Skip to content

Commit 2684b38

Browse files
Merge branch 'main' into main
2 parents a9c5d3d + 4e0e0a6 commit 2684b38

File tree

13 files changed

+1651
-64
lines changed

13 files changed

+1651
-64
lines changed

pyproject.toml

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"boto3>=1.26.0,<2.0.0",
3030
"botocore>=1.29.0,<2.0.0",
3131
"docstring_parser>=0.15,<1.0",
32-
"mcp>=1.8.0,<2.0.0",
32+
"mcp>=1.11.0,<2.0.0",
3333
"pydantic>=2.0.0,<3.0.0",
3434
"typing-extensions>=4.13.2,<5.0.0",
3535
"watchdog>=6.0.0,<7.0.0",
@@ -89,8 +89,14 @@ writer = [
8989
"writer-sdk>=2.2.0,<3.0.0"
9090
]
9191

92+
sagemaker = [
93+
"boto3>=1.26.0,<2.0.0",
94+
"botocore>=1.29.0,<2.0.0",
95+
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0"
96+
]
97+
9298
a2a = [
93-
"a2a-sdk[sql]>=0.2.16,<1.0.0",
99+
"a2a-sdk[sql]>=0.2.11,<1.0.0",
94100
"uvicorn>=0.34.2,<1.0.0",
95101
"httpx>=0.28.1,<1.0.0",
96102
"fastapi>=0.115.12,<1.0.0",
@@ -136,7 +142,7 @@ all = [
136142
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
137143

138144
# a2a
139-
"a2a-sdk[sql]>=0.2.16,<1.0.0",
145+
"a2a-sdk[sql]>=0.2.11,<1.0.0",
140146
"uvicorn>=0.34.2,<1.0.0",
141147
"httpx>=0.28.1,<1.0.0",
142148
"fastapi>=0.115.12,<1.0.0",
@@ -148,7 +154,7 @@ all = [
148154
source = "vcs"
149155

150156
[tool.hatch.envs.hatch-static-analysis]
151-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
157+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
152158
dependencies = [
153159
"mypy>=1.15.0,<2.0.0",
154160
"ruff>=0.11.6,<0.12.0",
@@ -171,7 +177,7 @@ lint-fix = [
171177
]
172178

173179
[tool.hatch.envs.hatch-test]
174-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
180+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"]
175181
extra-dependencies = [
176182
"moto>=5.1.0,<6.0.0",
177183
"pytest>=8.0.0,<9.0.0",
@@ -187,7 +193,7 @@ extra-args = [
187193

188194
[tool.hatch.envs.dev]
189195
dev-mode = true
190-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]
196+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"]
191197

192198
[[tool.hatch.envs.hatch-test.matrix]]
193199
python = ["3.13", "3.12", "3.11", "3.10"]
@@ -315,4 +321,4 @@ style = [
315321
["instruction", ""],
316322
["text", ""],
317323
["disabled", "fg:#858585 italic"]
318-
]
324+
]

src/strands/models/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ async def structured_output(
414414
stop_reason, messages, _, _ = event["stop"]
415415

416416
if stop_reason != "tool_use":
417-
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
417+
raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
418418

419419
content = messages["content"]
420420
output_response: dict[str, Any] | None = None

src/strands/models/bedrock.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
from ..event_loop import streaming
1919
from ..tools import convert_pydantic_to_tool_spec
20-
from ..types.content import Messages
20+
from ..types.content import ContentBlock, Message, Messages
2121
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2222
from ..types.streaming import StreamEvent
23-
from ..types.tools import ToolSpec
23+
from ..types.tools import ToolResult, ToolSpec
2424
from .model import Model
2525

2626
logger = logging.getLogger(__name__)
@@ -181,7 +181,7 @@ def format_request(
181181
"""
182182
return {
183183
"modelId": self.config["model_id"],
184-
"messages": messages,
184+
"messages": self._format_bedrock_messages(messages),
185185
"system": [
186186
*([{"text": system_prompt}] if system_prompt else []),
187187
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
@@ -246,6 +246,53 @@ def format_request(
246246
),
247247
}
248248

249+
def _format_bedrock_messages(self, messages: Messages) -> Messages:
250+
"""Format messages for Bedrock API compatibility.
251+
252+
This function ensures messages conform to Bedrock's expected format by:
253+
- Cleaning tool result content blocks by removing additional fields that may be
254+
useful for retaining information in hooks but would cause Bedrock validation
255+
exceptions when presented with unexpected fields
256+
- Ensuring all message content blocks are properly formatted for the Bedrock API
257+
258+
Args:
259+
messages: List of messages to format
260+
261+
Returns:
262+
Messages formatted for Bedrock API compatibility
263+
264+
Note:
265+
Bedrock will throw validation exceptions when presented with additional
266+
unexpected fields in tool result blocks.
267+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
268+
"""
269+
cleaned_messages = []
270+
271+
for message in messages:
272+
cleaned_content: list[ContentBlock] = []
273+
274+
for content_block in message["content"]:
275+
if "toolResult" in content_block:
276+
# Create a new content block with only the cleaned toolResult
277+
tool_result: ToolResult = content_block["toolResult"]
278+
279+
# Keep only the required fields for Bedrock
280+
cleaned_tool_result = ToolResult(
281+
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
282+
)
283+
284+
cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
285+
cleaned_content.append(cleaned_block)
286+
else:
287+
# Keep other content blocks as-is
288+
cleaned_content.append(content_block)
289+
290+
# Create new message with cleaned content
291+
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
292+
cleaned_messages.append(cleaned_message)
293+
294+
return cleaned_messages
295+
249296
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
250297
"""Check if guardrail data contains any blocked policies.
251298
@@ -584,7 +631,7 @@ async def structured_output(
584631
stop_reason, messages, _, _ = event["stop"]
585632

586633
if stop_reason != "tool_use":
587-
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
634+
raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".")
588635

589636
content = messages["content"]
590637
output_response: dict[str, Any] | None = None

0 commit comments

Comments
 (0)