Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def __init__(self, agent: BaseAgent, skip_summarization: bool = False):

super().__init__(name=agent.name, description=agent.description)

@model_validator(mode='before')
@model_validator(mode="before")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to revert the style changes (unless they are coming from ruff and not your personal settings :)

@classmethod
def populate_name(cls, data: Any) -> Any:
data['name'] = data['agent'].name
data["name"] = data["agent"].name
return data

@override
Expand All @@ -73,11 +73,11 @@ def _get_declaration(self) -> types.FunctionDeclaration:
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'request': types.Schema(
"request": types.Schema(
type=types.Type.STRING,
),
},
required=['request'],
required=["request"],
),
description=self.agent.description,
name=self.name,
Expand Down Expand Up @@ -105,15 +105,14 @@ async def run_async(
) -> Any:
from ..agents.llm_agent import LlmAgent
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService

if self.skip_summarization:
tool_context.actions.skip_summarization = True

if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
content = types.Content(
role='user',
role="user",
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
Expand All @@ -122,15 +121,15 @@ async def run_async(
)
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=args['request'])],
role="user",
parts=[types.Part.from_text(text=args["request"])],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
session_service=tool_context._invocation_context.session_service,
memory_service=tool_context._invocation_context.memory_service,
Comment on lines +131 to +132

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this change correctly sources the services from the invocation context, it relies on accessing the protected member _invocation_context. This pattern appears elsewhere in the codebase, but it breaks encapsulation and makes the public API of ToolContext less clear.

For better maintainability and a cleaner design, consider exposing these services as public properties on ToolContext. This would be a good opportunity for future refactoring.

Example of what could be added to ToolContext:

@property
def session_service(self) -> BaseSessionService:
    return self._invocation_context.session_service

@property
def memory_service(self) -> Optional[BaseMemoryService]:
    return self._invocation_context.memory_service

With this change, the code here would become tool_context.session_service and tool_context.memory_service, which is much cleaner.

credential_service=tool_context._invocation_context.credential_service,
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
)
Expand All @@ -154,8 +153,8 @@ async def run_async(
last_content = event.content

if not last_content:
return ''
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
return ""
merged_text = "\n".join(p.text for p in last_content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
merged_text
Expand Down