Skip to content

Commit 71c0c74

Browse files
authored
chore(client): refactor streaming slightly to better future proof it
2 parents 29c22c9 + 5054810 commit 71c0c74

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

src/openai/_streaming.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,11 @@ def __stream__(self) -> Iterator[_T]:
5959
if sse.data.startswith("[DONE]"):
6060
break
6161

62-
if sse.event is None or (
63-
sse.event.startswith("response.") or
64-
sse.event.startswith("transcript.") or
65-
sse.event.startswith("image_edit.") or
66-
sse.event.startswith("image_generation.")
67-
):
62+
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63+
if sse.event and sse.event.startswith("thread."):
6864
data = sse.json()
69-
if is_mapping(data) and data.get("error"):
65+
66+
if sse.event == "error" and is_mapping(data) and data.get("error"):
7067
message = None
7168
error = data.get("error")
7269
if is_mapping(error):
@@ -80,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
8077
body=data["error"],
8178
)
8279

83-
yield process_data(data=data, cast_to=cast_to, response=response)
84-
80+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
8581
else:
8682
data = sse.json()
87-
88-
if sse.event == "error" and is_mapping(data) and data.get("error"):
83+
if is_mapping(data) and data.get("error"):
8984
message = None
9085
error = data.get("error")
9186
if is_mapping(error):
@@ -99,7 +94,7 @@ def __stream__(self) -> Iterator[_T]:
9994
body=data["error"],
10095
)
10196

102-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
97+
yield process_data(data=data, cast_to=cast_to, response=response)
10398

10499
# Ensure the entire stream is consumed
105100
for _sse in iterator:
@@ -166,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
166161
if sse.data.startswith("[DONE]"):
167162
break
168163

169-
if sse.event is None or sse.event.startswith("response.") or sse.event.startswith("transcript."):
164+
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165+
if sse.event and sse.event.startswith("thread."):
170166
data = sse.json()
171-
if is_mapping(data) and data.get("error"):
167+
168+
if sse.event == "error" and is_mapping(data) and data.get("error"):
172169
message = None
173170
error = data.get("error")
174171
if is_mapping(error):
@@ -182,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
182179
body=data["error"],
183180
)
184181

185-
yield process_data(data=data, cast_to=cast_to, response=response)
186-
182+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
187183
else:
188184
data = sse.json()
189-
190-
if sse.event == "error" and is_mapping(data) and data.get("error"):
185+
if is_mapping(data) and data.get("error"):
191186
message = None
192187
error = data.get("error")
193188
if is_mapping(error):
@@ -201,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
201196
body=data["error"],
202197
)
203198

204-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
199+
yield process_data(data=data, cast_to=cast_to, response=response)
205200

206201
# Ensure the entire stream is consumed
207202
async for _sse in iterator:

0 commit comments

Comments
 (0)