Skip to content

Commit 5054810

Browse files
committed
flip logic around
1 parent 39caee7 commit 5054810

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

examples/image_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ def main() -> None:
5050
try:
5151
main()
5252
except Exception as error:
53-
print(f"Error generating image: {error}")
53+
print(f"Error generating image: {error}")

src/openai/_streaming.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ def __stream__(self) -> Iterator[_T]:
5959
if sse.data.startswith("[DONE]"):
6060
break
6161

62-
if sse.event is not None and not sse.event.startswith("thread."):
62+
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63+
if sse.event and sse.event.startswith("thread."):
6364
data = sse.json()
64-
if is_mapping(data) and data.get("error"):
65+
66+
if sse.event == "error" and is_mapping(data) and data.get("error"):
6567
message = None
6668
error = data.get("error")
6769
if is_mapping(error):
@@ -75,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
7577
body=data["error"],
7678
)
7779

78-
yield process_data(data=data, cast_to=cast_to, response=response)
79-
80+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
8081
else:
8182
data = sse.json()
82-
83-
if sse.event == "error" and is_mapping(data) and data.get("error"):
83+
if is_mapping(data) and data.get("error"):
8484
message = None
8585
error = data.get("error")
8686
if is_mapping(error):
@@ -93,8 +93,8 @@ def __stream__(self) -> Iterator[_T]:
9393
request=self.response.request,
9494
body=data["error"],
9595
)
96-
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
97-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
96+
97+
yield process_data(data=data, cast_to=cast_to, response=response)
9898

9999
# Ensure the entire stream is consumed
100100
for _sse in iterator:
@@ -161,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
161161
if sse.data.startswith("[DONE]"):
162162
break
163163

164-
if sse.event is None or sse.event.startswith("response.") or sse.event.startswith("transcript."):
164+
# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165+
if sse.event and sse.event.startswith("thread."):
165166
data = sse.json()
166-
if is_mapping(data) and data.get("error"):
167+
168+
if sse.event == "error" and is_mapping(data) and data.get("error"):
167169
message = None
168170
error = data.get("error")
169171
if is_mapping(error):
@@ -177,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
177179
body=data["error"],
178180
)
179181

180-
yield process_data(data=data, cast_to=cast_to, response=response)
181-
182+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
182183
else:
183184
data = sse.json()
184-
185-
if sse.event == "error" and is_mapping(data) and data.get("error"):
185+
if is_mapping(data) and data.get("error"):
186186
message = None
187187
error = data.get("error")
188188
if is_mapping(error):
@@ -196,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
196196
body=data["error"],
197197
)
198198

199-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
199+
yield process_data(data=data, cast_to=cast_to, response=response)
200200

201201
# Ensure the entire stream is consumed
202202
async for _sse in iterator:

0 commit comments

Comments
 (0)