Skip to content

Commit 22e3ba1

Browse files
authored
Merge pull request #82 from YangZhiBoGreenHand/yzb/fix/sub-stream
fix: sub stream with submit tool outputs
2 parents 2aa5c91 + 2f101c3 commit 22e3ba1

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

app/api/v1/runs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ async def list_run_steps(
121121
return page.model_dump(by_alias=True)
122122

123123

124-
125124
@router.get(
126125
"/{thread_id}/runs/{run_id}/steps/{step_id}",
127126
response_model=RunStepRead,
@@ -160,7 +159,7 @@ async def submit_tool_outputs_to_run(
160159
db_run = await RunService.submit_tool_outputs_to_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
161160
# Resume async task
162161
if db_run.status == "queued":
163-
run_task.apply_async(args=(db_run.id,))
162+
run_task.apply_async(args=(db_run.id, body.stream))
164163

165164
if body.stream:
166165
return pub_handler.sub_stream(db_run.id, request)

app/core/runner/pub_handler.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ def read_event(channel: str, x_index: str = None) -> Tuple[Optional[str], Option
4949
return stream_id, event
5050

5151

52+
def save_last_stream_id(run_id: str, stream_id: str):
53+
"""
54+
保存当前 run_id 对应的最新 stream_id
55+
:param run_id: 当前的运行 ID
56+
:param stream_id: 最新的 stream_id
57+
"""
58+
redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
59+
60+
61+
def get_last_stream_id(run_id: str) -> str:
62+
"""
63+
获取上次保存的 stream_id
64+
:param run_id: 当前的运行 ID
65+
:return: 上次的 stream_id 或 None
66+
"""
67+
return redis_client.get(f"run:{run_id}:last_stream_id")
68+
69+
5270
def _data_adjust_tools(tools: List[dict]) -> List[dict]:
5371
def _adjust_tool(tool: dict):
5472
if tool["type"] not in {"code_interpreter", "file_search", "function"}:
@@ -108,7 +126,8 @@ async def _stream():
108126
for event in prefix_events:
109127
yield event
110128

111-
x_index = None
129+
last_index = get_last_stream_id(run_id) # 获取上次的 stream_id
130+
x_index = last_index or None
112131
while True:
113132
if await request.is_disconnected():
114133
break
@@ -120,12 +139,15 @@ async def _stream():
120139
break
121140

122141
if data["event"] == "done":
142+
save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
123143
break
124144

125145
if data["event"] == "error":
146+
save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
126147
raise InternalServerError(data["data"])
127148

128149
yield data
150+
save_last_stream_id(run_id, x_index) # 记录最新的 stream_id
129151

130152
for event in suffix_events:
131153
yield event

tests/e2e/sub_stream_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import json
2+
3+
from openai import AssistantEventHandler
4+
5+
6+
def test_sub_stream_with_submit_tool_outputs_stream(client):
7+
8+
def get_current_weather(location):
9+
return f"{location}今天是雨天。 "
10+
11+
assistant = client.beta.assistants.create(
12+
name="Assistant Demo",
13+
instructions="You are a helpful assistant. When asked a question, use tools wherever possible.",
14+
model="gpt-4o",
15+
tools=[
16+
{
17+
"type": "function",
18+
"function": {
19+
"name": "get_current_weather",
20+
"description": "当你想查询指定城市的天气时非常有用。",
21+
"parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "城市或县区,比如北京市、杭州市、余杭区等。"}}, "required": ["location"]}, # 查询天气时需要提供位置,因此参数设置为location
22+
},
23+
}
24+
],
25+
)
26+
print("=====> : %s\n", assistant)
27+
28+
thread = client.beta.threads.create()
29+
print("=====> : %s\n", thread)
30+
31+
message = client.beta.threads.messages.create(
32+
thread_id=thread.id,
33+
role="user",
34+
content="北京天气如何?",
35+
)
36+
print("=====> : %s\n", message)
37+
38+
funcs = [get_current_weather]
39+
40+
class EventHandler(AssistantEventHandler):
41+
42+
def on_event(self, event):
43+
print(event.event)
44+
if event.event == "thread.run.requires_action":
45+
print(event)
46+
run_id = event.data.id # Retrieve the run ID from the event data
47+
self.handle_requires_action(event.data, run_id)
48+
49+
def handle_requires_action(self, data, run_id):
50+
tool_outputs = []
51+
52+
for tool in data.required_action.submit_tool_outputs.tool_calls:
53+
func = next(iter([func for func in funcs if func.__name__ == tool.function.name]))
54+
try:
55+
output = func(**eval(tool.function.arguments))
56+
except Exception as e:
57+
output = "Error: " + str(e)
58+
59+
tool_outputs.append({"tool_call_id": tool.id, "output": json.dumps(output)})
60+
61+
print(tool_outputs)
62+
63+
# Submit all tool_outputs at the same time
64+
self.submit_tool_outputs(tool_outputs, run_id)
65+
66+
def submit_tool_outputs(self, tool_outputs, run_id):
67+
# Use the submit_tool_outputs_stream helper
68+
with client.beta.threads.runs.submit_tool_outputs_stream(
69+
thread_id=self.current_run.thread_id,
70+
run_id=self.current_run.id,
71+
tool_outputs=tool_outputs,
72+
event_handler=EventHandler(),
73+
) as stream:
74+
# for text in stream.text_deltas:
75+
# print(text, end="", flush=True)
76+
# print()
77+
stream.until_done()
78+
79+
def on_text_delta(self, delta, snapshot) -> None:
80+
print("=====> text delta")
81+
print("delta : %s", delta)
82+
83+
with client.beta.threads.runs.stream(
84+
thread_id=thread.id,
85+
assistant_id=assistant.id,
86+
event_handler=EventHandler(),
87+
) as stream:
88+
stream.until_done()

0 commit comments

Comments
 (0)