@@ -59,9 +59,11 @@ def __stream__(self) -> Iterator[_T]:
59
59
if sse .data .startswith ("[DONE]" ):
60
60
break
61
61
62
- if sse .event is not None and not sse .event .startswith ("thread." ):
62
+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63
+ if sse .event and sse .event .startswith ("thread." ):
63
64
data = sse .json ()
64
- if is_mapping (data ) and data .get ("error" ):
65
+
66
+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
65
67
message = None
66
68
error = data .get ("error" )
67
69
if is_mapping (error ):
@@ -75,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
75
77
body = data ["error" ],
76
78
)
77
79
78
- yield process_data (data = data , cast_to = cast_to , response = response )
79
-
80
+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
80
81
else :
81
82
data = sse .json ()
82
-
83
- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
83
+ if is_mapping (data ) and data .get ("error" ):
84
84
message = None
85
85
error = data .get ("error" )
86
86
if is_mapping (error ):
@@ -93,8 +93,8 @@ def __stream__(self) -> Iterator[_T]:
93
93
request = self .response .request ,
94
94
body = data ["error" ],
95
95
)
96
- # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
97
- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
96
+
97
+ yield process_data (data = data , cast_to = cast_to , response = response )
98
98
99
99
# Ensure the entire stream is consumed
100
100
for _sse in iterator :
@@ -161,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
161
161
if sse .data .startswith ("[DONE]" ):
162
162
break
163
163
164
- if sse .event is None or sse .event .startswith ("response." ) or sse .event .startswith ("transcript." ):
164
+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165
+ if sse .event and sse .event .startswith ("thread." ):
165
166
data = sse .json ()
166
- if is_mapping (data ) and data .get ("error" ):
167
+
168
+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
167
169
message = None
168
170
error = data .get ("error" )
169
171
if is_mapping (error ):
@@ -177,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
177
179
body = data ["error" ],
178
180
)
179
181
180
- yield process_data (data = data , cast_to = cast_to , response = response )
181
-
182
+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
182
183
else :
183
184
data = sse .json ()
184
-
185
- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
185
+ if is_mapping (data ) and data .get ("error" ):
186
186
message = None
187
187
error = data .get ("error" )
188
188
if is_mapping (error ):
@@ -196,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
196
196
body = data ["error" ],
197
197
)
198
198
199
- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
199
+ yield process_data (data = data , cast_to = cast_to , response = response )
200
200
201
201
# Ensure the entire stream is consumed
202
202
async for _sse in iterator :
0 commit comments