diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 8c680b611..a2a7171b7 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -53,10 +53,10 @@ def __init__(self, agent: BaseAgent, skip_summarization: bool = False): super().__init__(name=agent.name, description=agent.description) - @model_validator(mode='before') + @model_validator(mode="before") @classmethod def populate_name(cls, data: Any) -> Any: - data['name'] = data['agent'].name + data["name"] = data["agent"].name return data @override @@ -73,11 +73,11 @@ def _get_declaration(self) -> types.FunctionDeclaration: parameters=types.Schema( type=types.Type.OBJECT, properties={ - 'request': types.Schema( + "request": types.Schema( type=types.Type.STRING, ), }, - required=['request'], + required=["request"], ), description=self.agent.description, name=self.name, @@ -105,7 +105,6 @@ async def run_async( ) -> Any: from ..agents.llm_agent import LlmAgent from ..runners import Runner - from ..sessions.in_memory_session_service import InMemorySessionService if self.skip_summarization: tool_context.actions.skip_summarization = True @@ -113,7 +112,7 @@ async def run_async( if isinstance(self.agent, LlmAgent) and self.agent.input_schema: input_value = self.agent.input_schema.model_validate(args) content = types.Content( - role='user', + role="user", parts=[ types.Part.from_text( text=input_value.model_dump_json(exclude_none=True) @@ -122,15 +121,15 @@ async def run_async( ) else: content = types.Content( - role='user', - parts=[types.Part.from_text(text=args['request'])], + role="user", + parts=[types.Part.from_text(text=args["request"])], ) runner = Runner( app_name=self.agent.name, agent=self.agent, artifact_service=ForwardingArtifactService(tool_context), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + session_service=tool_context._invocation_context.session_service, + memory_service=tool_context._invocation_context.memory_service, credential_service=tool_context._invocation_context.credential_service, plugins=list(tool_context._invocation_context.plugin_manager.plugins), ) @@ -154,8 +153,8 @@ async def run_async( last_content = event.content if not last_content: - return '' - merged_text = '\n'.join(p.text for p in last_content.parts if p.text) + return "" + merged_text = "\n".join(p.text for p in last_content.parts if p.text) if isinstance(self.agent, LlmAgent) and self.agent.output_schema: tool_result = self.agent.output_schema.model_validate_json( merged_text