Skip to content

Commit 49b3317

Browse files
Add prompt caching (#891)
* Reflect the cost of cache read/write to the price calculation. * Implement prompt caching. - Backend - Add a field `use_prompt_caching` to models and schemas of custom bot. - `BotModel` - `BotInput` - `BotOutput` - `BotModifyInput` - `BotModifyOutput` - Add a column `UsePromptCaching` to DynamoDB table `BotTableV3`. - Use prompt caching if enabled and the model supports it. - Frontend - Add a field `usePromptCaching` to schemas of custom bot. - `BotDetails` - `RegisterBotRequest` - `RegisterBotResponse` - `UpdateBotRequest` - `UpdateBotResponse` - Add 'Prompt Caching' section to `BotKbEditPage`. * [Debug]Print token count and price when received `STREAMING_END`. * Reformat modified python codes. - `backend/app/config.py` - `backend/app/bedrock.py` - `backend/app/repositories/custom_bot.py` - `backend/app/usecases/chat.py` * Use `ExpandableDrawerGroup` for prompt caching settings. * Revert "[Debug]Print token count and price when received `STREAMING_END`." This reverts commit ba5e584. * Refactor data structure of agent settings. * Rename `usePromptCaching` to `promptCachingEnabled`. - `use_prompt_caching` -> `prompt_caching_enabled` - Change `BotModel.prompt_caching_enabled` to a non-nullable type. * Change `BotModifyInput.prompt_caching_enabled` to a non-nullable type. * Add `token_count` and `price` to the payload of `STREAMING_END` notification.
1 parent 54725bf commit 49b3317

File tree

19 files changed

+363
-122
lines changed

19 files changed

+363
-122
lines changed

backend/app/bedrock.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import os
5-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypeGuard
5+
from typing import TYPE_CHECKING, Any, Dict, Optional, Literal, Tuple, TypeGuard
66

77
from app.config import (
88
BEDROCK_PRICING,
@@ -30,6 +30,7 @@
3030
InferenceConfigurationTypeDef,
3131
MessageTypeDef,
3232
SystemContentBlockTypeDef,
33+
ToolTypeDef,
3334
)
3435

3536

@@ -81,6 +82,31 @@ def is_tooluse_supported(model: type_model_name) -> bool:
8182
]
8283

8384

85+
def is_prompt_caching_supported(
86+
model: type_model_name, target: Literal["system", "message", "tool"]
87+
) -> bool:
88+
if target == "tool":
89+
return model in [
90+
"claude-v4-opus",
91+
"claude-v4-sonnet",
92+
"claude-v3.7-sonnet",
93+
"claude-v3.5-sonnet-v2",
94+
"claude-v3.5-haiku",
95+
]
96+
97+
else:
98+
return model in [
99+
"claude-v4-opus",
100+
"claude-v4-sonnet",
101+
"claude-v3.7-sonnet",
102+
"claude-v3.5-sonnet-v2",
103+
"claude-v3.5-haiku",
104+
"amazon-nova-pro",
105+
"amazon-nova-lite",
106+
"amazon-nova-micro",
107+
]
108+
109+
84110
def _prepare_deepseek_model_params(
85111
model: type_model_name, generation_params: Optional[GenerationParamsModel] = None
86112
) -> Tuple[InferenceConfigurationTypeDef, None]:
@@ -263,6 +289,7 @@ def compose_args_for_converse_api(
263289
tools: dict[str, AgentTool] | None = None,
264290
stream: bool = True,
265291
enable_reasoning: bool = False,
292+
prompt_caching_enabled: bool = False,
266293
) -> ConverseStreamRequestTypeDef:
267294
def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
268295
# Drop unsigned reasoning blocks only for DeepSeek R1
@@ -303,6 +330,16 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
303330
for message in messages
304331
if _is_conversation_role(message.role)
305332
]
333+
tool_specs: list[ToolTypeDef] | None = (
334+
[
335+
{
336+
"toolSpec": tool.to_converse_spec(),
337+
}
338+
for tool in tools.values()
339+
]
340+
if tools
341+
else None
342+
)
306343

307344
# Prepare model-specific parameters
308345
inference_config: InferenceConfigurationTypeDef
@@ -457,6 +494,41 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
457494
if len(instruction) > 0
458495
]
459496

497+
if prompt_caching_enabled and not (
498+
tool_specs and not is_prompt_caching_supported(model, target="tool")
499+
):
500+
if is_prompt_caching_supported(model, "system") and len(system_prompts) > 0:
501+
system_prompts.append(
502+
{
503+
"cachePoint": {
504+
"type": "default",
505+
},
506+
}
507+
)
508+
509+
if is_prompt_caching_supported(model, target="message"):
510+
for order, message in enumerate(
511+
filter(lambda m: m["role"] == "user", reversed(arg_messages))
512+
):
513+
if order >= 2:
514+
break
515+
516+
message["content"] = [
517+
*(message["content"]),
518+
{
519+
"cachePoint": {"type": "default"},
520+
},
521+
]
522+
523+
if is_prompt_caching_supported(model, target="tool") and tool_specs:
524+
tool_specs.append(
525+
{
526+
"cachePoint": {
527+
"type": "default",
528+
},
529+
}
530+
)
531+
460532
# Construct the base arguments
461533
args: ConverseStreamRequestTypeDef = {
462534
"inferenceConfig": inference_config,
@@ -480,14 +552,9 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
480552
args["guardrailConfig"]["streamProcessingMode"] = "async"
481553

482554
# NOTE: Some models doesn't support tool use. https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
483-
if tools:
555+
if tool_specs:
484556
args["toolConfig"] = {
485-
"tools": [
486-
{
487-
"toolSpec": tool.to_converse_spec(),
488-
}
489-
for tool in tools.values()
490-
],
557+
"tools": tool_specs,
491558
}
492559

493560
return args
@@ -519,6 +586,8 @@ def calculate_price(
519586
model: type_model_name,
520587
input_tokens: int,
521588
output_tokens: int,
589+
cache_read_input_tokens: int,
590+
cache_write_input_tokens: int,
522591
region: str = BEDROCK_REGION,
523592
) -> float:
524593
input_price = (
@@ -531,8 +600,29 @@ def calculate_price(
531600
.get(model, {})
532601
.get("output", BEDROCK_PRICING["default"][model]["output"])
533602
)
603+
cache_read_input_price = (
604+
BEDROCK_PRICING.get(region, {})
605+
.get(model, {})
606+
.get(
607+
"cache_read_input",
608+
BEDROCK_PRICING["default"][model].get("cache_read_input", input_price),
609+
)
610+
)
611+
cache_write_input_price = (
612+
BEDROCK_PRICING.get(region, {})
613+
.get(model, {})
614+
.get(
615+
"cache_write_input",
616+
BEDROCK_PRICING["default"][model].get("cache_write_input", input_price),
617+
)
618+
)
534619

535-
return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0
620+
return (
621+
input_price * input_tokens / 1000.0
622+
+ output_price * output_tokens / 1000.0
623+
+ cache_read_input_price * cache_read_input_tokens / 1000.0
624+
+ cache_write_input_price * cache_write_input_tokens / 1000.0
625+
)
536626

537627

538628
def get_model_id(

backend/app/config.py

Lines changed: 144 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,64 @@ class EmbeddingConfig(TypedDict):
6363
# See: https://aws.amazon.com/bedrock/pricing/
6464
BEDROCK_PRICING = {
6565
"us-east-1": {
66-
"claude-v4-opus": {"input": 0.015, "output": 0.075},
67-
"claude-v4-sonnet": {"input": 0.003, "output": 0.015},
66+
"claude-v4-opus": {
67+
"input": 0.015,
68+
"output": 0.075,
69+
"cache_write_input": 0.01875,
70+
"cache_read_input": 0.0015,
71+
},
72+
"claude-v4-sonnet": {
73+
"input": 0.003,
74+
"output": 0.015,
75+
"cache_write_input": 0.00375,
76+
"cache_read_input": 0.0003,
77+
},
6878
"claude-v3-haiku": {"input": 0.00025, "output": 0.00125},
69-
"claude-v3.5-haiku": {"input": 0.001, "output": 0.005},
70-
"claude-v3.5-sonnet": {"input": 0.00300, "output": 0.01500},
71-
"claude-v3.5-sonnet-v2": {"input": 0.00300, "output": 0.01500},
72-
"claude-v3.7-sonnet": {"input": 0.00300, "output": 0.01500},
79+
"claude-v3.5-haiku": {
80+
"input": 0.001,
81+
"output": 0.005,
82+
"cache_write_input": 0.001,
83+
"cache_read_input": 0.00008,
84+
},
85+
"claude-v3.5-sonnet": {
86+
"input": 0.00300,
87+
"output": 0.01500,
88+
"cache_write_input": 0.00375,
89+
"cache_read_input": 0.0003,
90+
},
91+
"claude-v3.5-sonnet-v2": {
92+
"input": 0.00300,
93+
"output": 0.01500,
94+
"cache_write_input": 0.00375,
95+
"cache_read_input": 0.0003,
96+
},
97+
"claude-v3.7-sonnet": {
98+
"input": 0.00300,
99+
"output": 0.01500,
100+
"cache_write_input": 0.00375,
101+
"cache_read_input": 0.0003,
102+
},
73103
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
74104
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
75105
"mistral-large": {"input": 0.004, "output": 0.012},
76-
"amazon-nova-pro": {"input": 0.0008, "output": 0.0032},
77-
"amazon-nova-lite": {"input": 0.00006, "output": 0.00024},
78-
"amazon-nova-micro": {"input": 0.000035, "output": 0.00014},
106+
"amazon-nova-pro": {
107+
"input": 0.0008,
108+
"output": 0.0032,
109+
"cache_write_input": 0.0008,
110+
"cache_read_input": 0.0002,
111+
},
112+
"amazon-nova-lite": {
113+
"input": 0.00006,
114+
"output": 0.00024,
115+
"cache_write_input": 0.00006,
116+
"cache_read_input": 0.000015,
117+
},
118+
"amazon-nova-micro": {
119+
"input": 0.000035,
120+
"output": 0.00014,
121+
"cache_write_input": 0.000035,
122+
"cache_read_input": 0.00000875,
123+
},
79124
"deepseek-r1": {"input": 0.00135, "output": 0.0054},
80125
# Meta Llama 3 models (US region)
81126
"llama3-3-70b-instruct": {"input": 0.00072, "output": 0.00072},
@@ -85,17 +130,47 @@ class EmbeddingConfig(TypedDict):
85130
"llama3-2-90b-instruct": {"input": 0.00072, "output": 0.00072},
86131
},
87132
"us-west-2": {
88-
"claude-v4-opus": {"input": 0.015, "output": 0.075},
89-
"claude-v4-sonnet": {"input": 0.003, "output": 0.015},
90-
"claude-v3.7-sonnet": {"input": 0.00300, "output": 0.01500},
133+
"claude-v4-opus": {
134+
"input": 0.015,
135+
"output": 0.075,
136+
"cache_write_input": 0.01875,
137+
"cache_read_input": 0.0015,
138+
},
139+
"claude-v4-sonnet": {
140+
"input": 0.003,
141+
"output": 0.015,
142+
"cache_write_input": 0.00375,
143+
"cache_read_input": 0.0003,
144+
},
145+
"claude-v3.7-sonnet": {
146+
"input": 0.00300,
147+
"output": 0.01500,
148+
"cache_write_input": 0.00375,
149+
"cache_read_input": 0.0003,
150+
},
91151
"claude-v3-opus": {"input": 0.01500, "output": 0.07500},
92152
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
93153
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
94154
"mistral-large": {"input": 0.004, "output": 0.012},
95155
"mistral-large-2": {"input": 0.002, "output": 0.06},
96-
"amazon-nova-pro": {"input": 0.0008, "output": 0.0032},
97-
"amazon-nova-lite": {"input": 0.00006, "output": 0.00024},
98-
"amazon-nova-micro": {"input": 0.000035, "output": 0.00014},
156+
"amazon-nova-pro": {
157+
"input": 0.0008,
158+
"output": 0.0032,
159+
"cache_write_input": 0.0008,
160+
"cache_read_input": 0.0002,
161+
},
162+
"amazon-nova-lite": {
163+
"input": 0.00006,
164+
"output": 0.00024,
165+
"cache_write_input": 0.00006,
166+
"cache_read_input": 0.000015,
167+
},
168+
"amazon-nova-micro": {
169+
"input": 0.000035,
170+
"output": 0.00014,
171+
"cache_write_input": 0.000035,
172+
"cache_read_input": 0.00000875,
173+
},
99174
"deepseek-r1": {"input": 0.00135, "output": 0.0054},
100175
# Meta Llama 3 models (US region)
101176
"llama3-3-70b-instruct": {"input": 0.00072, "output": 0.00072},
@@ -106,21 +181,66 @@ class EmbeddingConfig(TypedDict):
106181
},
107182
"ap-northeast-1": {},
108183
"default": {
109-
"claude-v4-opus": {"input": 0.015, "output": 0.075},
110-
"claude-v4-sonnet": {"input": 0.003, "output": 0.015},
184+
"claude-v4-opus": {
185+
"input": 0.015,
186+
"output": 0.075,
187+
"cache_write_input": 0.01875,
188+
"cache_read_input": 0.0015,
189+
},
190+
"claude-v4-sonnet": {
191+
"input": 0.003,
192+
"output": 0.015,
193+
"cache_write_input": 0.00375,
194+
"cache_read_input": 0.0003,
195+
},
111196
"claude-v3-haiku": {"input": 0.00025, "output": 0.00125},
112-
"claude-v3.5-haiku": {"input": 0.001, "output": 0.005},
113-
"claude-v3.5-sonnet": {"input": 0.00300, "output": 0.01500},
114-
"claude-v3.5-sonnet-v2": {"input": 0.00300, "output": 0.01500},
115-
"claude-v3.7-sonnet": {"input": 0.00300, "output": 0.01500},
197+
"claude-v3.5-haiku": {
198+
"input": 0.001,
199+
"output": 0.005,
200+
"cache_write_input": 0.001,
201+
"cache_read_input": 0.00008,
202+
},
203+
"claude-v3.5-sonnet": {
204+
"input": 0.00300,
205+
"output": 0.01500,
206+
"cache_write_input": 0.00375,
207+
"cache_read_input": 0.0003,
208+
},
209+
"claude-v3.5-sonnet-v2": {
210+
"input": 0.00300,
211+
"output": 0.01500,
212+
"cache_write_input": 0.00375,
213+
"cache_read_input": 0.0003,
214+
},
215+
"claude-v3.7-sonnet": {
216+
"input": 0.00300,
217+
"output": 0.01500,
218+
"cache_write_input": 0.00375,
219+
"cache_read_input": 0.0003,
220+
},
116221
"claude-v3-opus": {"input": 0.01500, "output": 0.07500},
117222
"mistral-7b-instruct": {"input": 0.00015, "output": 0.0002},
118223
"mixtral-8x7b-instruct": {"input": 0.00045, "output": 0.0007},
119224
"mistral-large": {"input": 0.004, "output": 0.012},
120225
"mistral-large-2": {"input": 0.002, "output": 0.06},
121-
"amazon-nova-pro": {"input": 0.0008, "output": 0.0032},
122-
"amazon-nova-lite": {"input": 0.00006, "output": 0.00024},
123-
"amazon-nova-micro": {"input": 0.000035, "output": 0.00014},
226+
"amazon-nova-pro": {
227+
"input": 0.0008,
228+
"output": 0.0032,
229+
"cache_write_input": 0.0008,
230+
"cache_read_input": 0.0002,
231+
},
232+
"amazon-nova-lite": {
233+
"input": 0.00006,
234+
"output": 0.00024,
235+
"cache_write_input": 0.00006,
236+
"cache_read_input": 0.000015,
237+
},
238+
"amazon-nova-micro": {
239+
"input": 0.000035,
240+
"output": 0.00014,
241+
"cache_write_input": 0.000035,
242+
"cache_read_input": 0.00000875,
243+
},
124244
"deepseek-r1": {"input": 0.00135, "output": 0.0054},
125245
# Meta Llama 3 models (US region)
126246
"llama3-3-70b-instruct": {"input": 0.00072, "output": 0.00072},

0 commit comments

Comments
 (0)