Skip to content

feat(mcp): Add list_prompts, get_prompt methods #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from mcp import ClientSession, ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import GetPromptResult, ListPromptsResult
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent

Expand Down Expand Up @@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult:
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)

def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult:
"""Synchronously retrieves the list of available prompts from the MCP server.

This method calls the asynchronous list_prompts method on the MCP session
and returns the raw ListPromptsResult with pagination support.

Args:
pagination_token: Optional token for pagination

Returns:
ListPromptsResult: The raw MCP response containing prompts and pagination info
"""
self._log_debug_with_thread("listing MCP prompts synchronously")
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _list_prompts_async() -> ListPromptsResult:
return await self._background_thread_session.list_prompts(cursor=pagination_token)

list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
for prompt in list_prompts_result.prompts:
self._log_debug_with_thread(prompt.name)

return list_prompts_result

def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult:
"""Synchronously retrieves a prompt from the MCP server.

Args:
prompt_id: The ID of the prompt to retrieve
args: Optional arguments to pass to the prompt

Returns:
GetPromptResult: The prompt response from the MCP server
"""
self._log_debug_with_thread("getting MCP prompt synchronously")
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _get_prompt_async() -> GetPromptResult:
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)

get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
self._log_debug_with_thread("received prompt from MCP server")

return get_prompt_result

def call_tool_sync(
self,
tool_use_id: str,
Expand Down
62 changes: 62 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from mcp import ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage
from mcp.types import TextContent as MCPTextContent
from mcp.types import Tool as MCPTool

Expand Down Expand Up @@ -404,3 +405,64 @@ def test_exception_when_future_not_running():

# Verify that set_exception was not called since the future was not running
mock_future.set_exception.assert_not_called()


# Prompt Tests - Sync Methods


def test_list_prompts_sync(mock_transport, mock_session):
"""Test that list_prompts_sync correctly retrieves prompts."""
mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt])

with MCPClient(mock_transport["transport_callable"]) as client:
result = client.list_prompts_sync()

mock_session.list_prompts.assert_called_once_with(cursor=None)
assert len(result.prompts) == 1
assert result.prompts[0].name == "test_prompt"
assert result.nextCursor is None


def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session):
"""Test that list_prompts_sync correctly passes pagination token and returns next cursor."""
mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token")

with MCPClient(mock_transport["transport_callable"]) as client:
result = client.list_prompts_sync(pagination_token="current_page_token")

mock_session.list_prompts.assert_called_once_with(cursor="current_page_token")
assert len(result.prompts) == 1
assert result.prompts[0].name == "test_prompt"
assert result.nextCursor == "next_page_token"


def test_list_prompts_sync_session_not_active():
"""Test that list_prompts_sync raises an error when session is not active."""
client = MCPClient(MagicMock())

with pytest.raises(MCPClientInitializationError, match="client session is not running"):
client.list_prompts_sync()


def test_get_prompt_sync(mock_transport, mock_session):
"""Test that get_prompt_sync correctly retrieves a prompt."""
mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt"))
mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message])

with MCPClient(mock_transport["transport_callable"]) as client:
result = client.get_prompt_sync("test_prompt_id", {"key": "value"})

mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"})
assert len(result.messages) == 1
assert result.messages[0].role == "user"
assert result.messages[0].content.text == "This is a test prompt"


def test_get_prompt_sync_session_not_active():
"""Test that get_prompt_sync raises an error when session is not active."""
client = MCPClient(MagicMock())

with pytest.raises(MCPClientInitializationError, match="client session is not running"):
client.get_prompt_sync("test_prompt_id", {})
83 changes: 73 additions & 10 deletions tests_integ/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
from strands.types.tools import ToolUse


def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int):
"""
Initialize and start an MCP calculator server for integration testing.
Initialize and start a comprehensive MCP server for integration testing.

This function creates a FastMCP server instance that provides a simple
calculator tool for performing addition operations. The server uses
Server-Sent Events (SSE) transport for communication, making it accessible
over HTTP.
This function creates a FastMCP server instance that provides tools, prompts,
and resources all in one server for comprehensive testing. The server uses
Server-Sent Events (SSE) or streamable HTTP transport for communication.
"""
from mcp.server import FastMCP

mcp = FastMCP("Calculator Server", port=port)
mcp = FastMCP("Comprehensive MCP Server", port=port)

@mcp.tool(description="Calculator tool which performs calculations")
def calculator(x: int, y: int) -> int:
Expand All @@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent:
except Exception as e:
print("Error while generating custom image: {}".format(e))

# Prompts
@mcp.prompt(description="A greeting prompt template")
def greeting_prompt(name: str = "World") -> str:
return f"Hello, {name}! How are you today?"

@mcp.prompt(description="A math problem prompt template")
def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str:
return f"Create a {difficulty} {operation} math problem and solve it step by step."

mcp.run(transport=transport)


Expand All @@ -58,8 +66,9 @@ def test_mcp_client():
{'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
""" # noqa: E501

# Start comprehensive server with tools, prompts, and resources
server_thread = threading.Thread(
target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
Expand All @@ -68,8 +77,14 @@ def test_mcp_client():
stdio_mcp_client = MCPClient(
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
)

with sse_mcp_client, stdio_mcp_client:
agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
# Test Tools functionality
sse_tools = sse_mcp_client.list_tools_sync()
stdio_tools = stdio_mcp_client.list_tools_sync()
all_tools = sse_tools + stdio_tools

agent = Agent(tools=all_tools)
agent("add 1 and 2, then echo the result back to me")

tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
Expand All @@ -88,6 +103,43 @@ def test_mcp_client():
]
)

# Test Prompts functionality
prompts_result = sse_mcp_client.list_prompts_sync()
assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts

prompt_names = [prompt.name for prompt in prompts_result.prompts]
assert "greeting_prompt" in prompt_names
assert "math_prompt" in prompt_names

# Test get_prompt_sync with greeting prompt
greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"})
assert len(greeting_result.messages) > 0
prompt_text = greeting_result.messages[0].content.text
assert "Hello, Alice!" in prompt_text
assert "How are you today?" in prompt_text

# Test get_prompt_sync with math prompt
math_result = sse_mcp_client.get_prompt_sync(
"math_prompt", {"operation": "multiplication", "difficulty": "medium"}
)
assert len(math_result.messages) > 0
math_text = math_result.messages[0].content.text
assert "multiplication" in math_text
assert "medium" in math_text
assert "step by step" in math_text

# Test pagination support for prompts
prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None)
assert len(prompts_with_token.prompts) >= 0

# Test pagination support for tools (existing functionality)
tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None)
assert len(tools_with_token) >= 0

# TODO: Add resources testing when resources are implemented
# resources_result = sse_mcp_client.list_resources_sync()
# assert len(resources_result.resources) >= 0

tool_use_id = "test-structured-content-123"
result = stdio_mcp_client.call_tool_sync(
tool_use_id=tool_use_id,
Expand Down Expand Up @@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content():
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
)
def test_streamable_http_mcp_client():
"""Test comprehensive MCP client with streamable HTTP transport."""
server_thread = threading.Thread(
target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
)
server_thread.start()
time.sleep(2) # wait for server to startup completely
Expand All @@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport:

streamable_http_client = MCPClient(transport_callback)
with streamable_http_client:
# Test tools
agent = Agent(tools=streamable_http_client.list_tools_sync())
agent("add 1 and 2 using a calculator")

tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
assert any([block["name"] == "calculator" for block in tool_use_content_blocks])

# Test prompts
prompts_result = streamable_http_client.list_prompts_sync()
assert len(prompts_result.prompts) >= 2

greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"})
assert len(greeting_result.messages) > 0
prompt_text = greeting_result.messages[0].content.text
assert "Hello, Charlie!" in prompt_text


def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
Loading