Skip to content

Commit 7c9d79b

Browse files
committed
Realtime: only cancel response if necessary
1 parent b459cc4 commit 7c9d79b

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/agents/realtime/openai_realtime.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self) -> None:
140140
self._ongoing_response: bool = False
141141
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
142142
self._playback_tracker: RealtimePlaybackTracker | None = None
143+
self._created_session: OpenAISessionObject | None = None
143144

144145
async def connect(self, options: RealtimeModelConfig) -> None:
145146
"""Establish a connection to the model and keep it alive."""
@@ -349,7 +350,14 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
349350
int(elapsed_ms),
350351
)
351352
await self._send_raw_message(converted)
352-
await self._cancel_response()
353+
354+
automatic_response_cancellation_enabled = (
355+
self._created_session
356+
and self._created_session.turn_detection
357+
and self._created_session.turn_detection.interrupt_response
358+
)
359+
if not automatic_response_cancellation_enabled:
360+
await self._cancel_response()
353361

354362
self._audio_state_tracker.on_interrupted()
355363
if self._playback_tracker:
@@ -483,6 +491,9 @@ async def _handle_ws_event(self, event: dict[str, Any]):
483491
await self._emit_event(RealtimeModelTurnEndedEvent())
484492
elif parsed.type == "session.created":
485493
await self._send_tracing_config(self._tracing_config)
494+
self._update_created_session(parsed.session) # type: ignore
495+
elif parsed.type == "session.updated":
496+
self._update_created_session(parsed.session) # type: ignore
486497
elif parsed.type == "error":
487498
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
488499
elif parsed.type == "conversation.item.deleted":
@@ -532,6 +543,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
532543
):
533544
await self._handle_output_item(parsed.item)
534545

546+
def _update_created_session(self, session: OpenAISessionObject) -> None:
547+
self._created_session = session
548+
if session.output_audio_format:
549+
self._audio_state_tracker.set_audio_format(session.output_audio_format)
550+
if self._playback_tracker:
551+
self._playback_tracker.set_audio_format(session.output_audio_format)
552+
535553
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
536554
session_config = self._get_session_config(model_settings)
537555
await self._send_raw_message(

0 commit comments

Comments
 (0)