@@ -59,14 +59,11 @@ def __stream__(self) -> Iterator[_T]:
59
59
if sse .data .startswith ("[DONE]" ):
60
60
break
61
61
62
- if sse .event is None or (
63
- sse .event .startswith ("response." ) or
64
- sse .event .startswith ("transcript." ) or
65
- sse .event .startswith ("image_edit." ) or
66
- sse .event .startswith ("image_generation." )
67
- ):
62
+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
63
+ if sse .event and sse .event .startswith ("thread." ):
68
64
data = sse .json ()
69
- if is_mapping (data ) and data .get ("error" ):
65
+
66
+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
70
67
message = None
71
68
error = data .get ("error" )
72
69
if is_mapping (error ):
@@ -80,12 +77,10 @@ def __stream__(self) -> Iterator[_T]:
80
77
body = data ["error" ],
81
78
)
82
79
83
- yield process_data (data = data , cast_to = cast_to , response = response )
84
-
80
+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
85
81
else :
86
82
data = sse .json ()
87
-
88
- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
83
+ if is_mapping (data ) and data .get ("error" ):
89
84
message = None
90
85
error = data .get ("error" )
91
86
if is_mapping (error ):
@@ -99,7 +94,7 @@ def __stream__(self) -> Iterator[_T]:
99
94
body = data ["error" ],
100
95
)
101
96
102
- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
97
+ yield process_data (data = data , cast_to = cast_to , response = response )
103
98
104
99
# Ensure the entire stream is consumed
105
100
for _sse in iterator :
@@ -166,9 +161,11 @@ async def __stream__(self) -> AsyncIterator[_T]:
166
161
if sse .data .startswith ("[DONE]" ):
167
162
break
168
163
169
- if sse .event is None or sse .event .startswith ("response." ) or sse .event .startswith ("transcript." ):
164
+ # we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
165
+ if sse .event and sse .event .startswith ("thread." ):
170
166
data = sse .json ()
171
- if is_mapping (data ) and data .get ("error" ):
167
+
168
+ if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
172
169
message = None
173
170
error = data .get ("error" )
174
171
if is_mapping (error ):
@@ -182,12 +179,10 @@ async def __stream__(self) -> AsyncIterator[_T]:
182
179
body = data ["error" ],
183
180
)
184
181
185
- yield process_data (data = data , cast_to = cast_to , response = response )
186
-
182
+ yield process_data (data = {"data" : data , "event" : sse .event }, cast_to = cast_to , response = response )
187
183
else :
188
184
data = sse .json ()
189
-
190
- if sse .event == "error" and is_mapping (data ) and data .get ("error" ):
185
+ if is_mapping (data ) and data .get ("error" ):
191
186
message = None
192
187
error = data .get ("error" )
193
188
if is_mapping (error ):
@@ -201,7 +196,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
201
196
body = data ["error" ],
202
197
)
203
198
204
- yield process_data (data = { " data" : data , "event" : sse . event } , cast_to = cast_to , response = response )
199
+ yield process_data (data = data , cast_to = cast_to , response = response )
205
200
206
201
# Ensure the entire stream is consumed
207
202
async for _sse in iterator :
0 commit comments