solana-agent 31.1.7__py3-none-any.whl → 31.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1613 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import logging
6
+ from typing import Any, AsyncGenerator, Dict, Optional
7
+ import asyncio
8
+
9
+ import websockets
10
+
11
+ from solana_agent.interfaces.providers.realtime import (
12
+ BaseRealtimeSession,
13
+ RealtimeSessionOptions,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class OpenAIRealtimeWebSocketSession(BaseRealtimeSession):
20
+ """OpenAI Realtime WebSocket session (server-to-server) with audio in/out.
21
+
22
+ Notes:
23
+ - Expects 24kHz audio.
24
+ - Emits raw PCM16 bytes for output audio deltas.
25
+ - Provides separate async generators for input/output transcripts.
26
+ - You can toggle VAD via session.update and manually commit when disabled.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str,
32
+ url: str = "wss://api.openai.com/v1/realtime",
33
+ options: Optional[RealtimeSessionOptions] = None,
34
+ ) -> None:
35
+ self.api_key = api_key
36
+ self.url = url
37
+ self.options = options or RealtimeSessionOptions()
38
+
39
+ # Queues and state
40
+ self._ws = None
41
+ self._event_queue = asyncio.Queue()
42
+ self._audio_queue = asyncio.Queue()
43
+ self._in_tr_queue = asyncio.Queue()
44
+ self._out_tr_queue = asyncio.Queue()
45
+
46
+ # Tool/function state
47
+ self._recv_task = None
48
+ self._tool_executor = None
49
+ self._pending_calls = {}
50
+ self._active_tool_calls = 0
51
+ # Map call_id or item_id -> asyncio.Event when server has created the function_call item
52
+ self._call_ready_events = {}
53
+ # Track identifiers already executed to avoid duplicates
54
+ self._executed_call_ids = set()
55
+ # Track mapping from tool call identifiers to the originating response.id
56
+ # Prefer addressing function_call_output to a response to avoid conversation lookup races
57
+ self._call_response_ids = {}
58
+ # Accumulate function call arguments when streamed
59
+ self._call_args_accum = {}
60
+ self._call_names = {}
61
+ # Map from function_call item_id -> call_id for reliable output addressing
62
+ self._item_call_ids = {}
63
+
64
+ # Input/response state
65
+ self._pending_input_bytes = 0
66
+ self._commit_evt = asyncio.Event()
67
+ self._commit_inflight = False
68
+ self._response_active = False
69
+ self._server_auto_create_enabled = False
70
+ self._last_commit_ts = 0.0
71
+
72
+ # Session/event tracking
73
+ self._session_created_evt = asyncio.Event()
74
+ self._session_updated_evt = asyncio.Event()
75
+ self._awaiting_session_updated = False
76
+ self._last_session_patch = {}
77
+ self._last_session_updated_payload = {}
78
+ self._last_input_item_id = None
79
+ # Response generation tracking to bind tool outputs to the active response
80
+ self._response_generation = 0
81
+ # Track the currently active response.id (fallback when some events omit it)
82
+ self._active_response_id = None
83
+ # Accumulate assistant output text by response.id and flush on response completion
84
+ self._out_text_buffers = {}
85
+
86
+ # Outbound event correlation
87
+ self._event_seq = 0
88
+ self._sent_events = {}
89
+
90
+ async def connect(self) -> None: # pragma: no cover
91
+ # Defensive: ensure session events exist even if __init__ didn’t set them (older builds)
92
+ if not hasattr(self, "_session_created_evt") or not isinstance(
93
+ getattr(self, "_session_created_evt", None), asyncio.Event
94
+ ):
95
+ self._session_created_evt = asyncio.Event()
96
+ if not hasattr(self, "_session_updated_evt") or not isinstance(
97
+ getattr(self, "_session_updated_evt", None), asyncio.Event
98
+ ):
99
+ self._session_updated_evt = asyncio.Event()
100
+ headers = [
101
+ ("Authorization", f"Bearer {self.api_key}"),
102
+ ]
103
+ model = self.options.model or "gpt-realtime"
104
+ uri = f"{self.url}?model={model}"
105
+ logger.info(
106
+ "Realtime WS connecting: uri=%s, input=%s@%sHz, output=%s@%sHz, voice=%s, vad=%s",
107
+ uri,
108
+ self.options.input_mime,
109
+ self.options.input_rate_hz,
110
+ self.options.output_mime,
111
+ self.options.output_rate_hz,
112
+ self.options.voice,
113
+ self.options.vad_enabled,
114
+ )
115
+ self._ws = await websockets.connect(
116
+ uri, additional_headers=headers, max_size=None
117
+ )
118
+ logger.info("Connected to OpenAI Realtime WS: %s", uri)
119
+ self._recv_task = asyncio.create_task(self._recv_loop())
120
+
121
+ # Optionally wait briefly for session.created; some servers ignore pre-created updates
122
+ try:
123
+ await asyncio.wait_for(self._session_created_evt.wait(), timeout=1.0)
124
+ logger.info(
125
+ "Realtime WS: session.created observed before first session.update"
126
+ )
127
+ except asyncio.TimeoutError:
128
+ logger.info(
129
+ "Realtime WS: no session.created within 1.0s; sending session.update anyway"
130
+ )
131
+
132
+ # Build optional prompt block
133
+ prompt_block = None
134
+ if getattr(self.options, "prompt_id", None):
135
+ prompt_block = {
136
+ "id": self.options.prompt_id,
137
+ }
138
+ if getattr(self.options, "prompt_version", None):
139
+ prompt_block["version"] = self.options.prompt_version
140
+ if getattr(self.options, "prompt_variables", None):
141
+ prompt_block["variables"] = self.options.prompt_variables
142
+
143
+ # Configure session (instructions, tools). VAD handled per-request.
144
+ # Per server schema, turn_detection belongs under audio.input; set to None to disable.
145
+ # When VAD is disabled, explicitly set create_response to False if the server honors it
146
+ td_input = (
147
+ {"type": "server_vad", "create_response": True}
148
+ if self.options.vad_enabled
149
+ else {"type": "server_vad", "create_response": False}
150
+ )
151
+
152
+ # Build session.update per docs (nested audio object)
153
+ def _strip_tool_strict(tools_val):
154
+ try:
155
+ tools_list = list(tools_val or [])
156
+ except Exception:
157
+ return tools_val
158
+ cleaned = []
159
+ for t in tools_list:
160
+ try:
161
+ t2 = dict(t)
162
+ t2.pop("strict", None)
163
+ cleaned.append(t2)
164
+ except Exception:
165
+ cleaned.append(t)
166
+ return cleaned
167
+
168
+ session_payload: Dict[str, Any] = {
169
+ "type": "session.update",
170
+ "session": {
171
+ "type": "realtime",
172
+ "output_modalities": ["audio"],
173
+ "audio": {
174
+ "input": {
175
+ "format": {
176
+ "type": self.options.input_mime or "audio/pcm",
177
+ "rate": int(self.options.input_rate_hz or 24000),
178
+ },
179
+ "turn_detection": td_input,
180
+ },
181
+ "output": {
182
+ "format": {
183
+ "type": self.options.output_mime or "audio/pcm",
184
+ "rate": int(self.options.output_rate_hz or 24000),
185
+ },
186
+ "voice": self.options.voice,
187
+ "speed": float(
188
+ getattr(self.options, "voice_speed", 1.0) or 1.0
189
+ ),
190
+ },
191
+ },
192
+ # Note: no top-level turn_detection; nested under audio.input
193
+ **({"prompt": prompt_block} if prompt_block else {}),
194
+ "instructions": self.options.instructions or "",
195
+ **(
196
+ {"tools": _strip_tool_strict(self.options.tools)}
197
+ if self.options.tools
198
+ else {}
199
+ ),
200
+ **(
201
+ {"tool_choice": self.options.tool_choice}
202
+ if getattr(self.options, "tool_choice", None)
203
+ else {}
204
+ ),
205
+ },
206
+ }
207
+ logger.info(
208
+ "Realtime WS: sending session.update (voice=%s, vad=%s, output=%s@%s)",
209
+ self.options.voice,
210
+ self.options.vad_enabled,
211
+ (self.options.output_mime or "audio/pcm"),
212
+ int(self.options.output_rate_hz or 24000),
213
+ )
214
+ # Log exact session.update payload and mark awaiting session.updated
215
+ try:
216
+ logger.info(
217
+ "Realtime WS: sending session.update payload=%s",
218
+ json.dumps(session_payload.get("session", {}), sort_keys=True),
219
+ )
220
+ except Exception:
221
+ pass
222
+ self._last_session_patch = session_payload.get("session", {})
223
+ self._session_updated_evt = asyncio.Event()
224
+ self._awaiting_session_updated = True
225
+ # Quick sanity warnings
226
+ try:
227
+ sess = self._last_session_patch
228
+ instr = sess.get("instructions")
229
+ voice = ((sess.get("audio") or {}).get("output") or {}).get("voice")
230
+ if instr is None or (isinstance(instr, str) and instr.strip() == ""):
231
+ logger.warning(
232
+ "Realtime WS: instructions missing/empty in session.update"
233
+ )
234
+ if not voice:
235
+ logger.warning("Realtime WS: voice missing in session.update")
236
+ except Exception:
237
+ pass
238
+ await self._send_tracked(session_payload, label="session.update:init")
239
+
240
+ async def close(self) -> None: # pragma: no cover
241
+ if self._ws:
242
+ await self._ws.close()
243
+ self._ws = None
244
+ if self._recv_task:
245
+ self._recv_task.cancel()
246
+ self._recv_task = None
247
+ # Unblock any pending waiters to avoid dangling tasks
248
+ try:
249
+ self._commit_evt.set()
250
+ except Exception:
251
+ pass
252
+ try:
253
+ self._session_created_evt.set()
254
+ except Exception:
255
+ pass
256
+ try:
257
+ self._session_updated_evt.set()
258
+ except Exception:
259
+ pass
260
+
261
+ async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
262
+ if not self._ws:
263
+ raise RuntimeError("WebSocket not connected")
264
+ try:
265
+ await self._ws.send(json.dumps(payload))
266
+ finally:
267
+ try:
268
+ ptype = payload.get("type")
269
+ peid = payload.get("event_id")
270
+ except Exception:
271
+ ptype = str(type(payload))
272
+ peid = None
273
+ logger.debug("WS send: %s (event_id=%s)", ptype, peid)
274
+
275
+ def _next_event_id(self) -> str:
276
+ try:
277
+ self._event_seq += 1
278
+ return f"evt-{self._event_seq}"
279
+ except Exception:
280
+ # Fallback to random when counters fail
281
+ import uuid as _uuid
282
+
283
+ return str(_uuid.uuid4())
284
+
285
+ async def _send_tracked(
286
+ self, payload: Dict[str, Any], label: Optional[str] = None
287
+ ) -> str:
288
+ """Attach an event_id and retain a snapshot for correlating error events."""
289
+ try:
290
+ eid = payload.get("event_id") or self._next_event_id()
291
+ payload["event_id"] = eid
292
+ self._sent_events[eid] = {
293
+ "label": label or (payload.get("type") or "client.event"),
294
+ "type": payload.get("type"),
295
+ "payload": payload.copy(),
296
+ "ts": asyncio.get_event_loop().time(),
297
+ }
298
+ except Exception:
299
+ eid = payload.get("event_id") or ""
300
+ await self._send(payload)
301
+ return eid
302
+
303
+ async def _recv_loop(self) -> None: # pragma: no cover
304
+ assert self._ws is not None
305
+ try:
306
+ async for raw in self._ws:
307
+ try:
308
+ data = json.loads(raw)
309
+ etype = data.get("type")
310
+ eid = data.get("event_id")
311
+ logger.debug("WS recv: %s (event_id=%s)", etype, eid)
312
+ # Track active response.id from any response.* event that carries it
313
+ try:
314
+ if isinstance(etype, str) and etype.startswith("response."):
315
+ _resp = data.get("response") or {}
316
+ _rid = _resp.get("id")
317
+ if _rid:
318
+ self._active_response_id = _rid
319
+ except Exception:
320
+ pass
321
+ # Demux streams
322
+ if etype in ("response.output_audio.delta", "response.audio.delta"):
323
+ b64 = data.get("delta") or ""
324
+ if b64:
325
+ try:
326
+ chunk = base64.b64decode(b64)
327
+ self._audio_queue.put_nowait(chunk)
328
+ logger.info("Audio delta bytes=%d", len(chunk))
329
+ try:
330
+ # New response detected if we were previously inactive
331
+ if not getattr(self, "_response_active", False):
332
+ self._response_generation = (
333
+ int(
334
+ getattr(self, "_response_generation", 0)
335
+ )
336
+ + 1
337
+ )
338
+ self._response_active = True
339
+ except Exception:
340
+ pass
341
+ except Exception:
342
+ pass
343
+ elif etype == "response.text.delta":
344
+ # Some servers emit generic text deltas with metadata marking transcription
345
+ metadata = data.get("response", {}).get("metadata", {})
346
+ if metadata.get("type") == "transcription":
347
+ delta = data.get("delta") or ""
348
+ if delta:
349
+ self._in_tr_queue.put_nowait(delta)
350
+ logger.info("Input transcript delta: %r", delta[:120])
351
+ else:
352
+ logger.debug("Input transcript delta: empty")
353
+ elif etype in (
354
+ "response.function_call_arguments.delta",
355
+ "response.function_call_arguments.done",
356
+ ):
357
+ # Capture streamed function-call arguments early and mark readiness on .done
358
+ try:
359
+ resp = data.get("response") or {}
360
+ rid = resp.get("id") or getattr(
361
+ self, "_active_response_id", None
362
+ )
363
+ # Many servers include call_id at top-level for these events
364
+ call_id = data.get("call_id") or data.get("id")
365
+ # Some servers include the function call item under 'item'
366
+ if not call_id:
367
+ call_id = (data.get("item") or {}).get("call_id") or (
368
+ data.get("item") or {}
369
+ ).get("id")
370
+ # Name can appear directly or under item
371
+ name = data.get("name") or (data.get("item") or {}).get(
372
+ "name"
373
+ )
374
+ if name:
375
+ self._call_names[call_id] = name
376
+ if rid and call_id:
377
+ self._call_response_ids[call_id] = rid
378
+ if etype.endswith("delta"):
379
+ delta = data.get("delta") or ""
380
+ if call_id and delta:
381
+ self._call_args_accum[call_id] = (
382
+ self._call_args_accum.get(call_id, "") + delta
383
+ )
384
+ else: # .done
385
+ if call_id:
386
+ # Mark call ready and enqueue pending execution if not already
387
+ ev = self._call_ready_events.get(call_id)
388
+ if not ev:
389
+ ev = asyncio.Event()
390
+ self._call_ready_events[call_id] = ev
391
+ ev.set()
392
+ # Register pending call using accumulated args
393
+ if call_id not in getattr(
394
+ self, "_pending_calls", {}
395
+ ):
396
+ args_text = (
397
+ self._call_args_accum.get(call_id, "{}")
398
+ or "{}"
399
+ )
400
+ if not hasattr(self, "_pending_calls"):
401
+ self._pending_calls = {}
402
+ self._pending_calls[call_id] = {
403
+ "name": self._call_names.get(call_id),
404
+ "args": args_text,
405
+ "gen": int(
406
+ getattr(self, "_response_generation", 0)
407
+ ),
408
+ "call_id": call_id,
409
+ "item_id": None,
410
+ }
411
+ self._executed_call_ids.add(call_id)
412
+ await self._execute_pending_call(call_id)
413
+ except Exception:
414
+ pass
415
+ elif etype == "response.output_text.delta":
416
+ # Assistant textual output delta. Buffer per response.id.
417
+ # Prefer the audio transcript stream for final transcript; only use text
418
+ # deltas if no audio transcript arrives.
419
+ try:
420
+ rid = (data.get("response") or {}).get("id") or getattr(
421
+ self, "_active_response_id", None
422
+ )
423
+ delta = data.get("delta") or ""
424
+ if rid and delta:
425
+ buf = self._out_text_buffers.setdefault(
426
+ rid, {"text": "", "has_audio": False}
427
+ )
428
+ # Only accumulate text when we don't yet have audio transcript
429
+ if not bool(buf.get("has_audio")):
430
+ buf["text"] = str(buf.get("text", "")) + delta
431
+ logger.debug(
432
+ "Buffered assistant text delta (rid=%s, len=%d)",
433
+ rid,
434
+ len(delta),
435
+ )
436
+ except Exception:
437
+ pass
438
+ elif etype == "conversation.item.input_audio_transcription.delta":
439
+ delta = data.get("delta") or ""
440
+ if delta:
441
+ self._in_tr_queue.put_nowait(delta)
442
+ logger.info("Input transcript delta (GA): %r", delta[:120])
443
+ else:
444
+ logger.debug("Input transcript delta (GA): empty")
445
+ elif etype in (
446
+ "response.output_audio_transcript.delta",
447
+ "response.audio_transcript.delta",
448
+ ):
449
+ # Assistant audio transcript delta (authoritative for spoken output)
450
+ try:
451
+ rid = (data.get("response") or {}).get("id") or getattr(
452
+ self, "_active_response_id", None
453
+ )
454
+ delta = data.get("delta") or ""
455
+ if rid and delta:
456
+ buf = self._out_text_buffers.setdefault(
457
+ rid, {"text": "", "has_audio": False}
458
+ )
459
+ buf["has_audio"] = True
460
+ buf["text"] = str(buf.get("text", "")) + delta
461
+ logger.debug(
462
+ "Buffered audio transcript delta (rid=%s, len=%d)",
463
+ rid,
464
+ len(delta),
465
+ )
466
+ except Exception:
467
+ pass
468
+ elif etype == "input_audio_buffer.committed":
469
+ # Track the committed audio item id for OOB transcription referencing
470
+ item_id = data.get("item_id") or data.get("id")
471
+ if item_id:
472
+ self._last_input_item_id = item_id
473
+ try:
474
+ self._last_commit_ts = asyncio.get_event_loop().time()
475
+ except Exception:
476
+ pass
477
+ logger.info(
478
+ "Realtime WS: input_audio_buffer committed: item_id=%s",
479
+ item_id,
480
+ )
481
+ # Clear commit in-flight flag on ack and signal waiters
482
+ try:
483
+ self._commit_inflight = False
484
+ except Exception:
485
+ pass
486
+ try:
487
+ self._commit_evt.set()
488
+ except Exception:
489
+ pass
490
+ elif etype in (
491
+ "response.output_audio.done",
492
+ "response.audio.done",
493
+ ):
494
+ # End of audio stream for the response; stop audio iterator but keep WS open for transcripts
495
+ logger.info(
496
+ "Realtime WS: output audio done; ending audio stream"
497
+ )
498
+ # If we have a buffered transcript for this response, flush it now
499
+ try:
500
+ rid = (data.get("response") or {}).get("id") or getattr(
501
+ self, "_active_response_id", None
502
+ )
503
+ if rid and rid in self._out_text_buffers:
504
+ final = str(
505
+ self._out_text_buffers.get(rid, {}).get("text")
506
+ or ""
507
+ )
508
+ if final:
509
+ self._out_tr_queue.put_nowait(final)
510
+ logger.debug(
511
+ "Flushed assistant transcript on audio.done (rid=%s, len=%d)",
512
+ rid,
513
+ len(final),
514
+ )
515
+ self._out_text_buffers.pop(rid, None)
516
+ except Exception:
517
+ pass
518
+ try:
519
+ self._audio_queue.put_nowait(None)
520
+ except Exception:
521
+ pass
522
+ try:
523
+ self._response_active = False
524
+ except Exception:
525
+ pass
526
+ try:
527
+ # Clear active response id when audio for that response is done
528
+ self._active_response_id = None
529
+ except Exception:
530
+ pass
531
+ # Don't break; we still want to receive input transcription events
532
+ elif etype in (
533
+ "response.completed",
534
+ "response.complete",
535
+ "response.done",
536
+ ):
537
+ # Do not terminate audio stream here; completion may be for a function_call
538
+ metadata = data.get("response", {}).get("metadata", {})
539
+ if metadata.get("type") == "transcription":
540
+ logger.info("Realtime WS: transcription response completed")
541
+ # Flush buffered assistant transcript, if any
542
+ try:
543
+ rid = (data.get("response") or {}).get("id")
544
+ if rid and rid in self._out_text_buffers:
545
+ final = str(
546
+ self._out_text_buffers.get(rid, {}).get("text")
547
+ or ""
548
+ )
549
+ if final:
550
+ self._out_tr_queue.put_nowait(final)
551
+ logger.debug(
552
+ "Flushed assistant transcript on response.done (rid=%s, len=%d)",
553
+ rid,
554
+ len(final),
555
+ )
556
+ self._out_text_buffers.pop(rid, None)
557
+ except Exception:
558
+ pass
559
+ try:
560
+ self._response_active = False
561
+ except Exception:
562
+ pass
563
+ try:
564
+ # Response lifecycle ended; clear active id
565
+ self._active_response_id = None
566
+ except Exception:
567
+ pass
568
+ elif etype in ("response.text.done", "response.output_text.done"):
569
+ metadata = data.get("response", {}).get("metadata", {})
570
+ if metadata.get("type") == "transcription":
571
+ tr = data.get("text") or ""
572
+ if tr:
573
+ try:
574
+ self._in_tr_queue.put_nowait(tr)
575
+ logger.debug(
576
+ "Input transcript completed: %r", tr[:120]
577
+ )
578
+ except Exception:
579
+ pass
580
+ else:
581
+ # For assistant text-only completions without audio, flush buffered text
582
+ try:
583
+ rid = (data.get("response") or {}).get("id") or getattr(
584
+ self, "_active_response_id", None
585
+ )
586
+ if rid and rid in self._out_text_buffers:
587
+ final = str(
588
+ self._out_text_buffers.get(rid, {}).get("text")
589
+ or ""
590
+ )
591
+ if final:
592
+ self._out_tr_queue.put_nowait(final)
593
+ logger.debug(
594
+ "Flushed assistant transcript on text.done (rid=%s, len=%d)",
595
+ rid,
596
+ len(final),
597
+ )
598
+ self._out_text_buffers.pop(rid, None)
599
+ except Exception:
600
+ pass
601
+ elif (
602
+ etype == "conversation.item.input_audio_transcription.completed"
603
+ ):
604
+ tr = data.get("transcript") or ""
605
+ if tr:
606
+ try:
607
+ self._in_tr_queue.put_nowait(tr)
608
+ logger.debug(
609
+ "Input transcript completed (GA): %r", tr[:120]
610
+ )
611
+ except Exception:
612
+ pass
613
+ elif etype == "session.updated":
614
+ sess = data.get("session", {})
615
+ self._last_session_updated_payload = sess
616
+ # Track server auto-create enablement (turn_detection.create_response)
617
+ try:
618
+ td_root = sess.get("turn_detection")
619
+ td_input = (
620
+ (sess.get("audio") or {})
621
+ .get("input", {})
622
+ .get("turn_detection")
623
+ )
624
+ td = td_root if td_root is not None else td_input
625
+ # If local VAD is disabled, force auto-create disabled
626
+ if not bool(getattr(self.options, "vad_enabled", False)):
627
+ self._server_auto_create_enabled = False
628
+ else:
629
+ self._server_auto_create_enabled = bool(
630
+ isinstance(td, dict)
631
+ and bool(td.get("create_response"))
632
+ )
633
+ logger.debug(
634
+ "Realtime WS: server_auto_create_enabled=%s",
635
+ self._server_auto_create_enabled,
636
+ )
637
+ except Exception:
638
+ pass
639
+ # Mark that the latest update has been applied
640
+ try:
641
+ self._awaiting_session_updated = False
642
+ self._session_updated_evt.set()
643
+ except Exception:
644
+ pass
645
+ try:
646
+ logger.info(
647
+ "Realtime WS: session.updated payload=%s",
648
+ json.dumps(sess, sort_keys=True),
649
+ )
650
+ except Exception:
651
+ logger.info(
652
+ "Realtime WS: session updated (payload dump failed)"
653
+ )
654
+ # Extra guardrails: warn if key fields absent/empty
655
+ try:
656
+ instr = sess.get("instructions")
657
+ voice = sess.get("voice") or (
658
+ (sess.get("audio") or {}).get("output", {}).get("voice")
659
+ )
660
+ if instr is None or (
661
+ isinstance(instr, str) and instr.strip() == ""
662
+ ):
663
+ logger.warning(
664
+ "Realtime WS: session.updated has empty/missing instructions"
665
+ )
666
+ if not voice:
667
+ logger.warning(
668
+ "Realtime WS: session.updated missing voice"
669
+ )
670
+ except Exception:
671
+ pass
672
+ elif etype == "session.created":
673
+ logger.info("Realtime WS: session created")
674
+ try:
675
+ self._session_created_evt.set()
676
+ except Exception:
677
+ pass
678
+ elif etype == "error" or (
679
+ isinstance(etype, str)
680
+ and (
681
+ etype.endswith(".error")
682
+ or ("error" in etype)
683
+ or ("failed" in etype)
684
+ )
685
+ ):
686
+ # Surface server errors explicitly
687
+ try:
688
+ logger.error(
689
+ "Realtime WS error event: %s",
690
+ json.dumps(data, sort_keys=True),
691
+ )
692
+ except Exception:
693
+ logger.error(
694
+ "Realtime WS error event (payload dump failed)"
695
+ )
696
+ # Correlate to the originating client event by event_id
697
+ try:
698
+ eid = data.get("event_id") or data.get("error", {}).get(
699
+ "event_id"
700
+ )
701
+ if eid and eid in self._sent_events:
702
+ sent = self._sent_events.get(eid) or {}
703
+ logger.error(
704
+ "Realtime WS error correlated: event_id=%s sent_label=%s sent_type=%s",
705
+ eid,
706
+ sent.get("label"),
707
+ sent.get("type"),
708
+ )
709
+ except Exception:
710
+ pass
711
+ # No legacy fallback; rely on current config/state.
712
+ # Always also publish raw events
713
+ try:
714
+ self._event_queue.put_nowait(data)
715
+ except Exception:
716
+ pass
717
+ # Handle tool/function calls (GA-only triggers)
718
+ if etype in (
719
+ "conversation.item.created",
720
+ "conversation.item.added",
721
+ ):
722
+ try:
723
+ item = data.get("item") or {}
724
+ if item.get("type") == "function_call":
725
+ call_id = item.get("call_id")
726
+ item_id = item.get("id")
727
+ # Mark readiness for both identifiers if available
728
+ for cid in filter(None, [call_id, item_id]):
729
+ ev = self._call_ready_events.get(cid)
730
+ if not ev:
731
+ ev = asyncio.Event()
732
+ self._call_ready_events[cid] = ev
733
+ ev.set()
734
+ logger.debug(
735
+ "Call ready via item.%s: call_id=%s item_id=%s",
736
+ "created" if etype.endswith("created") else "added",
737
+ call_id,
738
+ item_id,
739
+ )
740
+ except Exception:
741
+ pass
742
+ elif etype in (
743
+ "response.output_item.added",
744
+ "response.output_item.created",
745
+ ):
746
+ # Map response-scoped function_call items to response_id and ack function_call_output
747
+ try:
748
+ resp = data.get("response") or {}
749
+ rid = resp.get("id") or getattr(
750
+ self, "_active_response_id", None
751
+ )
752
+ item = data.get("item") or {}
753
+ itype = item.get("type")
754
+ if itype == "function_call":
755
+ call_id = item.get("call_id")
756
+ item_id = item.get("id")
757
+ name = item.get("name")
758
+ if name and call_id:
759
+ self._call_names[call_id] = name
760
+ if call_id and item_id:
761
+ self._item_call_ids[item_id] = call_id
762
+ # Bind call identifiers to the response id for response-scoped outputs
763
+ for cid in filter(None, [call_id, item_id]):
764
+ if rid:
765
+ self._call_response_ids[cid] = rid
766
+ # Signal readiness for these identifiers
767
+ for cid in filter(None, [call_id, item_id]):
768
+ ev = self._call_ready_events.get(cid)
769
+ if not ev:
770
+ ev = asyncio.Event()
771
+ self._call_ready_events[cid] = ev
772
+ ev.set()
773
+ logger.debug(
774
+ "Mapped function_call via response.%s: call_id=%s item_id=%s response_id=%s",
775
+ "created" if etype.endswith("created") else "added",
776
+ call_id,
777
+ item_id,
778
+ rid,
779
+ )
780
+ except Exception:
781
+ pass
782
+ elif etype in (
783
+ "response.done",
784
+ "response.completed",
785
+ "response.complete",
786
+ ):
787
+ # Also detect GA function_call items in response.output
788
+ try:
789
+ resp = data.get("response", {})
790
+ rid = resp.get("id")
791
+ out_items = resp.get("output", [])
792
+ for item in out_items:
793
+ if item.get("type") == "function_call":
794
+ call_id = item.get("call_id")
795
+ item_id = item.get("id")
796
+ # Prefer a canonical id to reference later; try item_id first, then call_id
797
+ canonical_id = item_id or call_id
798
+ name = item.get("name")
799
+ args = item.get("arguments") or "{}"
800
+ status = item.get("status")
801
+ if call_id and item_id:
802
+ self._item_call_ids[item_id] = call_id
803
+ # Map this call to the response id for response-scoped outputs
804
+ for cid in filter(
805
+ None, [call_id, item_id, canonical_id]
806
+ ):
807
+ if rid:
808
+ self._call_response_ids[cid] = rid
809
+ if (
810
+ canonical_id in self._executed_call_ids
811
+ or (
812
+ call_id
813
+ and call_id in self._executed_call_ids
814
+ )
815
+ or (
816
+ item_id
817
+ and item_id in self._executed_call_ids
818
+ )
819
+ ):
820
+ continue
821
+ if status and str(status).lower() not in (
822
+ "completed",
823
+ "complete",
824
+ "done",
825
+ ):
826
+ # Wait for completion in a later event
827
+ continue
828
+ if not hasattr(self, "_pending_calls"):
829
+ self._pending_calls = {}
830
+ self._pending_calls[canonical_id] = {
831
+ "name": name,
832
+ "args": args,
833
+ # Bind this call to the current response generation
834
+ "gen": int(
835
+ getattr(self, "_response_generation", 0)
836
+ ),
837
+ # Keep both identifiers for fallback when posting outputs
838
+ "call_id": call_id,
839
+ "item_id": item_id,
840
+ }
841
+ # Mark this call as ready (server has produced the item)
842
+ for cid in filter(
843
+ None, [canonical_id, call_id, item_id]
844
+ ):
845
+ ev = self._call_ready_events.get(cid)
846
+ if not ev:
847
+ ev = asyncio.Event()
848
+ self._call_ready_events[cid] = ev
849
+ ev.set()
850
+ logger.debug(
851
+ "Call ready via response.done: call_id=%s item_id=%s canonical_id=%s",
852
+ call_id,
853
+ item_id,
854
+ canonical_id,
855
+ )
856
+ for cid in filter(
857
+ None, [canonical_id, call_id, item_id]
858
+ ):
859
+ self._executed_call_ids.add(cid)
860
+ await self._execute_pending_call(canonical_id)
861
+ except Exception:
862
+ pass
863
+
864
+ # No ack/error event mapping; server emits concrete lifecycle events
865
+ except Exception:
866
+ continue
867
+ except Exception:
868
+ logger.exception("Realtime WS receive loop error")
869
+ finally:
870
+ # Close queues gracefully
871
+ for q in (
872
+ self._audio_queue,
873
+ self._in_tr_queue,
874
+ self._out_tr_queue,
875
+ self._event_queue,
876
+ ):
877
+ try:
878
+ q.put_nowait(None) # type: ignore
879
+ except Exception:
880
+ pass
881
+
882
+ # --- Client event helpers ---
883
+ async def update_session(
884
+ self, session_patch: Dict[str, Any]
885
+ ) -> None: # pragma: no cover
886
+ # Build nested session.update per docs. Only include provided fields.
887
+ raw = dict(session_patch or {})
888
+ patch: Dict[str, Any] = {}
889
+ audio_patch: Dict[str, Any] = {}
890
+
891
+ try:
892
+ audio = dict(raw.get("audio") or {})
893
+ # Normalize turn_detection to audio.input per server schema
894
+ include_td = False
895
+ turn_det = None
896
+ inp = dict(audio.get("input") or {})
897
+ if "turn_detection" in inp:
898
+ include_td = True
899
+ turn_det = inp.get("turn_detection")
900
+ if "turn_detection" in audio:
901
+ include_td = True
902
+ turn_det = audio.get("turn_detection")
903
+ if "turn_detection" in raw:
904
+ include_td = True
905
+ turn_det = raw.get("turn_detection")
906
+
907
+ inp = dict(audio.get("input") or {})
908
+ out = dict(audio.get("output") or {})
909
+
910
+ # Input format
911
+ fmt_in = inp.get("format", raw.get("input_audio_format"))
912
+ if fmt_in is not None:
913
+ if isinstance(fmt_in, dict):
914
+ audio_patch.setdefault("input", {})["format"] = fmt_in
915
+ else:
916
+ ftype = str(fmt_in)
917
+ if ftype == "pcm16":
918
+ ftype = "audio/pcm"
919
+ audio_patch.setdefault("input", {})["format"] = {
920
+ "type": ftype,
921
+ "rate": int(self.options.input_rate_hz or 24000),
922
+ }
923
+
924
+ # Optional input extras
925
+ for key in ("noise_reduction", "transcription"):
926
+ if key in audio:
927
+ audio_patch[key] = audio.get(key)
928
+
929
+ # Apply turn_detection under audio.input if provided (allow None to disable)
930
+ if include_td:
931
+ audio_patch.setdefault("input", {})["turn_detection"] = turn_det
932
+
933
+ # Output format/voice/speed
934
+ op: Dict[str, Any] = {}
935
+ fmt_out = out.get("format", raw.get("output_audio_format"))
936
+ if fmt_out is not None:
937
+ if isinstance(fmt_out, dict):
938
+ op["format"] = fmt_out
939
+ else:
940
+ ftype = str(fmt_out)
941
+ if ftype == "pcm16":
942
+ ftype = "audio/pcm"
943
+ op["format"] = {
944
+ "type": ftype,
945
+ "rate": int(self.options.output_rate_hz or 24000),
946
+ }
947
+ if "voice" in out:
948
+ op["voice"] = out.get("voice")
949
+ if "speed" in out:
950
+ op["speed"] = out.get("speed")
951
+ # Convenience: allow top-level overrides
952
+ if "voice" in raw and "voice" not in op:
953
+ op["voice"] = raw.get("voice")
954
+ if "speed" in raw and "speed" not in op:
955
+ op["speed"] = raw.get("speed")
956
+ if op:
957
+ audio_patch.setdefault("output", {}).update(op)
958
+ except Exception:
959
+ pass
960
+
961
+ if audio_patch:
962
+ patch["audio"] = audio_patch
963
+
964
+ # Always include session.type in updates
965
+ patch["type"] = "realtime"
966
+
967
+ # No top-level turn_detection
968
+
969
+ def _strip_tool_strict(tools_val):
970
+ try:
971
+ tools_list = list(tools_val or [])
972
+ except Exception:
973
+ return tools_val
974
+ cleaned = []
975
+ for t in tools_list:
976
+ try:
977
+ t2 = dict(t)
978
+ t2.pop("strict", None)
979
+ cleaned.append(t2)
980
+ except Exception:
981
+ cleaned.append(t)
982
+ return cleaned
983
+
984
+ # Pass through other documented fields if present
985
+ for k in (
986
+ "model",
987
+ "output_modalities",
988
+ "prompt",
989
+ "instructions",
990
+ "tools",
991
+ "tool_choice",
992
+ "include",
993
+ "max_output_tokens",
994
+ "tracing",
995
+ "truncation",
996
+ ):
997
+ if k in raw:
998
+ if k == "tools":
999
+ patch[k] = _strip_tool_strict(raw[k])
1000
+ else:
1001
+ patch[k] = raw[k]
1002
+
1003
+ # Ensure tools are cleaned even if provided only under audio or elsewhere
1004
+ if "tools" in patch:
1005
+ patch["tools"] = _strip_tool_strict(patch["tools"]) # idempotent
1006
+
1007
+ payload = {"type": "session.update", "session": patch}
1008
+ # Mark awaiting updated and store last patch
1009
+ self._last_session_patch = patch or {}
1010
+ self._session_updated_evt = asyncio.Event()
1011
+ self._awaiting_session_updated = True
1012
+ # Log payload and warn if potentially clearing/omitting critical fields
1013
+ try:
1014
+ logger.info(
1015
+ "Realtime WS: sending session.update payload=%s",
1016
+ json.dumps(self._last_session_patch, sort_keys=True),
1017
+ )
1018
+ if "instructions" in self._last_session_patch and (
1019
+ (self._last_session_patch.get("instructions") or "").strip() == ""
1020
+ ):
1021
+ logger.warning(
1022
+ "Realtime WS: session.update sets empty instructions; this clears them"
1023
+ )
1024
+ out_cfg = (self._last_session_patch.get("audio") or {}).get("output") or {}
1025
+ if "voice" in out_cfg and not out_cfg.get("voice"):
1026
+ logger.warning("Realtime WS: session.update provides empty voice")
1027
+ if "instructions" not in self._last_session_patch:
1028
+ logger.warning(
1029
+ "Realtime WS: session.update omits instructions; relying on previous instructions"
1030
+ )
1031
+ except Exception:
1032
+ pass
1033
+ await self._send(payload)
1034
+
1035
+ async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
1036
+ b64 = base64.b64encode(pcm16_bytes).decode("ascii")
1037
+ await self._send_tracked(
1038
+ {"type": "input_audio_buffer.append", "audio": b64},
1039
+ label="input_audio_buffer.append",
1040
+ )
1041
+ try:
1042
+ self._pending_input_bytes += len(pcm16_bytes)
1043
+ except Exception:
1044
+ pass
1045
+
1046
+ async def commit_input(self) -> None: # pragma: no cover
1047
+ try:
1048
+ # Skip commits while a response is active to avoid server errors
1049
+ if bool(getattr(self, "_response_active", False)):
1050
+ logger.warning("Realtime WS: skipping commit; response active")
1051
+ return
1052
+ # Avoid overlapping commits while awaiting server ack
1053
+ if bool(getattr(self, "_commit_inflight", False)):
1054
+ logger.warning("Realtime WS: skipping commit; commit in-flight")
1055
+ return
1056
+ # Avoid rapid duplicate commits
1057
+ last_commit = float(getattr(self, "_last_commit_ts", 0.0))
1058
+ if last_commit and (asyncio.get_event_loop().time() - last_commit) < 1.0:
1059
+ logger.warning("Realtime WS: skipping commit; committed recently")
1060
+ return
1061
+ # Require at least 100ms of audio (~4800 bytes at 24kHz mono 16-bit)
1062
+ min_bytes = int(0.1 * int(self.options.input_rate_hz or 24000) * 2)
1063
+ except Exception:
1064
+ min_bytes = 4800
1065
+ if int(getattr(self, "_pending_input_bytes", 0)) < min_bytes:
1066
+ try:
1067
+ logger.warning(
1068
+ "Realtime WS: skipping commit; buffer too small bytes=%d < %d",
1069
+ int(getattr(self, "_pending_input_bytes", 0)),
1070
+ min_bytes,
1071
+ )
1072
+ except Exception:
1073
+ pass
1074
+ return
1075
+ # Reset commit event before sending a new commit and mark as in-flight
1076
+ try:
1077
+ self._commit_evt = asyncio.Event()
1078
+ self._commit_inflight = True
1079
+ except Exception:
1080
+ pass
1081
+ await self._send_tracked(
1082
+ {"type": "input_audio_buffer.commit"}, label="input_audio_buffer.commit"
1083
+ )
1084
+ try:
1085
+ logger.info("Realtime WS: input_audio_buffer.commit sent")
1086
+ self._pending_input_bytes = 0
1087
+ self._last_commit_ts = asyncio.get_event_loop().time()
1088
+ except Exception:
1089
+ pass
1090
+
1091
+ async def clear_input(self) -> None: # pragma: no cover
1092
+ await self._send_tracked(
1093
+ {"type": "input_audio_buffer.clear"}, label="input_audio_buffer.clear"
1094
+ )
1095
+ # Reset last input reference and commit event to avoid stale references
1096
+ try:
1097
+ self._last_input_item_id = None
1098
+ self._commit_evt = asyncio.Event()
1099
+ except Exception:
1100
+ pass
1101
+
1102
+ async def create_response(
1103
+ self, response_patch: Optional[Dict[str, Any]] = None
1104
+ ) -> None: # pragma: no cover
1105
+ # Avoid duplicate responses: if server auto-creates after commit or one is already active, don't send.
1106
+ try:
1107
+ if getattr(self, "_response_active", False):
1108
+ logger.warning(
1109
+ "Realtime WS: response.create suppressed — response already active"
1110
+ )
1111
+ return
1112
+ auto = bool(getattr(self, "_server_auto_create_enabled", False))
1113
+ last_commit = float(getattr(self, "_last_commit_ts", 0.0))
1114
+ if auto and last_commit:
1115
+ # If we committed very recently (<1.0s), assume server will auto-create
1116
+ if (asyncio.get_event_loop().time() - last_commit) < 1.0:
1117
+ logger.info(
1118
+ "Realtime WS: response.create skipped — server auto-create expected"
1119
+ )
1120
+ return
1121
+ except Exception:
1122
+ pass
1123
+ # Wait briefly for commit event so we can reference the latest audio item when applicable
1124
+ if not self._last_input_item_id:
1125
+ try:
1126
+ await asyncio.wait_for(self._commit_evt.wait(), timeout=2.0)
1127
+ except asyncio.TimeoutError:
1128
+ pass
1129
+ # Ensure the latest session.update (if any) has been applied before responding
1130
+ if self._awaiting_session_updated:
1131
+ # Prefer an explicit session.updated; if absent, accept session.created
1132
+ try:
1133
+ if self._session_updated_evt.is_set():
1134
+ logger.info(
1135
+ "Realtime WS: response.create proceeding after session.updated (pre-set)"
1136
+ )
1137
+ else:
1138
+ await asyncio.wait_for(
1139
+ self._session_updated_evt.wait(), timeout=2.5
1140
+ )
1141
+ logger.info(
1142
+ "Realtime WS: response.create proceeding after session.updated"
1143
+ )
1144
+ except asyncio.TimeoutError:
1145
+ if self._session_created_evt.is_set():
1146
+ logger.info(
1147
+ "Realtime WS: response.create proceeding after session.created (no session.updated observed)"
1148
+ )
1149
+ # Best-effort: resend last session.update once more to apply voice/instructions
1150
+ try:
1151
+ if self._last_session_patch:
1152
+ logger.info(
1153
+ "Realtime WS: resending session.update to apply config before response"
1154
+ )
1155
+ # Reset awaiting flag and wait briefly again
1156
+ self._session_updated_evt = asyncio.Event()
1157
+ self._awaiting_session_updated = True
1158
+ # Ensure required session.type on retry
1159
+ _sess = dict(self._last_session_patch or {})
1160
+ _sess["type"] = "realtime"
1161
+ await self._send(
1162
+ {"type": "session.update", "session": _sess}
1163
+ )
1164
+ try:
1165
+ await asyncio.wait_for(
1166
+ self._session_updated_evt.wait(), timeout=1.0
1167
+ )
1168
+ logger.info(
1169
+ "Realtime WS: proceeding after retry session.updated"
1170
+ )
1171
+ except asyncio.TimeoutError:
1172
+ logger.warning(
1173
+ "Realtime WS: retry session.update did not yield session.updated in time"
1174
+ )
1175
+ except Exception:
1176
+ pass
1177
+ else:
1178
+ try:
1179
+ await asyncio.wait_for(
1180
+ self._session_created_evt.wait(), timeout=2.5
1181
+ )
1182
+ logger.info(
1183
+ "Realtime WS: response.create proceeding after session.created"
1184
+ )
1185
+ except asyncio.TimeoutError:
1186
+ logger.warning(
1187
+ "Realtime WS: neither session.updated nor session.created received in time; proceeding"
1188
+ )
1189
+
1190
+ # Then, create main response
1191
+ payload: Dict[str, Any] = {"type": "response.create"}
1192
+ if response_patch:
1193
+ payload["response"] = response_patch
1194
+ # Ensure response object exists; rely on session defaults for modalities/audio
1195
+ if "response" not in payload:
1196
+ payload["response"] = {}
1197
+ rp = payload["response"]
1198
+ # Sanitize unsupported fields that servers may reject
1199
+ try:
1200
+ rp.pop("modalities", None)
1201
+ rp.pop("audio", None)
1202
+ except Exception:
1203
+ pass
1204
+ rp.setdefault("metadata", {"type": "response"})
1205
+ # Attach input reference so the model links this response to last audio
1206
+ if self._last_input_item_id and "input" not in rp:
1207
+ rp["input"] = [{"type": "item_reference", "id": self._last_input_item_id}]
1208
+ try:
1209
+ has_ref = bool(self._last_input_item_id)
1210
+ logger.info(
1211
+ "Realtime WS: sending response.create (input_ref=%s)",
1212
+ has_ref,
1213
+ )
1214
+ except Exception:
1215
+ pass
1216
+ # Increment response generation when we intentionally start a new response
1217
+ try:
1218
+ if not getattr(self, "_response_active", False):
1219
+ self._response_generation = (
1220
+ int(getattr(self, "_response_generation", 0)) + 1
1221
+ )
1222
+ except Exception:
1223
+ pass
1224
+ await self._send_tracked(payload, label="response.create")
1225
+ try:
1226
+ self._response_active = True
1227
+ except Exception:
1228
+ pass
1229
+
1230
+ # --- Streams ---
1231
+ async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
1232
+ while True:
1233
+ item = await q.get()
1234
+ if item is None:
1235
+ break
1236
+ yield item
1237
+
1238
+ def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
1239
+ return self._iter_queue(self._event_queue)
1240
+
1241
+ def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
1242
+ return self._iter_queue(self._audio_queue)
1243
+
1244
+ def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1245
+ return self._iter_queue(self._in_tr_queue)
1246
+
1247
+ def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1248
+ return self._iter_queue(self._out_tr_queue)
1249
+
1250
+ def set_tool_executor(self, executor): # pragma: no cover
1251
+ self._tool_executor = executor
1252
+
1253
+ # Expose whether a function/tool call is currently pending
1254
+ def has_pending_tool_call(self) -> bool: # pragma: no cover
1255
+ try:
1256
+ return (
1257
+ bool(getattr(self, "_pending_calls", {}))
1258
+ or int(getattr(self, "_active_tool_calls", 0)) > 0
1259
+ )
1260
+ except Exception:
1261
+ return False
1262
+
1263
+ # --- Internal helpers for GA tool execution ---
1264
+ async def _execute_pending_call(self, call_id: Optional[str]) -> None:
1265
+ if not call_id:
1266
+ return
1267
+ # Peek without popping so we remain in a "pending/active" state
1268
+ pc = getattr(self, "_pending_calls", {}).get(call_id)
1269
+ if not pc or not self._tool_executor or not pc.get("name"):
1270
+ return
1271
+ try:
1272
+ # Drop if this call was bound to a previous response generation
1273
+ try:
1274
+ call_gen = int(pc.get("gen", 0))
1275
+ cur_gen = int(getattr(self, "_response_generation", 0))
1276
+ if call_gen and call_gen != cur_gen:
1277
+ logger.warning(
1278
+ "Skipping stale tool call: id=%s name=%s call_gen=%d cur_gen=%d",
1279
+ call_id,
1280
+ pc.get("name"),
1281
+ call_gen,
1282
+ cur_gen,
1283
+ )
1284
+ return
1285
+ except Exception:
1286
+ pass
1287
+ # Mark as active to keep timeouts from firing while tool runs
1288
+ try:
1289
+ self._active_tool_calls += 1
1290
+ except Exception:
1291
+ pass
1292
+ args_preview_len = len((pc.get("args") or ""))
1293
+ logger.info(
1294
+ "Executing tool: id=%s name=%s args_len=%d",
1295
+ call_id,
1296
+ pc.get("name"),
1297
+ args_preview_len,
1298
+ )
1299
+ args = pc.get("args") or "{}"
1300
+ try:
1301
+ parsed = json.loads(args)
1302
+ except Exception:
1303
+ parsed = {}
1304
+ start_ts = asyncio.get_event_loop().time()
1305
+ timeout_s = float(getattr(self.options, "tool_timeout_s", 300.0) or 300.0)
1306
+ try:
1307
+ result = await asyncio.wait_for(
1308
+ self._tool_executor(pc["name"], parsed), timeout=timeout_s
1309
+ )
1310
+ except asyncio.TimeoutError:
1311
+ logger.warning(
1312
+ "Tool timeout: id=%s name=%s exceeded %.1fs",
1313
+ call_id,
1314
+ pc.get("name"),
1315
+ timeout_s,
1316
+ )
1317
+ result = {"error": "tool_timeout"}
1318
+ dur = asyncio.get_event_loop().time() - start_ts
1319
+ try:
1320
+ result_summary = (
1321
+ f"keys={list(result.keys())[:5]}"
1322
+ if isinstance(result, dict)
1323
+ else type(result).__name__
1324
+ )
1325
+ except Exception:
1326
+ result_summary = "<unavailable>"
1327
+ logger.info(
1328
+ "Tool done: id=%s name=%s dur=%.2fs result=%s",
1329
+ call_id,
1330
+ pc.get("name"),
1331
+ dur,
1332
+ result_summary,
1333
+ )
1334
+ # Ensure the server has created the function_call item before we post output
1335
+ try:
1336
+ ev = self._call_ready_events.get(call_id)
1337
+ if ev:
1338
+ try:
1339
+ await asyncio.wait_for(ev.wait(), timeout=1.5)
1340
+ except asyncio.TimeoutError:
1341
+ logger.debug(
1342
+ "Call ready wait timed out; proceeding anyway: call_id=%s",
1343
+ call_id,
1344
+ )
1345
+ # Tiny jitter to help ordering on the server
1346
+ await asyncio.sleep(0.03)
1347
+ except Exception:
1348
+ pass
1349
+
1350
+ # Send tool result via conversation.item.create, then trigger response.create (per docs)
1351
+ try:
1352
+ # Derive a valid call_id and avoid sending item_id as call_id
1353
+ derived_call_id = (
1354
+ pc.get("call_id")
1355
+ or self._item_call_ids.get(pc.get("item_id"))
1356
+ or (
1357
+ call_id
1358
+ if isinstance(call_id, str) and call_id.startswith("call_")
1359
+ else None
1360
+ )
1361
+ )
1362
+ if not derived_call_id:
1363
+ logger.error(
1364
+ "Cannot send function_call_output: missing call_id (id=%s name=%s)",
1365
+ call_id,
1366
+ pc.get("name"),
1367
+ )
1368
+ else:
1369
+ await self._send_tracked(
1370
+ {
1371
+ "type": "conversation.item.create",
1372
+ "item": {
1373
+ "type": "function_call_output",
1374
+ "call_id": derived_call_id,
1375
+ "output": json.dumps(result),
1376
+ },
1377
+ },
1378
+ label="conversation.item.create:function_call_output",
1379
+ )
1380
+ logger.info(
1381
+ "conversation.item.create(function_call_output) sent call_id=%s",
1382
+ derived_call_id,
1383
+ )
1384
+ except Exception:
1385
+ logger.exception(
1386
+ "Failed to send function_call_output for call_id=%s", call_id
1387
+ )
1388
+ try:
1389
+ await asyncio.sleep(0.02)
1390
+ await self._send_tracked(
1391
+ {
1392
+ "type": "response.create",
1393
+ "response": {"metadata": {"type": "response"}},
1394
+ },
1395
+ label="response.create:after_tool",
1396
+ )
1397
+ logger.info("response.create sent after tool output")
1398
+ except Exception:
1399
+ logger.exception(
1400
+ "Failed to send follow-up response.create after tool output"
1401
+ )
1402
+ # Cleanup readiness event
1403
+ try:
1404
+ self._call_ready_events.pop(call_id, None)
1405
+ except Exception:
1406
+ pass
1407
+ except Exception:
1408
+ logger.exception(
1409
+ "Tool execution raised unexpectedly for call_id=%s", call_id
1410
+ )
1411
+ finally:
1412
+ # Clear pending state and decrement active count
1413
+ try:
1414
+ getattr(self, "_pending_calls", {}).pop(call_id, None)
1415
+ except Exception:
1416
+ pass
1417
+ try:
1418
+ self._active_tool_calls = max(0, int(self._active_tool_calls) - 1)
1419
+ logger.debug(
1420
+ "Pending tool calls decremented; active=%d", self._active_tool_calls
1421
+ )
1422
+ except Exception:
1423
+ pass
1424
+
1425
+
1426
+ class OpenAITranscriptionWebSocketSession(BaseRealtimeSession):
1427
+ """OpenAI Realtime Transcription WebSocket session.
1428
+
1429
+ This session is transcription-only per GA docs. It accepts PCM16 input and emits
1430
+ conversation.item.input_audio_transcription.* events.
1431
+ """
1432
+
1433
+ def __init__(
1434
+ self,
1435
+ api_key: str,
1436
+ url: str = "wss://api.openai.com/v1/realtime",
1437
+ options: Optional[RealtimeSessionOptions] = None,
1438
+ ) -> None:
1439
+ self.api_key = api_key
1440
+ self.url = url
1441
+ self.options = options or RealtimeSessionOptions()
1442
+ self._ws = None
1443
+ self._event_queue = asyncio.Queue()
1444
+ self._in_tr_queue = asyncio.Queue()
1445
+ self._recv_task = None
1446
+ self._last_input_item_id: Optional[str] = None
1447
+
1448
+ async def connect(self) -> None: # pragma: no cover
1449
+ headers = [("Authorization", f"Bearer {self.api_key}")]
1450
+ # Model is for TTS session; transcription model is set in session update
1451
+ model = self.options.model or "gpt-realtime"
1452
+ uri = f"{self.url}?model={model}"
1453
+ logger.info("Transcription WS connecting: uri=%s", uri)
1454
+ self._ws = await websockets.connect(
1455
+ uri, additional_headers=headers, max_size=None
1456
+ )
1457
+ self._recv_task = asyncio.create_task(self._recv_loop())
1458
+
1459
+ # Transcription session config per GA
1460
+ ts_payload: Dict[str, Any] = {
1461
+ "type": "transcription_session.update",
1462
+ "input_audio_format": "pcm16",
1463
+ "input_audio_transcription": {
1464
+ "model": getattr(self.options, "transcribe_model", None)
1465
+ or "gpt-4o-mini-transcribe",
1466
+ **(
1467
+ {"prompt": getattr(self.options, "transcribe_prompt", "")}
1468
+ if getattr(self.options, "transcribe_prompt", None) is not None
1469
+ else {}
1470
+ ),
1471
+ **(
1472
+ {"language": getattr(self.options, "transcribe_language", "en")}
1473
+ if getattr(self.options, "transcribe_language", None) is not None
1474
+ else {}
1475
+ ),
1476
+ },
1477
+ "turn_detection": (
1478
+ {"type": "server_vad"} if self.options.vad_enabled else None
1479
+ ),
1480
+ # Optionally include extra properties (e.g., logprobs)
1481
+ # "include": ["item.input_audio_transcription.logprobs"],
1482
+ }
1483
+ logger.info("Transcription WS: sending transcription_session.update")
1484
+ await self._send(ts_payload)
1485
+
1486
+ async def close(self) -> None: # pragma: no cover
1487
+ if self._ws:
1488
+ await self._ws.close()
1489
+ self._ws = None
1490
+ if self._recv_task:
1491
+ self._recv_task.cancel()
1492
+ self._recv_task = None
1493
+
1494
+ async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
1495
+ if not self._ws:
1496
+ raise RuntimeError("WebSocket not connected")
1497
+ await self._ws.send(json.dumps(payload))
1498
+
1499
+ async def _recv_loop(self) -> None: # pragma: no cover
1500
+ assert self._ws is not None
1501
+ try:
1502
+ async for raw in self._ws:
1503
+ try:
1504
+ data = json.loads(raw)
1505
+ etype = data.get("type")
1506
+ # Temporarily log at INFO to diagnose missing events
1507
+ logger.info("Transcription WS recv: %s", etype)
1508
+ if etype == "input_audio_buffer.committed":
1509
+ self._last_input_item_id = data.get("item_id") or data.get("id")
1510
+ if self._last_input_item_id:
1511
+ logger.info(
1512
+ "Transcription WS: input_audio_buffer committed: item_id=%s",
1513
+ self._last_input_item_id,
1514
+ )
1515
+ elif etype in (
1516
+ "conversation.item.input_audio_transcription.delta",
1517
+ "input_audio_transcription.delta",
1518
+ "response.input_audio_transcription.delta",
1519
+ ) or (
1520
+ isinstance(etype, str)
1521
+ and etype.endswith("input_audio_transcription.delta")
1522
+ ):
1523
+ delta = data.get("delta") or ""
1524
+ if delta:
1525
+ self._in_tr_queue.put_nowait(delta)
1526
+ logger.info("Transcription delta: %r", delta[:120])
1527
+ elif etype in (
1528
+ "conversation.item.input_audio_transcription.completed",
1529
+ "input_audio_transcription.completed",
1530
+ "response.input_audio_transcription.completed",
1531
+ ) or (
1532
+ isinstance(etype, str)
1533
+ and etype.endswith("input_audio_transcription.completed")
1534
+ ):
1535
+ tr = data.get("transcript") or ""
1536
+ if tr:
1537
+ self._in_tr_queue.put_nowait(tr)
1538
+ logger.debug("Transcription completed: %r", tr[:120])
1539
+ # Always publish raw events
1540
+ try:
1541
+ self._event_queue.put_nowait(data)
1542
+ except Exception:
1543
+ pass
1544
+ except Exception:
1545
+ continue
1546
+ except Exception:
1547
+ logger.exception("Transcription WS receive loop error")
1548
+ finally:
1549
+ for q in (self._in_tr_queue, self._event_queue):
1550
+ try:
1551
+ q.put_nowait(None) # type: ignore
1552
+ except Exception:
1553
+ pass
1554
+
1555
+ # --- Client events ---
1556
+ async def update_session(
1557
+ self, session_patch: Dict[str, Any]
1558
+ ) -> None: # pragma: no cover
1559
+ # Allow updating transcription session fields
1560
+ patch = {"type": "transcription_session.update", **session_patch}
1561
+ await self._send(patch)
1562
+
1563
+ async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
1564
+ b64 = base64.b64encode(pcm16_bytes).decode("ascii")
1565
+ await self._send({"type": "input_audio_buffer.append", "audio": b64})
1566
+ logger.info("Transcription WS: appended bytes=%d", len(pcm16_bytes))
1567
+
1568
+ async def commit_input(self) -> None: # pragma: no cover
1569
+ await self._send({"type": "input_audio_buffer.commit"})
1570
+ logger.info("Transcription WS: input_audio_buffer.commit sent")
1571
+
1572
+ async def clear_input(self) -> None: # pragma: no cover
1573
+ await self._send({"type": "input_audio_buffer.clear"})
1574
+
1575
+ async def create_response(
1576
+ self, response_patch: Optional[Dict[str, Any]] = None
1577
+ ) -> None: # pragma: no cover
1578
+ # No responses in transcription session
1579
+ return
1580
+
1581
+ # --- Streams ---
1582
+ async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
1583
+ while True:
1584
+ item = await q.get()
1585
+ if item is None:
1586
+ break
1587
+ yield item
1588
+
1589
+ def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
1590
+ return self._iter_queue(self._event_queue)
1591
+
1592
+ def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
1593
+ # No audio in transcription session
1594
+ async def _empty():
1595
+ if False:
1596
+ yield b""
1597
+
1598
+ return _empty()
1599
+
1600
+ def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1601
+ return self._iter_queue(self._in_tr_queue)
1602
+
1603
+ def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1604
+ # No assistant transcript in transcription-only mode
1605
+ async def _empty():
1606
+ if False:
1607
+ yield ""
1608
+
1609
+ return _empty()
1610
+
1611
+ def set_tool_executor(self, executor): # pragma: no cover
1612
+ # Not applicable for transcription-only
1613
+ return