solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. solana_agent/__init__.py +10 -5
  2. solana_agent/adapters/ffmpeg_transcoder.py +375 -0
  3. solana_agent/adapters/mongodb_adapter.py +15 -2
  4. solana_agent/adapters/openai_adapter.py +679 -0
  5. solana_agent/adapters/openai_realtime_ws.py +1813 -0
  6. solana_agent/adapters/pinecone_adapter.py +543 -0
  7. solana_agent/cli.py +128 -0
  8. solana_agent/client/solana_agent.py +180 -20
  9. solana_agent/domains/agent.py +13 -13
  10. solana_agent/domains/routing.py +18 -8
  11. solana_agent/factories/agent_factory.py +239 -38
  12. solana_agent/guardrails/pii.py +107 -0
  13. solana_agent/interfaces/client/client.py +95 -12
  14. solana_agent/interfaces/guardrails/guardrails.py +26 -0
  15. solana_agent/interfaces/plugins/plugins.py +2 -1
  16. solana_agent/interfaces/providers/__init__.py +0 -0
  17. solana_agent/interfaces/providers/audio.py +40 -0
  18. solana_agent/interfaces/providers/data_storage.py +9 -2
  19. solana_agent/interfaces/providers/llm.py +86 -9
  20. solana_agent/interfaces/providers/memory.py +13 -1
  21. solana_agent/interfaces/providers/realtime.py +212 -0
  22. solana_agent/interfaces/providers/vector_storage.py +53 -0
  23. solana_agent/interfaces/services/agent.py +27 -12
  24. solana_agent/interfaces/services/knowledge_base.py +59 -0
  25. solana_agent/interfaces/services/query.py +41 -8
  26. solana_agent/interfaces/services/routing.py +0 -1
  27. solana_agent/plugins/manager.py +37 -16
  28. solana_agent/plugins/registry.py +34 -19
  29. solana_agent/plugins/tools/__init__.py +0 -5
  30. solana_agent/plugins/tools/auto_tool.py +1 -0
  31. solana_agent/repositories/memory.py +332 -111
  32. solana_agent/services/__init__.py +1 -1
  33. solana_agent/services/agent.py +390 -241
  34. solana_agent/services/knowledge_base.py +768 -0
  35. solana_agent/services/query.py +1858 -153
  36. solana_agent/services/realtime.py +626 -0
  37. solana_agent/services/routing.py +104 -51
  38. solana_agent-31.4.0.dist-info/METADATA +1070 -0
  39. solana_agent-31.4.0.dist-info/RECORD +49 -0
  40. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
  41. solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
  42. solana_agent/adapters/llm_adapter.py +0 -160
  43. solana_agent-20.1.2.dist-info/METADATA +0 -464
  44. solana_agent-20.1.2.dist-info/RECORD +0 -35
  45. {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,1813 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import logging
6
+ from typing import Any, AsyncGenerator, Dict, Optional
7
+ import asyncio
8
+
9
+ import websockets
10
+
11
+ from solana_agent.interfaces.providers.realtime import (
12
+ BaseRealtimeSession,
13
+ RealtimeSessionOptions,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class OpenAIRealtimeWebSocketSession(BaseRealtimeSession):
20
+ """OpenAI Realtime WebSocket session (server-to-server) with audio in/out.
21
+
22
+ Notes:
23
+ - Expects 24kHz audio.
24
+ - Emits raw PCM16 bytes for output audio deltas.
25
+ - Provides separate async generators for input/output transcripts.
26
+ - You can toggle VAD via session.update and manually commit when disabled.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str,
32
+ url: str = "wss://api.openai.com/v1/realtime",
33
+ options: Optional[RealtimeSessionOptions] = None,
34
+ ) -> None:
35
+ self.api_key = api_key
36
+ self.url = url
37
+ self.options = options or RealtimeSessionOptions()
38
+
39
+ # Queues and state
40
+ self._ws = None
41
+ self._event_queue = asyncio.Queue()
42
+ self._audio_queue = asyncio.Queue()
43
+ self._in_tr_queue = asyncio.Queue()
44
+ self._out_tr_queue = asyncio.Queue()
45
+
46
+ # Tool/function state
47
+ self._recv_task = None
48
+ self._tool_executor = None
49
+ self._pending_calls = {}
50
+ self._active_tool_calls = 0
51
+ # Map call_id or item_id -> asyncio.Event when server has created the function_call item
52
+ self._call_ready_events = {}
53
+ # Track identifiers already executed to avoid duplicates
54
+ self._executed_call_ids = set()
55
+ # Track mapping from tool call identifiers to the originating response.id
56
+ # Prefer addressing function_call_output to a response to avoid conversation lookup races
57
+ self._call_response_ids = {}
58
+ # Accumulate function call arguments when streamed
59
+ self._call_args_accum = {}
60
+ self._call_names = {}
61
+ # Map from function_call item_id -> call_id for reliable output addressing
62
+ self._item_call_ids = {}
63
+
64
+ # Input/response state
65
+ self._pending_input_bytes = 0
66
+ self._commit_evt = asyncio.Event()
67
+ self._commit_inflight = False
68
+ self._response_active = False
69
+ self._server_auto_create_enabled = False
70
+ self._last_commit_ts = 0.0
71
+
72
+ # Session/event tracking
73
+ self._session_created_evt = asyncio.Event()
74
+ self._session_updated_evt = asyncio.Event()
75
+ self._awaiting_session_updated = False
76
+ self._last_session_patch = {}
77
+ self._last_session_updated_payload = {}
78
+ self._last_input_item_id = None
79
+ # Response generation tracking to bind tool outputs to the active response
80
+ self._response_generation = 0
81
+ # Track the currently active response.id (fallback when some events omit it)
82
+ self._active_response_id = None
83
+ # Accumulate assistant output text by response.id and flush on response completion
84
+ self._out_text_buffers = {}
85
+
86
+ # Outbound event correlation
87
+ self._event_seq = 0
88
+ self._sent_events = {}
89
+
90
+ async def connect(self) -> None: # pragma: no cover
91
+ # Defensive: ensure session events exist even if __init__ didn’t set them (older builds)
92
+ if not hasattr(self, "_session_created_evt") or not isinstance(
93
+ getattr(self, "_session_created_evt", None), asyncio.Event
94
+ ):
95
+ self._session_created_evt = asyncio.Event()
96
+ if not hasattr(self, "_session_updated_evt") or not isinstance(
97
+ getattr(self, "_session_updated_evt", None), asyncio.Event
98
+ ):
99
+ self._session_updated_evt = asyncio.Event()
100
+ headers = [
101
+ ("Authorization", f"Bearer {self.api_key}"),
102
+ ]
103
+ model = self.options.model or "gpt-realtime"
104
+ uri = f"{self.url}?model={model}"
105
+
106
+ # Determine if audio output should be configured for logging
107
+ modalities = self.options.output_modalities or ["audio", "text"]
108
+ should_configure_audio_output = "audio" in modalities
109
+
110
+ if should_configure_audio_output:
111
+ logger.info(
112
+ "Realtime WS connecting: uri=%s, input=%s@%sHz, output=%s@%sHz, voice=%s, vad=%s",
113
+ uri,
114
+ self.options.input_mime,
115
+ self.options.input_rate_hz,
116
+ self.options.output_mime,
117
+ self.options.output_rate_hz,
118
+ self.options.voice,
119
+ self.options.vad_enabled,
120
+ )
121
+ else:
122
+ logger.info(
123
+ "Realtime WS connecting: uri=%s, input=%s@%sHz, text-only output, vad=%s",
124
+ uri,
125
+ self.options.input_mime,
126
+ self.options.input_rate_hz,
127
+ self.options.vad_enabled,
128
+ )
129
+ self._ws = await websockets.connect(
130
+ uri, additional_headers=headers, max_size=None
131
+ )
132
+ logger.info("Connected to OpenAI Realtime WS: %s", uri)
133
+ self._recv_task = asyncio.create_task(self._recv_loop())
134
+
135
+ # Optionally wait briefly for session.created; some servers ignore pre-created updates
136
+ try:
137
+ await asyncio.wait_for(self._session_created_evt.wait(), timeout=1.0)
138
+ logger.info(
139
+ "Realtime WS: session.created observed before first session.update"
140
+ )
141
+ except asyncio.TimeoutError:
142
+ logger.info(
143
+ "Realtime WS: no session.created within 1.0s; sending session.update anyway"
144
+ )
145
+
146
+ # Build optional prompt block
147
+ prompt_block = None
148
+ if getattr(self.options, "prompt_id", None):
149
+ prompt_block = {
150
+ "id": self.options.prompt_id,
151
+ }
152
+ if getattr(self.options, "prompt_version", None):
153
+ prompt_block["version"] = self.options.prompt_version
154
+ if getattr(self.options, "prompt_variables", None):
155
+ prompt_block["variables"] = self.options.prompt_variables
156
+
157
+ # Configure session (instructions, tools). VAD handled per-request.
158
+ # Per server schema, turn_detection belongs under audio.input; set to None to disable.
159
+ # When VAD is disabled, explicitly set create_response to False if the server honors it
160
+ td_input = (
161
+ {"type": "server_vad", "create_response": True}
162
+ if self.options.vad_enabled
163
+ else {"type": "server_vad", "create_response": False}
164
+ )
165
+
166
+ # Build session.update per docs (nested audio object)
167
+ def _strip_tool_strict(tools_val):
168
+ try:
169
+ tools_list = list(tools_val or [])
170
+ except Exception:
171
+ return tools_val
172
+ cleaned = []
173
+ for t in tools_list:
174
+ try:
175
+ t2 = dict(t)
176
+ t2.pop("strict", None)
177
+ cleaned.append(t2)
178
+ except Exception:
179
+ cleaned.append(t)
180
+ return cleaned
181
+
182
+ # Determine if audio output should be configured
183
+ modalities = self.options.output_modalities or ["audio", "text"]
184
+ should_configure_audio_output = "audio" in modalities
185
+
186
+ # Build session.update per docs (nested audio object)
187
+ session_payload: Dict[str, Any] = {
188
+ "type": "session.update",
189
+ "session": {
190
+ "type": "realtime",
191
+ "output_modalities": modalities,
192
+ "audio": {
193
+ "input": {
194
+ "format": {
195
+ "type": self.options.input_mime or "audio/pcm",
196
+ "rate": int(self.options.input_rate_hz or 24000),
197
+ },
198
+ "turn_detection": td_input,
199
+ },
200
+ **(
201
+ {
202
+ "output": {
203
+ "format": {
204
+ "type": self.options.output_mime or "audio/pcm",
205
+ "rate": int(self.options.output_rate_hz or 24000),
206
+ },
207
+ "voice": self.options.voice,
208
+ "speed": float(
209
+ getattr(self.options, "voice_speed", 1.0) or 1.0
210
+ ),
211
+ }
212
+ }
213
+ if should_configure_audio_output
214
+ else {}
215
+ ),
216
+ },
217
+ # Note: no top-level turn_detection; nested under audio.input
218
+ **({"prompt": prompt_block} if prompt_block else {}),
219
+ "instructions": self.options.instructions or "",
220
+ **(
221
+ {"tools": _strip_tool_strict(self.options.tools)}
222
+ if self.options.tools
223
+ else {}
224
+ ),
225
+ **(
226
+ {"tool_choice": self.options.tool_choice}
227
+ if getattr(self.options, "tool_choice", None)
228
+ else {}
229
+ ),
230
+ },
231
+ }
232
+ # Optional realtime transcription configuration
233
+ try:
234
+ tr_model = getattr(self.options, "transcription_model", None)
235
+ if tr_model:
236
+ audio_obj = session_payload["session"].setdefault("audio", {})
237
+ # Attach input transcription config per GA schema
238
+ transcription_cfg: Dict[str, Any] = {"model": tr_model}
239
+ lang = getattr(self.options, "transcription_language", None)
240
+ if lang:
241
+ transcription_cfg["language"] = lang
242
+ prompt_txt = getattr(self.options, "transcription_prompt", None)
243
+ if prompt_txt is not None:
244
+ transcription_cfg["prompt"] = prompt_txt
245
+ if getattr(self.options, "transcription_include_logprobs", False):
246
+ session_payload["session"].setdefault("include", []).append(
247
+ "item.input_audio_transcription.logprobs"
248
+ )
249
+ nr = getattr(self.options, "transcription_noise_reduction", None)
250
+ if nr is not None:
251
+ audio_obj["noise_reduction"] = bool(nr)
252
+ # Place under audio.input.transcription per current server conventions
253
+ audio_obj.setdefault("input", {}).setdefault(
254
+ "transcription", transcription_cfg
255
+ )
256
+ except Exception:
257
+ logger.exception("Failed to attach transcription config to session.update")
258
+ if should_configure_audio_output:
259
+ logger.info(
260
+ "Realtime WS: sending session.update (voice=%s, vad=%s, output=%s@%s)",
261
+ self.options.voice,
262
+ self.options.vad_enabled,
263
+ (self.options.output_mime or "audio/pcm"),
264
+ int(self.options.output_rate_hz or 24000),
265
+ )
266
+ else:
267
+ logger.info(
268
+ "Realtime WS: sending session.update (text-only, vad=%s)",
269
+ self.options.vad_enabled,
270
+ )
271
+ # Log exact session.update payload and mark awaiting session.updated
272
+ try:
273
+ logger.info(
274
+ "Realtime WS: sending session.update payload=%s",
275
+ json.dumps(session_payload.get("session", {}), sort_keys=True),
276
+ )
277
+ except Exception:
278
+ pass
279
+ self._last_session_patch = session_payload.get("session", {})
280
+ self._session_updated_evt = asyncio.Event()
281
+ self._awaiting_session_updated = True
282
+ # Quick sanity warnings
283
+ try:
284
+ sess = self._last_session_patch
285
+ instr = sess.get("instructions")
286
+ voice = ((sess.get("audio") or {}).get("output") or {}).get("voice")
287
+ if instr is None or (isinstance(instr, str) and instr.strip() == ""):
288
+ logger.warning(
289
+ "Realtime WS: instructions missing/empty in session.update"
290
+ )
291
+ if not voice and should_configure_audio_output:
292
+ logger.warning("Realtime WS: voice missing in session.update")
293
+ except Exception:
294
+ pass
295
+ await self._send_tracked(session_payload, label="session.update:init")
296
+
297
+ async def close(self) -> None: # pragma: no cover
298
+ if self._ws:
299
+ await self._ws.close()
300
+ self._ws = None
301
+ if self._recv_task:
302
+ self._recv_task.cancel()
303
+ self._recv_task = None
304
+ # Unblock any pending waiters to avoid dangling tasks
305
+ try:
306
+ self._commit_evt.set()
307
+ except Exception:
308
+ pass
309
+ try:
310
+ self._session_created_evt.set()
311
+ except Exception:
312
+ pass
313
+ try:
314
+ self._session_updated_evt.set()
315
+ except Exception:
316
+ pass
317
+
318
+ async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
319
+ if not self._ws:
320
+ raise RuntimeError("WebSocket not connected")
321
+ try:
322
+ await self._ws.send(json.dumps(payload))
323
+ finally:
324
+ try:
325
+ ptype = payload.get("type")
326
+ peid = payload.get("event_id")
327
+ except Exception:
328
+ ptype = str(type(payload))
329
+ peid = None
330
+ logger.debug("WS send: %s (event_id=%s)", ptype, peid)
331
+
332
+ def _next_event_id(self) -> str:
333
+ try:
334
+ self._event_seq += 1
335
+ return f"evt-{self._event_seq}"
336
+ except Exception:
337
+ # Fallback to random when counters fail
338
+ import uuid as _uuid
339
+
340
+ return str(_uuid.uuid4())
341
+
342
+ async def _send_tracked(
343
+ self, payload: Dict[str, Any], label: Optional[str] = None
344
+ ) -> str:
345
+ """Attach an event_id and retain a snapshot for correlating error events."""
346
+ try:
347
+ eid = payload.get("event_id") or self._next_event_id()
348
+ payload["event_id"] = eid
349
+ self._sent_events[eid] = {
350
+ "label": label or (payload.get("type") or "client.event"),
351
+ "type": payload.get("type"),
352
+ "payload": payload.copy(),
353
+ "ts": asyncio.get_event_loop().time(),
354
+ }
355
+ except Exception:
356
+ eid = payload.get("event_id") or ""
357
+ await self._send(payload)
358
+ return eid
359
+
360
+ async def _recv_loop(self) -> None: # pragma: no cover
361
+ assert self._ws is not None
362
+ try:
363
+ async for raw in self._ws:
364
+ try:
365
+ data = json.loads(raw)
366
+ etype = data.get("type")
367
+ eid = data.get("event_id")
368
+ logger.debug("WS recv: %s (event_id=%s)", etype, eid)
369
+ # Track active response.id from any response.* event that carries it
370
+ try:
371
+ if isinstance(etype, str) and etype.startswith("response."):
372
+ _resp = data.get("response") or {}
373
+ _rid = _resp.get("id")
374
+ if _rid:
375
+ self._active_response_id = _rid
376
+ except Exception:
377
+ pass
378
+ # Demux streams
379
+ if etype in ("response.output_audio.delta", "response.audio.delta"):
380
+ b64 = data.get("delta") or ""
381
+ if b64:
382
+ try:
383
+ chunk = base64.b64decode(b64)
384
+ self._audio_queue.put_nowait(chunk)
385
+ # Ownership/response tagging for diagnostics
386
+ try:
387
+ owner = getattr(self, "_owner_user_id", None)
388
+ except Exception:
389
+ owner = None
390
+ try:
391
+ rid = getattr(self, "_active_response_id", None)
392
+ except Exception:
393
+ rid = None
394
+ try:
395
+ gen = int(getattr(self, "_response_generation", 0))
396
+ except Exception:
397
+ gen = None
398
+ logger.info(
399
+ "Audio delta bytes=%d owner=%s rid=%s gen=%s",
400
+ len(chunk),
401
+ owner,
402
+ rid,
403
+ gen,
404
+ )
405
+ try:
406
+ # New response detected if we were previously inactive
407
+ if not getattr(self, "_response_active", False):
408
+ self._response_generation = (
409
+ int(
410
+ getattr(self, "_response_generation", 0)
411
+ )
412
+ + 1
413
+ )
414
+ self._response_active = True
415
+ except Exception:
416
+ pass
417
+ except Exception:
418
+ pass
419
+ elif etype == "response.text.delta":
420
+ # Some servers emit generic text deltas with metadata marking transcription
421
+ metadata = data.get("response", {}).get("metadata", {})
422
+ if metadata.get("type") == "transcription":
423
+ delta = data.get("delta") or ""
424
+ if delta:
425
+ self._in_tr_queue.put_nowait(delta)
426
+ logger.info("Input transcript delta: %r", delta[:120])
427
+ else:
428
+ logger.debug("Input transcript delta: empty")
429
+ elif etype in (
430
+ "response.function_call_arguments.delta",
431
+ "response.function_call_arguments.done",
432
+ ):
433
+ # Capture streamed function-call arguments early and mark readiness on .done
434
+ try:
435
+ resp = data.get("response") or {}
436
+ rid = resp.get("id") or getattr(
437
+ self, "_active_response_id", None
438
+ )
439
+ # Many servers include call_id at top-level for these events
440
+ call_id = data.get("call_id") or data.get("id")
441
+ # Some servers include the function call item under 'item'
442
+ if not call_id:
443
+ call_id = (data.get("item") or {}).get("call_id") or (
444
+ data.get("item") or {}
445
+ ).get("id")
446
+ # Name can appear directly or under item
447
+ name = data.get("name") or (data.get("item") or {}).get(
448
+ "name"
449
+ )
450
+ if name:
451
+ self._call_names[call_id] = name
452
+ if rid and call_id:
453
+ self._call_response_ids[call_id] = rid
454
+ if etype.endswith("delta"):
455
+ delta = data.get("delta") or ""
456
+ if call_id and delta:
457
+ self._call_args_accum[call_id] = (
458
+ self._call_args_accum.get(call_id, "") + delta
459
+ )
460
+ else: # .done
461
+ if call_id:
462
+ # Mark call ready and enqueue pending execution if not already
463
+ ev = self._call_ready_events.get(call_id)
464
+ if not ev:
465
+ ev = asyncio.Event()
466
+ self._call_ready_events[call_id] = ev
467
+ ev.set()
468
+ # Register pending call using accumulated args
469
+ if call_id not in getattr(
470
+ self, "_pending_calls", {}
471
+ ):
472
+ args_text = (
473
+ self._call_args_accum.get(call_id, "{}")
474
+ or "{}"
475
+ )
476
+ if not hasattr(self, "_pending_calls"):
477
+ self._pending_calls = {}
478
+ self._pending_calls[call_id] = {
479
+ "name": self._call_names.get(call_id),
480
+ "args": args_text,
481
+ "gen": int(
482
+ getattr(self, "_response_generation", 0)
483
+ ),
484
+ "call_id": call_id,
485
+ "item_id": None,
486
+ }
487
+ self._executed_call_ids.add(call_id)
488
+ await self._execute_pending_call(call_id)
489
+ except Exception:
490
+ pass
491
+ elif etype == "response.output_text.delta":
492
+ # Assistant textual output delta. Buffer per response.id.
493
+ # Prefer the audio transcript stream for final transcript; only use text
494
+ # deltas if no audio transcript arrives.
495
+ try:
496
+ rid = (data.get("response") or {}).get("id") or getattr(
497
+ self, "_active_response_id", None
498
+ )
499
+ delta = data.get("delta") or ""
500
+ if rid and delta:
501
+ buf = self._out_text_buffers.setdefault(
502
+ rid, {"text": "", "has_audio": False}
503
+ )
504
+ # Only accumulate text when we don't yet have audio transcript
505
+ if not bool(buf.get("has_audio")):
506
+ buf["text"] = str(buf.get("text", "")) + delta
507
+ logger.debug(
508
+ "Buffered assistant text delta (rid=%s, len=%d)",
509
+ rid,
510
+ len(delta),
511
+ )
512
+ except Exception:
513
+ pass
514
+ elif etype == "conversation.item.input_audio_transcription.delta":
515
+ delta = data.get("delta") or ""
516
+ if delta:
517
+ self._in_tr_queue.put_nowait(delta)
518
+ logger.info("Input transcript delta (GA): %r", delta[:120])
519
+ else:
520
+ logger.debug("Input transcript delta (GA): empty")
521
+ elif etype in (
522
+ "response.output_audio_transcript.delta",
523
+ "response.audio_transcript.delta",
524
+ ):
525
+ # Assistant audio transcript delta (authoritative for spoken output)
526
+ try:
527
+ rid = (data.get("response") or {}).get("id") or getattr(
528
+ self, "_active_response_id", None
529
+ )
530
+ delta = data.get("delta") or ""
531
+ if rid and delta:
532
+ buf = self._out_text_buffers.setdefault(
533
+ rid, {"text": "", "has_audio": False}
534
+ )
535
+ buf["has_audio"] = True
536
+ buf["text"] = str(buf.get("text", "")) + delta
537
+ logger.debug(
538
+ "Buffered audio transcript delta (rid=%s, len=%d)",
539
+ rid,
540
+ len(delta),
541
+ )
542
+ except Exception:
543
+ pass
544
+ elif etype == "input_audio_buffer.committed":
545
+ # Track the committed audio item id for OOB transcription referencing
546
+ item_id = data.get("item_id") or data.get("id")
547
+ if item_id:
548
+ self._last_input_item_id = item_id
549
+ try:
550
+ self._last_commit_ts = asyncio.get_event_loop().time()
551
+ except Exception:
552
+ pass
553
+ logger.info(
554
+ "Realtime WS: input_audio_buffer committed: item_id=%s",
555
+ item_id,
556
+ )
557
+ # Clear commit in-flight flag on ack and signal waiters
558
+ try:
559
+ self._commit_inflight = False
560
+ except Exception:
561
+ pass
562
+ try:
563
+ self._commit_evt.set()
564
+ except Exception:
565
+ pass
566
+ elif etype in (
567
+ "response.output_audio.done",
568
+ "response.audio.done",
569
+ ):
570
+ # End of audio stream for the response; stop audio iterator but keep WS open for transcripts
571
+ try:
572
+ owner = getattr(self, "_owner_user_id", None)
573
+ except Exception:
574
+ owner = None
575
+ try:
576
+ rid = (data.get("response") or {}).get("id") or getattr(
577
+ self, "_active_response_id", None
578
+ )
579
+ except Exception:
580
+ rid = None
581
+ try:
582
+ gen = int(getattr(self, "_response_generation", 0))
583
+ except Exception:
584
+ gen = None
585
+ logger.info(
586
+ "Realtime WS: output audio done; owner=%s rid=%s gen=%s",
587
+ owner,
588
+ rid,
589
+ gen,
590
+ )
591
+ # If we have a buffered transcript for this response, flush it now
592
+ try:
593
+ rid = (data.get("response") or {}).get("id") or getattr(
594
+ self, "_active_response_id", None
595
+ )
596
+ if rid and rid in self._out_text_buffers:
597
+ final = str(
598
+ self._out_text_buffers.get(rid, {}).get("text")
599
+ or ""
600
+ )
601
+ if final:
602
+ self._out_tr_queue.put_nowait(final)
603
+ logger.debug(
604
+ "Flushed assistant transcript on audio.done (rid=%s, len=%d)",
605
+ rid,
606
+ len(final),
607
+ )
608
+ self._out_text_buffers.pop(rid, None)
609
+ except Exception:
610
+ pass
611
+ try:
612
+ self._audio_queue.put_nowait(None)
613
+ except Exception:
614
+ pass
615
+ try:
616
+ self._response_active = False
617
+ except Exception:
618
+ pass
619
+ try:
620
+ # Clear active response id when audio for that response is done
621
+ self._active_response_id = None
622
+ except Exception:
623
+ pass
624
+ # Don't break; we still want to receive input transcription events
625
+ elif etype in (
626
+ "response.completed",
627
+ "response.complete",
628
+ "response.done",
629
+ ):
630
+ # Do not terminate audio stream here; completion may be for a function_call
631
+ metadata = data.get("response", {}).get("metadata", {})
632
+ if metadata.get("type") == "transcription":
633
+ logger.info("Realtime WS: transcription response completed")
634
+ # Flush buffered assistant transcript, if any
635
+ try:
636
+ rid = (data.get("response") or {}).get("id")
637
+ if rid and rid in self._out_text_buffers:
638
+ final = str(
639
+ self._out_text_buffers.get(rid, {}).get("text")
640
+ or ""
641
+ )
642
+ if final:
643
+ self._out_tr_queue.put_nowait(final)
644
+ logger.debug(
645
+ "Flushed assistant transcript on response.done (rid=%s, len=%d)",
646
+ rid,
647
+ len(final),
648
+ )
649
+ self._out_text_buffers.pop(rid, None)
650
+ except Exception:
651
+ pass
652
+ try:
653
+ self._response_active = False
654
+ except Exception:
655
+ pass
656
+ try:
657
+ # Response lifecycle ended; clear active id
658
+ self._active_response_id = None
659
+ except Exception:
660
+ pass
661
+ elif etype in ("response.text.done", "response.output_text.done"):
662
+ metadata = data.get("response", {}).get("metadata", {})
663
+ if metadata.get("type") == "transcription":
664
+ tr = data.get("text") or ""
665
+ if tr:
666
+ try:
667
+ self._in_tr_queue.put_nowait(tr)
668
+ logger.debug(
669
+ "Input transcript completed: %r", tr[:120]
670
+ )
671
+ except Exception:
672
+ pass
673
+ else:
674
+ # For assistant text-only completions without audio, flush buffered text
675
+ try:
676
+ rid = (data.get("response") or {}).get("id") or getattr(
677
+ self, "_active_response_id", None
678
+ )
679
+ if rid and rid in self._out_text_buffers:
680
+ final = str(
681
+ self._out_text_buffers.get(rid, {}).get("text")
682
+ or ""
683
+ )
684
+ if final:
685
+ self._out_tr_queue.put_nowait(final)
686
+ logger.debug(
687
+ "Flushed assistant transcript on text.done (rid=%s, len=%d)",
688
+ rid,
689
+ len(final),
690
+ )
691
+ self._out_text_buffers.pop(rid, None)
692
+ # Always terminate the output transcript stream for this response when text-only.
693
+ try:
694
+ # Only enqueue sentinel when no audio modality is configured
695
+ modalities = (
696
+ getattr(self.options, "output_modalities", None)
697
+ or []
698
+ )
699
+ if "audio" not in modalities:
700
+ self._out_tr_queue.put_nowait(None)
701
+ logger.debug(
702
+ "Enqueued transcript termination sentinel (text-only response)"
703
+ )
704
+ except Exception:
705
+ pass
706
+ except Exception:
707
+ pass
708
+ elif (
709
+ etype == "conversation.item.input_audio_transcription.completed"
710
+ ):
711
+ tr = data.get("transcript") or ""
712
+ if tr:
713
+ try:
714
+ self._in_tr_queue.put_nowait(tr)
715
+ logger.debug(
716
+ "Input transcript completed (GA): %r", tr[:120]
717
+ )
718
+ except Exception:
719
+ pass
720
+ elif etype == "session.updated":
721
+ sess = data.get("session", {})
722
+ self._last_session_updated_payload = sess
723
+ # Track server auto-create enablement (turn_detection.create_response)
724
+ try:
725
+ td_root = sess.get("turn_detection")
726
+ td_input = (
727
+ (sess.get("audio") or {})
728
+ .get("input", {})
729
+ .get("turn_detection")
730
+ )
731
+ td = td_root if td_root is not None else td_input
732
+ # If local VAD is disabled, force auto-create disabled
733
+ if not bool(getattr(self.options, "vad_enabled", False)):
734
+ self._server_auto_create_enabled = False
735
+ else:
736
+ self._server_auto_create_enabled = bool(
737
+ isinstance(td, dict)
738
+ and bool(td.get("create_response"))
739
+ )
740
+ logger.debug(
741
+ "Realtime WS: server_auto_create_enabled=%s",
742
+ self._server_auto_create_enabled,
743
+ )
744
+ except Exception:
745
+ pass
746
+ # Mark that the latest update has been applied
747
+ try:
748
+ self._awaiting_session_updated = False
749
+ self._session_updated_evt.set()
750
+ except Exception:
751
+ pass
752
+ try:
753
+ logger.info(
754
+ "Realtime WS: session.updated payload=%s",
755
+ json.dumps(sess, sort_keys=True),
756
+ )
757
+ except Exception:
758
+ logger.info(
759
+ "Realtime WS: session updated (payload dump failed)"
760
+ )
761
+ # Extra guardrails: warn if key fields absent/empty
762
+ try:
763
+ instr = sess.get("instructions")
764
+ voice = sess.get("voice") or (
765
+ (sess.get("audio") or {}).get("output", {}).get("voice")
766
+ )
767
+ if instr is None or (
768
+ isinstance(instr, str) and instr.strip() == ""
769
+ ):
770
+ logger.warning(
771
+ "Realtime WS: session.updated has empty/missing instructions"
772
+ )
773
+ if not voice:
774
+ logger.warning(
775
+ "Realtime WS: session.updated missing voice"
776
+ )
777
+ except Exception:
778
+ pass
779
+ elif etype == "session.created":
780
+ logger.info("Realtime WS: session created")
781
+ try:
782
+ self._session_created_evt.set()
783
+ except Exception:
784
+ pass
785
+ elif etype == "error" or (
786
+ isinstance(etype, str)
787
+ and (
788
+ etype.endswith(".error")
789
+ or ("error" in etype)
790
+ or ("failed" in etype)
791
+ )
792
+ ):
793
+ # Surface server errors explicitly
794
+ try:
795
+ logger.error(
796
+ "Realtime WS error event: %s",
797
+ json.dumps(data, sort_keys=True),
798
+ )
799
+ except Exception:
800
+ logger.error(
801
+ "Realtime WS error event (payload dump failed)"
802
+ )
803
+ # Correlate to the originating client event by event_id
804
+ try:
805
+ eid = data.get("event_id") or data.get("error", {}).get(
806
+ "event_id"
807
+ )
808
+ if eid and eid in self._sent_events:
809
+ sent = self._sent_events.get(eid) or {}
810
+ logger.error(
811
+ "Realtime WS error correlated: event_id=%s sent_label=%s sent_type=%s",
812
+ eid,
813
+ sent.get("label"),
814
+ sent.get("type"),
815
+ )
816
+ except Exception:
817
+ pass
818
+ # No legacy fallback; rely on current config/state.
819
+ # Always also publish raw events
820
+ try:
821
+ self._event_queue.put_nowait(data)
822
+ except Exception:
823
+ pass
824
+ # Handle tool/function calls (GA-only triggers)
825
+ if etype in (
826
+ "conversation.item.created",
827
+ "conversation.item.added",
828
+ ):
829
+ try:
830
+ item = data.get("item") or {}
831
+ if item.get("type") == "function_call":
832
+ call_id = item.get("call_id")
833
+ item_id = item.get("id")
834
+ # Mark readiness for both identifiers if available
835
+ for cid in filter(None, [call_id, item_id]):
836
+ ev = self._call_ready_events.get(cid)
837
+ if not ev:
838
+ ev = asyncio.Event()
839
+ self._call_ready_events[cid] = ev
840
+ ev.set()
841
+ logger.debug(
842
+ "Call ready via item.%s: call_id=%s item_id=%s",
843
+ "created" if etype.endswith("created") else "added",
844
+ call_id,
845
+ item_id,
846
+ )
847
+ except Exception:
848
+ pass
849
+ elif etype in (
850
+ "response.output_item.added",
851
+ "response.output_item.created",
852
+ ):
853
+ # Map response-scoped function_call items to response_id and ack function_call_output
854
+ try:
855
+ resp = data.get("response") or {}
856
+ rid = resp.get("id") or getattr(
857
+ self, "_active_response_id", None
858
+ )
859
+ item = data.get("item") or {}
860
+ itype = item.get("type")
861
+ if itype == "function_call":
862
+ call_id = item.get("call_id")
863
+ item_id = item.get("id")
864
+ name = item.get("name")
865
+ if name and call_id:
866
+ self._call_names[call_id] = name
867
+ if call_id and item_id:
868
+ self._item_call_ids[item_id] = call_id
869
+ # Bind call identifiers to the response id for response-scoped outputs
870
+ for cid in filter(None, [call_id, item_id]):
871
+ if rid:
872
+ self._call_response_ids[cid] = rid
873
+ # Signal readiness for these identifiers
874
+ for cid in filter(None, [call_id, item_id]):
875
+ ev = self._call_ready_events.get(cid)
876
+ if not ev:
877
+ ev = asyncio.Event()
878
+ self._call_ready_events[cid] = ev
879
+ ev.set()
880
+ logger.debug(
881
+ "Mapped function_call via response.%s: call_id=%s item_id=%s response_id=%s",
882
+ "created" if etype.endswith("created") else "added",
883
+ call_id,
884
+ item_id,
885
+ rid,
886
+ )
887
+ except Exception:
888
+ pass
889
+ elif etype in (
890
+ "response.done",
891
+ "response.completed",
892
+ "response.complete",
893
+ ):
894
+ # Also detect GA function_call items in response.output
895
+ try:
896
+ resp = data.get("response", {})
897
+ rid = resp.get("id")
898
+ out_items = resp.get("output", [])
899
+ for item in out_items:
900
+ if item.get("type") == "function_call":
901
+ call_id = item.get("call_id")
902
+ item_id = item.get("id")
903
+ # Prefer a canonical id to reference later; try item_id first, then call_id
904
+ canonical_id = item_id or call_id
905
+ name = item.get("name")
906
+ args = item.get("arguments") or "{}"
907
+ status = item.get("status")
908
+ if call_id and item_id:
909
+ self._item_call_ids[item_id] = call_id
910
+ # Map this call to the response id for response-scoped outputs
911
+ for cid in filter(
912
+ None, [call_id, item_id, canonical_id]
913
+ ):
914
+ if rid:
915
+ self._call_response_ids[cid] = rid
916
+ if (
917
+ canonical_id in self._executed_call_ids
918
+ or (
919
+ call_id
920
+ and call_id in self._executed_call_ids
921
+ )
922
+ or (
923
+ item_id
924
+ and item_id in self._executed_call_ids
925
+ )
926
+ ):
927
+ continue
928
+ if status and str(status).lower() not in (
929
+ "completed",
930
+ "complete",
931
+ "done",
932
+ ):
933
+ # Wait for completion in a later event
934
+ continue
935
+ if not hasattr(self, "_pending_calls"):
936
+ self._pending_calls = {}
937
+ self._pending_calls[canonical_id] = {
938
+ "name": name,
939
+ "args": args,
940
+ # Bind this call to the current response generation
941
+ "gen": int(
942
+ getattr(self, "_response_generation", 0)
943
+ ),
944
+ # Keep both identifiers for fallback when posting outputs
945
+ "call_id": call_id,
946
+ "item_id": item_id,
947
+ }
948
+ # Mark this call as ready (server has produced the item)
949
+ for cid in filter(
950
+ None, [canonical_id, call_id, item_id]
951
+ ):
952
+ ev = self._call_ready_events.get(cid)
953
+ if not ev:
954
+ ev = asyncio.Event()
955
+ self._call_ready_events[cid] = ev
956
+ ev.set()
957
+ logger.debug(
958
+ "Call ready via response.done: call_id=%s item_id=%s canonical_id=%s",
959
+ call_id,
960
+ item_id,
961
+ canonical_id,
962
+ )
963
+ for cid in filter(
964
+ None, [canonical_id, call_id, item_id]
965
+ ):
966
+ self._executed_call_ids.add(cid)
967
+ await self._execute_pending_call(canonical_id)
968
+ except Exception:
969
+ pass
970
+
971
+ # No ack/error event mapping; server emits concrete lifecycle events
972
+ except Exception:
973
+ continue
974
+ except Exception:
975
+ logger.exception("Realtime WS receive loop error")
976
+ finally:
977
+ # Close queues gracefully
978
+ for q in (
979
+ self._audio_queue,
980
+ self._in_tr_queue,
981
+ self._out_tr_queue,
982
+ self._event_queue,
983
+ ):
984
+ try:
985
+ q.put_nowait(None) # type: ignore
986
+ except Exception:
987
+ pass
988
+
989
+ # --- Client event helpers ---
990
+ async def update_session(
991
+ self, session_patch: Dict[str, Any]
992
+ ) -> None: # pragma: no cover
993
+ # Build nested session.update per docs. Only include provided fields.
994
+ raw = dict(session_patch or {})
995
+ patch: Dict[str, Any] = {}
996
+ audio_patch: Dict[str, Any] = {}
997
+
998
+ try:
999
+ audio = dict(raw.get("audio") or {})
1000
+ # Normalize turn_detection to audio.input per server schema
1001
+ include_td = False
1002
+ turn_det = None
1003
+ inp = dict(audio.get("input") or {})
1004
+ if "turn_detection" in inp:
1005
+ include_td = True
1006
+ turn_det = inp.get("turn_detection")
1007
+ if "turn_detection" in audio:
1008
+ include_td = True
1009
+ turn_det = audio.get("turn_detection")
1010
+ if "turn_detection" in raw:
1011
+ include_td = True
1012
+ turn_det = raw.get("turn_detection")
1013
+
1014
+ inp = dict(audio.get("input") or {})
1015
+ out = dict(audio.get("output") or {})
1016
+
1017
+ # Input format
1018
+ fmt_in = inp.get("format", raw.get("input_audio_format"))
1019
+ if fmt_in is not None:
1020
+ if isinstance(fmt_in, dict):
1021
+ audio_patch.setdefault("input", {})["format"] = fmt_in
1022
+ else:
1023
+ ftype = str(fmt_in)
1024
+ if ftype == "pcm16":
1025
+ ftype = "audio/pcm"
1026
+ audio_patch.setdefault("input", {})["format"] = {
1027
+ "type": ftype,
1028
+ "rate": int(self.options.input_rate_hz or 24000),
1029
+ }
1030
+
1031
+ # Optional input extras
1032
+ for key in ("noise_reduction", "transcription"):
1033
+ if key in audio:
1034
+ audio_patch[key] = audio.get(key)
1035
+
1036
+ # Apply turn_detection under audio.input if provided (allow None to disable)
1037
+ if include_td:
1038
+ audio_patch.setdefault("input", {})["turn_detection"] = turn_det
1039
+
1040
+ # Output format/voice/speed
1041
+ op: Dict[str, Any] = {}
1042
+ fmt_out = out.get("format", raw.get("output_audio_format"))
1043
+ if fmt_out is not None:
1044
+ if isinstance(fmt_out, dict):
1045
+ op["format"] = fmt_out
1046
+ else:
1047
+ ftype = str(fmt_out)
1048
+ if ftype == "pcm16":
1049
+ ftype = "audio/pcm"
1050
+ op["format"] = {
1051
+ "type": ftype,
1052
+ "rate": int(self.options.output_rate_hz or 24000),
1053
+ }
1054
+ if "voice" in out:
1055
+ op["voice"] = out.get("voice")
1056
+ if "speed" in out:
1057
+ op["speed"] = out.get("speed")
1058
+ # Convenience: allow top-level overrides
1059
+ if "voice" in raw and "voice" not in op:
1060
+ op["voice"] = raw.get("voice")
1061
+ if "speed" in raw and "speed" not in op:
1062
+ op["speed"] = raw.get("speed")
1063
+ if op:
1064
+ audio_patch.setdefault("output", {}).update(op)
1065
+ except Exception:
1066
+ pass
1067
+
1068
+ if audio_patch:
1069
+ patch["audio"] = audio_patch
1070
+
1071
+ # No top-level turn_detection
1072
+
1073
+ def _strip_tool_strict(tools_val):
1074
+ try:
1075
+ tools_list = list(tools_val or [])
1076
+ except Exception:
1077
+ return tools_val
1078
+ cleaned = []
1079
+ for t in tools_list:
1080
+ try:
1081
+ t2 = dict(t)
1082
+ t2.pop("strict", None)
1083
+ cleaned.append(t2)
1084
+ except Exception:
1085
+ cleaned.append(t)
1086
+ return cleaned
1087
+
1088
+ # Pass through other documented fields if present
1089
+ for k in (
1090
+ "model",
1091
+ "output_modalities",
1092
+ "prompt",
1093
+ "instructions",
1094
+ "tools",
1095
+ "tool_choice",
1096
+ "include",
1097
+ "max_output_tokens",
1098
+ "tracing",
1099
+ "truncation",
1100
+ ):
1101
+ if k in raw:
1102
+ if k == "tools":
1103
+ patch[k] = _strip_tool_strict(raw[k])
1104
+ else:
1105
+ patch[k] = raw[k]
1106
+
1107
+ # --- Inject realtime transcription config if options were updated after initial connect ---
1108
+ try:
1109
+ tr_model = getattr(self.options, "transcription_model", None)
1110
+ if tr_model and isinstance(patch, dict):
1111
+ # Ensure audio/input containers exist without overwriting caller provided fields
1112
+ aud = patch.setdefault("audio", {})
1113
+ inp = aud.setdefault("input", {})
1114
+ # Only add if not explicitly provided in this patch
1115
+ if "transcription" not in inp:
1116
+ transcription_cfg: Dict[str, Any] = {"model": tr_model}
1117
+ lang = getattr(self.options, "transcription_language", None)
1118
+ if lang:
1119
+ transcription_cfg["language"] = lang
1120
+ prompt_txt = getattr(self.options, "transcription_prompt", None)
1121
+ if prompt_txt is not None:
1122
+ transcription_cfg["prompt"] = prompt_txt
1123
+ nr = getattr(self.options, "transcription_noise_reduction", None)
1124
+ if nr is not None:
1125
+ aud["noise_reduction"] = bool(nr)
1126
+ if getattr(self.options, "transcription_include_logprobs", False):
1127
+ patch.setdefault("include", [])
1128
+ if (
1129
+ "item.input_audio_transcription.logprobs"
1130
+ not in patch["include"]
1131
+ ):
1132
+ patch["include"].append(
1133
+ "item.input_audio_transcription.logprobs"
1134
+ )
1135
+ inp["transcription"] = transcription_cfg
1136
+ try:
1137
+ logger.debug(
1138
+ "Realtime WS: update_session injected transcription config model=%s",
1139
+ tr_model,
1140
+ )
1141
+ except Exception:
1142
+ pass
1143
+ except Exception:
1144
+ logger.exception(
1145
+ "Realtime WS: failed injecting transcription config in update_session"
1146
+ )
1147
+
1148
+ # Ensure tools are cleaned even if provided only under audio or elsewhere
1149
+ if "tools" in patch:
1150
+ patch["tools"] = _strip_tool_strict(patch["tools"]) # idempotent
1151
+
1152
+ # Per server requirements, always include session.type and output_modalities
1153
+ try:
1154
+ patch["type"] = "realtime"
1155
+ # Preserve caller-provided output_modalities if present, otherwise default to configured modalities
1156
+ if "output_modalities" not in patch:
1157
+ patch["output_modalities"] = self.options.output_modalities or [
1158
+ "audio",
1159
+ "text",
1160
+ ]
1161
+ except Exception:
1162
+ pass
1163
+
1164
+ payload = {"type": "session.update", "session": patch}
1165
+ # Mark awaiting updated and store last patch
1166
+ self._last_session_patch = patch or {}
1167
+ self._session_updated_evt = asyncio.Event()
1168
+ self._awaiting_session_updated = True
1169
+ # Log payload and warn if potentially clearing/omitting critical fields
1170
+ try:
1171
+ logger.info(
1172
+ "Realtime WS: sending session.update payload=%s",
1173
+ json.dumps(self._last_session_patch, sort_keys=True),
1174
+ )
1175
+ if "instructions" in self._last_session_patch and (
1176
+ (self._last_session_patch.get("instructions") or "").strip() == ""
1177
+ ):
1178
+ logger.warning(
1179
+ "Realtime WS: session.update sets empty instructions; this clears them"
1180
+ )
1181
+ out_cfg = (self._last_session_patch.get("audio") or {}).get("output") or {}
1182
+ if "voice" in out_cfg and not out_cfg.get("voice"):
1183
+ logger.warning("Realtime WS: session.update provides empty voice")
1184
+ if "instructions" not in self._last_session_patch:
1185
+ logger.warning(
1186
+ "Realtime WS: session.update omits instructions; relying on previous instructions"
1187
+ )
1188
+ except Exception:
1189
+ pass
1190
+ # Use tracked send to attach an event_id and improve diagnostics
1191
+ await self._send_tracked(payload, label="session.update:patch")
1192
+
1193
+ async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
1194
+ b64 = base64.b64encode(pcm16_bytes).decode("ascii")
1195
+ await self._send_tracked(
1196
+ {"type": "input_audio_buffer.append", "audio": b64},
1197
+ label="input_audio_buffer.append",
1198
+ )
1199
+ try:
1200
+ self._pending_input_bytes += len(pcm16_bytes)
1201
+ except Exception:
1202
+ pass
1203
+
1204
+ async def commit_input(self) -> None: # pragma: no cover
1205
+ try:
1206
+ # If a previous response is still marked active, wait briefly, then proceed.
1207
+ # Skipping commits here can cause new turns to reference old audio and repeat answers.
1208
+ if bool(getattr(self, "_response_active", False)):
1209
+ logger.warning(
1210
+ "Realtime WS: response active at commit; waiting briefly before proceeding"
1211
+ )
1212
+ for _ in range(5): # up to ~0.5s
1213
+ await asyncio.sleep(0.1)
1214
+ if not bool(getattr(self, "_response_active", False)):
1215
+ break
1216
+ # Avoid overlapping commits while awaiting server ack
1217
+ if bool(getattr(self, "_commit_inflight", False)):
1218
+ logger.warning("Realtime WS: skipping commit; commit in-flight")
1219
+ return
1220
+ # Avoid rapid duplicate commits
1221
+ last_commit = float(getattr(self, "_last_commit_ts", 0.0))
1222
+ if last_commit and (asyncio.get_event_loop().time() - last_commit) < 1.0:
1223
+ logger.warning("Realtime WS: skipping commit; committed recently")
1224
+ return
1225
+ # Require at least 100ms of audio (~4800 bytes at 24kHz mono 16-bit)
1226
+ min_bytes = int(0.1 * int(self.options.input_rate_hz or 24000) * 2)
1227
+ except Exception:
1228
+ min_bytes = 4800
1229
+ if int(getattr(self, "_pending_input_bytes", 0)) < min_bytes:
1230
+ try:
1231
+ logger.warning(
1232
+ "Realtime WS: skipping commit; buffer too small bytes=%d < %d",
1233
+ int(getattr(self, "_pending_input_bytes", 0)),
1234
+ min_bytes,
1235
+ )
1236
+ except Exception:
1237
+ pass
1238
+ return
1239
+ # Reset commit event before sending a new commit and mark as in-flight
1240
+ try:
1241
+ self._commit_evt = asyncio.Event()
1242
+ self._commit_inflight = True
1243
+ except Exception:
1244
+ pass
1245
+ await self._send_tracked(
1246
+ {"type": "input_audio_buffer.commit"}, label="input_audio_buffer.commit"
1247
+ )
1248
+ try:
1249
+ logger.info("Realtime WS: input_audio_buffer.commit sent")
1250
+ self._pending_input_bytes = 0
1251
+ self._last_commit_ts = asyncio.get_event_loop().time()
1252
+ except Exception:
1253
+ pass
1254
+
1255
+ async def clear_input(self) -> None: # pragma: no cover
1256
+ await self._send_tracked(
1257
+ {"type": "input_audio_buffer.clear"}, label="input_audio_buffer.clear"
1258
+ )
1259
+ # Reset last input reference and commit event to avoid stale references
1260
+ try:
1261
+ self._last_input_item_id = None
1262
+ self._commit_evt = asyncio.Event()
1263
+ except Exception:
1264
+ pass
1265
+
1266
+ async def create_conversation_item(
1267
+ self, item: Dict[str, Any]
1268
+ ) -> None: # pragma: no cover
1269
+ """Create a conversation item (e.g., for text input)."""
1270
+ payload = {"type": "conversation.item.create", "item": item}
1271
+ await self._send_tracked(payload, label="conversation.item.create")
1272
+
1273
+ async def create_response(
1274
+ self, response_patch: Optional[Dict[str, Any]] = None
1275
+ ) -> None: # pragma: no cover
1276
+ # Avoid duplicate responses: if server auto-creates after commit or one is already active, don't send.
1277
+ try:
1278
+ if getattr(self, "_response_active", False):
1279
+ logger.warning(
1280
+ "Realtime WS: response.create suppressed — response already active"
1281
+ )
1282
+ return
1283
+ auto = bool(getattr(self, "_server_auto_create_enabled", False))
1284
+ last_commit = float(getattr(self, "_last_commit_ts", 0.0))
1285
+ if auto and last_commit:
1286
+ # If we committed very recently (<1.0s), assume server will auto-create
1287
+ if (asyncio.get_event_loop().time() - last_commit) < 1.0:
1288
+ logger.info(
1289
+ "Realtime WS: response.create skipped — server auto-create expected"
1290
+ )
1291
+ return
1292
+ except Exception:
1293
+ pass
1294
+ # Wait briefly for commit event so we can reference the latest audio item when applicable
1295
+ if not self._last_input_item_id:
1296
+ try:
1297
+ await asyncio.wait_for(self._commit_evt.wait(), timeout=2.0)
1298
+ except asyncio.TimeoutError:
1299
+ pass
1300
+ # Ensure the latest session.update (if any) has been applied before responding
1301
+ if self._awaiting_session_updated:
1302
+ # Prefer an explicit session.updated; if absent, accept session.created
1303
+ try:
1304
+ if self._session_updated_evt.is_set():
1305
+ logger.info(
1306
+ "Realtime WS: response.create proceeding after session.updated (pre-set)"
1307
+ )
1308
+ else:
1309
+ await asyncio.wait_for(
1310
+ self._session_updated_evt.wait(), timeout=2.5
1311
+ )
1312
+ logger.info(
1313
+ "Realtime WS: response.create proceeding after session.updated"
1314
+ )
1315
+ except asyncio.TimeoutError:
1316
+ if self._session_created_evt.is_set():
1317
+ logger.info(
1318
+ "Realtime WS: response.create proceeding after session.created (no session.updated observed)"
1319
+ )
1320
+ # Best-effort: resend last session.update once more to apply voice/instructions
1321
+ try:
1322
+ if self._last_session_patch:
1323
+ logger.info(
1324
+ "Realtime WS: resending session.update to apply config before response"
1325
+ )
1326
+ # Reset awaiting flag and wait briefly again
1327
+ self._session_updated_evt = asyncio.Event()
1328
+ self._awaiting_session_updated = True
1329
+ # Ensure required session.type on retry
1330
+ _sess = dict(self._last_session_patch or {})
1331
+ _sess["type"] = "realtime"
1332
+ await self._send(
1333
+ {"type": "session.update", "session": _sess}
1334
+ )
1335
+ try:
1336
+ await asyncio.wait_for(
1337
+ self._session_updated_evt.wait(), timeout=1.0
1338
+ )
1339
+ logger.info(
1340
+ "Realtime WS: proceeding after retry session.updated"
1341
+ )
1342
+ except asyncio.TimeoutError:
1343
+ logger.warning(
1344
+ "Realtime WS: retry session.update did not yield session.updated in time"
1345
+ )
1346
+ except Exception:
1347
+ pass
1348
+ else:
1349
+ try:
1350
+ await asyncio.wait_for(
1351
+ self._session_created_evt.wait(), timeout=2.5
1352
+ )
1353
+ logger.info(
1354
+ "Realtime WS: response.create proceeding after session.created"
1355
+ )
1356
+ except asyncio.TimeoutError:
1357
+ logger.warning(
1358
+ "Realtime WS: neither session.updated nor session.created received in time; proceeding"
1359
+ )
1360
+
1361
+ # Then, create main response
1362
+ payload: Dict[str, Any] = {"type": "response.create"}
1363
+ if response_patch:
1364
+ payload["response"] = response_patch
1365
+ # Ensure response object exists; rely on session defaults for modalities/audio
1366
+ if "response" not in payload:
1367
+ payload["response"] = {}
1368
+ rp = payload["response"]
1369
+ # Sanitize unsupported fields that servers may reject
1370
+ try:
1371
+ rp.pop("modalities", None)
1372
+ rp.pop("audio", None)
1373
+ except Exception:
1374
+ pass
1375
+ rp.setdefault("metadata", {"type": "response"})
1376
+ # Attach input reference so the model links this response to last audio
1377
+ if self._last_input_item_id and "input" not in rp:
1378
+ rp["input"] = [{"type": "item_reference", "id": self._last_input_item_id}]
1379
+ try:
1380
+ has_ref = bool(self._last_input_item_id)
1381
+ logger.info(
1382
+ "Realtime WS: sending response.create (input_ref=%s)",
1383
+ has_ref,
1384
+ )
1385
+ except Exception:
1386
+ pass
1387
+ # Increment response generation when we intentionally start a new response
1388
+ try:
1389
+ if not getattr(self, "_response_active", False):
1390
+ self._response_generation = (
1391
+ int(getattr(self, "_response_generation", 0)) + 1
1392
+ )
1393
+ except Exception:
1394
+ pass
1395
+ await self._send_tracked(payload, label="response.create")
1396
+ try:
1397
+ self._response_active = True
1398
+ except Exception:
1399
+ pass
1400
+
1401
+ # --- Streams ---
1402
+ async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
1403
+ while True:
1404
+ item = await q.get()
1405
+ if item is None:
1406
+ break
1407
+ yield item
1408
+
1409
+ def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
1410
+ return self._iter_queue(self._event_queue)
1411
+
1412
+ def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
1413
+ return self._iter_queue(self._audio_queue)
1414
+
1415
+ def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1416
+ return self._iter_queue(self._in_tr_queue)
1417
+
1418
+ def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1419
+ return self._iter_queue(self._out_tr_queue)
1420
+
1421
+ def set_tool_executor(self, executor): # pragma: no cover
1422
+ self._tool_executor = executor
1423
+
1424
+ def reset_output_stream(self) -> None: # pragma: no cover
1425
+ """Drain any queued output audio and clear per-response text buffers.
1426
+ This avoids replaying stale audio if the client failed to consume previous chunks."""
1427
+ try:
1428
+ while True:
1429
+ try:
1430
+ _ = self._audio_queue.get_nowait()
1431
+ except asyncio.QueueEmpty:
1432
+ break
1433
+ except Exception:
1434
+ break
1435
+ try:
1436
+ self._out_text_buffers.clear()
1437
+ except Exception:
1438
+ pass
1439
+ except Exception:
1440
+ pass
1441
+
1442
+ # Expose whether a function/tool call is currently pending
1443
+ def has_pending_tool_call(self) -> bool: # pragma: no cover
1444
+ try:
1445
+ return (
1446
+ bool(getattr(self, "_pending_calls", {}))
1447
+ or int(getattr(self, "_active_tool_calls", 0)) > 0
1448
+ )
1449
+ except Exception:
1450
+ return False
1451
+
1452
+ # --- Internal helpers for GA tool execution ---
1453
+ async def _execute_pending_call(self, call_id: Optional[str]) -> None:
1454
+ if not call_id:
1455
+ return
1456
+ # Peek without popping so we remain in a "pending/active" state
1457
+ pc = getattr(self, "_pending_calls", {}).get(call_id)
1458
+ if not pc or not self._tool_executor or not pc.get("name"):
1459
+ return
1460
+ try:
1461
+ # Drop if this call was bound to a previous response generation
1462
+ try:
1463
+ call_gen = int(pc.get("gen", 0))
1464
+ cur_gen = int(getattr(self, "_response_generation", 0))
1465
+ if call_gen and call_gen != cur_gen:
1466
+ logger.warning(
1467
+ "Skipping stale tool call: id=%s name=%s call_gen=%d cur_gen=%d",
1468
+ call_id,
1469
+ pc.get("name"),
1470
+ call_gen,
1471
+ cur_gen,
1472
+ )
1473
+ return
1474
+ except Exception:
1475
+ pass
1476
+ # Mark as active to keep timeouts from firing while tool runs
1477
+ try:
1478
+ self._active_tool_calls += 1
1479
+ except Exception:
1480
+ pass
1481
+ args_preview_len = len((pc.get("args") or ""))
1482
+ logger.info(
1483
+ "Executing tool: id=%s name=%s args_len=%d",
1484
+ call_id,
1485
+ pc.get("name"),
1486
+ args_preview_len,
1487
+ )
1488
+ args = pc.get("args") or "{}"
1489
+ try:
1490
+ parsed = json.loads(args)
1491
+ except Exception:
1492
+ parsed = {}
1493
+ start_ts = asyncio.get_event_loop().time()
1494
+ timeout_s = float(getattr(self.options, "tool_timeout_s", 300.0) or 300.0)
1495
+ try:
1496
+ result = await asyncio.wait_for(
1497
+ self._tool_executor(pc["name"], parsed), timeout=timeout_s
1498
+ )
1499
+ except asyncio.TimeoutError:
1500
+ logger.warning(
1501
+ "Tool timeout: id=%s name=%s exceeded %.1fs",
1502
+ call_id,
1503
+ pc.get("name"),
1504
+ timeout_s,
1505
+ )
1506
+ result = {"error": "tool_timeout"}
1507
+ dur = asyncio.get_event_loop().time() - start_ts
1508
+ try:
1509
+ result_summary = (
1510
+ f"keys={list(result.keys())[:5]}"
1511
+ if isinstance(result, dict)
1512
+ else type(result).__name__
1513
+ )
1514
+ except Exception:
1515
+ result_summary = "<unavailable>"
1516
+ logger.info(
1517
+ "Tool done: id=%s name=%s dur=%.2fs result=%s",
1518
+ call_id,
1519
+ pc.get("name"),
1520
+ dur,
1521
+ result_summary,
1522
+ )
1523
+ # Ensure the server has created the function_call item before we post output
1524
+ try:
1525
+ ev = self._call_ready_events.get(call_id)
1526
+ if ev:
1527
+ try:
1528
+ await asyncio.wait_for(ev.wait(), timeout=1.5)
1529
+ except asyncio.TimeoutError:
1530
+ logger.debug(
1531
+ "Call ready wait timed out; proceeding anyway: call_id=%s",
1532
+ call_id,
1533
+ )
1534
+ # Tiny jitter to help ordering on the server
1535
+ await asyncio.sleep(0.03)
1536
+ except Exception:
1537
+ pass
1538
+
1539
+ # Send tool result via conversation.item.create, then trigger response.create (per docs)
1540
+ try:
1541
+ # Derive a valid call_id and avoid sending item_id as call_id
1542
+ derived_call_id = (
1543
+ pc.get("call_id")
1544
+ or self._item_call_ids.get(pc.get("item_id"))
1545
+ or (
1546
+ call_id
1547
+ if isinstance(call_id, str) and call_id.startswith("call_")
1548
+ else None
1549
+ )
1550
+ )
1551
+ if not derived_call_id:
1552
+ logger.error(
1553
+ "Cannot send function_call_output: missing call_id (id=%s name=%s)",
1554
+ call_id,
1555
+ pc.get("name"),
1556
+ )
1557
+ else:
1558
+ await self._send_tracked(
1559
+ {
1560
+ "type": "conversation.item.create",
1561
+ "item": {
1562
+ "type": "function_call_output",
1563
+ "call_id": derived_call_id,
1564
+ "output": json.dumps(result),
1565
+ },
1566
+ },
1567
+ label="conversation.item.create:function_call_output",
1568
+ )
1569
+ logger.info(
1570
+ "conversation.item.create(function_call_output) sent call_id=%s",
1571
+ derived_call_id,
1572
+ )
1573
+ except Exception:
1574
+ logger.exception(
1575
+ "Failed to send function_call_output for call_id=%s", call_id
1576
+ )
1577
+ try:
1578
+ await asyncio.sleep(0.02)
1579
+ await self._send_tracked(
1580
+ {
1581
+ "type": "response.create",
1582
+ "response": {"metadata": {"type": "response"}},
1583
+ },
1584
+ label="response.create:after_tool",
1585
+ )
1586
+ logger.info("response.create sent after tool output")
1587
+ except Exception:
1588
+ logger.exception(
1589
+ "Failed to send follow-up response.create after tool output"
1590
+ )
1591
+ # Cleanup readiness event
1592
+ try:
1593
+ self._call_ready_events.pop(call_id, None)
1594
+ except Exception:
1595
+ pass
1596
+ except Exception:
1597
+ logger.exception(
1598
+ "Tool execution raised unexpectedly for call_id=%s", call_id
1599
+ )
1600
+ finally:
1601
+ # Clear pending state and decrement active count
1602
+ try:
1603
+ getattr(self, "_pending_calls", {}).pop(call_id, None)
1604
+ except Exception:
1605
+ pass
1606
+ try:
1607
+ self._active_tool_calls = max(0, int(self._active_tool_calls) - 1)
1608
+ logger.debug(
1609
+ "Pending tool calls decremented; active=%d", self._active_tool_calls
1610
+ )
1611
+ except Exception:
1612
+ pass
1613
+
1614
+
1615
+ class OpenAITranscriptionWebSocketSession(BaseRealtimeSession):
1616
+ """OpenAI Realtime Transcription WebSocket session.
1617
+
1618
+ This session is transcription-only per GA docs. It accepts PCM16 input and emits
1619
+ conversation.item.input_audio_transcription.* events.
1620
+ """
1621
+
1622
+ def __init__(
1623
+ self,
1624
+ api_key: str,
1625
+ url: str = "wss://api.openai.com/v1/realtime",
1626
+ options: Optional[RealtimeSessionOptions] = None,
1627
+ ) -> None:
1628
+ self.api_key = api_key
1629
+ self.url = url
1630
+ self.options = options or RealtimeSessionOptions()
1631
+ self._ws = None
1632
+ self._event_queue = asyncio.Queue()
1633
+ self._in_tr_queue = asyncio.Queue()
1634
+ self._recv_task = None
1635
+ self._last_input_item_id: Optional[str] = None
1636
+
1637
+ async def connect(self) -> None: # pragma: no cover
1638
+ headers = [("Authorization", f"Bearer {self.api_key}")]
1639
+ # Model is for TTS session; transcription model is set in session update
1640
+ model = self.options.model or "gpt-realtime"
1641
+ uri = f"{self.url}?model={model}"
1642
+ logger.info("Transcription WS connecting: uri=%s", uri)
1643
+ self._ws = await websockets.connect(
1644
+ uri, additional_headers=headers, max_size=None
1645
+ )
1646
+ self._recv_task = asyncio.create_task(self._recv_loop())
1647
+
1648
+ # Transcription session config per GA
1649
+ ts_payload: Dict[str, Any] = {
1650
+ "type": "transcription_session.update",
1651
+ "input_audio_format": "pcm16",
1652
+ "input_audio_transcription": {
1653
+ "model": getattr(self.options, "transcribe_model", None)
1654
+ or "gpt-4o-mini-transcribe",
1655
+ **(
1656
+ {"prompt": getattr(self.options, "transcribe_prompt", "")}
1657
+ if getattr(self.options, "transcribe_prompt", None) is not None
1658
+ else {}
1659
+ ),
1660
+ **(
1661
+ {"language": getattr(self.options, "transcribe_language", "en")}
1662
+ if getattr(self.options, "transcribe_language", None) is not None
1663
+ else {}
1664
+ ),
1665
+ },
1666
+ "turn_detection": (
1667
+ {"type": "server_vad"} if self.options.vad_enabled else None
1668
+ ),
1669
+ # Optionally include extra properties (e.g., logprobs)
1670
+ # "include": ["item.input_audio_transcription.logprobs"],
1671
+ }
1672
+ logger.info("Transcription WS: sending transcription_session.update")
1673
+ await self._send(ts_payload)
1674
+
1675
+ async def close(self) -> None: # pragma: no cover
1676
+ if self._ws:
1677
+ await self._ws.close()
1678
+ self._ws = None
1679
+ if self._recv_task:
1680
+ self._recv_task.cancel()
1681
+ self._recv_task = None
1682
+
1683
+ async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
1684
+ if not self._ws:
1685
+ raise RuntimeError("WebSocket not connected")
1686
+ await self._ws.send(json.dumps(payload))
1687
+
1688
+ async def _recv_loop(self) -> None: # pragma: no cover
1689
+ assert self._ws is not None
1690
+ try:
1691
+ async for raw in self._ws:
1692
+ try:
1693
+ data = json.loads(raw)
1694
+ etype = data.get("type")
1695
+ # Temporarily log at INFO to diagnose missing events
1696
+ logger.info("Transcription WS recv: %s", etype)
1697
+ if etype == "input_audio_buffer.committed":
1698
+ self._last_input_item_id = data.get("item_id") or data.get("id")
1699
+ if self._last_input_item_id:
1700
+ logger.info(
1701
+ "Transcription WS: input_audio_buffer committed: item_id=%s",
1702
+ self._last_input_item_id,
1703
+ )
1704
+ elif etype in (
1705
+ "conversation.item.input_audio_transcription.delta",
1706
+ "input_audio_transcription.delta",
1707
+ "response.input_audio_transcription.delta",
1708
+ ) or (
1709
+ isinstance(etype, str)
1710
+ and etype.endswith("input_audio_transcription.delta")
1711
+ ):
1712
+ delta = data.get("delta") or ""
1713
+ if delta:
1714
+ self._in_tr_queue.put_nowait(delta)
1715
+ logger.info("Transcription delta: %r", delta[:120])
1716
+ elif etype in (
1717
+ "conversation.item.input_audio_transcription.completed",
1718
+ "input_audio_transcription.completed",
1719
+ "response.input_audio_transcription.completed",
1720
+ ) or (
1721
+ isinstance(etype, str)
1722
+ and etype.endswith("input_audio_transcription.completed")
1723
+ ):
1724
+ tr = data.get("transcript") or ""
1725
+ if tr:
1726
+ self._in_tr_queue.put_nowait(tr)
1727
+ logger.debug("Transcription completed: %r", tr[:120])
1728
+ # Always publish raw events
1729
+ try:
1730
+ self._event_queue.put_nowait(data)
1731
+ except Exception:
1732
+ pass
1733
+ except Exception:
1734
+ continue
1735
+ except Exception:
1736
+ logger.exception("Transcription WS receive loop error")
1737
+ finally:
1738
+ for q in (self._in_tr_queue, self._event_queue):
1739
+ try:
1740
+ q.put_nowait(None) # type: ignore
1741
+ except Exception:
1742
+ pass
1743
+
1744
+ # --- Client events ---
1745
+ async def update_session(
1746
+ self, session_patch: Dict[str, Any]
1747
+ ) -> None: # pragma: no cover
1748
+ # Allow updating transcription session fields
1749
+ patch = {"type": "transcription_session.update", **session_patch}
1750
+ await self._send(patch)
1751
+
1752
+ async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
1753
+ b64 = base64.b64encode(pcm16_bytes).decode("ascii")
1754
+ await self._send({"type": "input_audio_buffer.append", "audio": b64})
1755
+ logger.info("Transcription WS: appended bytes=%d", len(pcm16_bytes))
1756
+
1757
+ async def commit_input(self) -> None: # pragma: no cover
1758
+ await self._send({"type": "input_audio_buffer.commit"})
1759
+ logger.info("Transcription WS: input_audio_buffer.commit sent")
1760
+
1761
+ async def clear_input(self) -> None: # pragma: no cover
1762
+ await self._send({"type": "input_audio_buffer.clear"})
1763
+
1764
+ async def create_conversation_item(
1765
+ self, item: Dict[str, Any]
1766
+ ) -> None: # pragma: no cover
1767
+ """Create a conversation item (e.g., for text input)."""
1768
+ payload = {"type": "conversation.item.create", "item": item}
1769
+ await self._send_tracked(payload, label="conversation.item.create")
1770
+
1771
+ async def create_response(
1772
+ self, response_patch: Optional[Dict[str, Any]] = None
1773
+ ) -> None: # pragma: no cover
1774
+ # No responses in transcription session
1775
+ return
1776
+
1777
+ # --- Streams ---
1778
+ async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
1779
+ while True:
1780
+ item = await q.get()
1781
+ if item is None:
1782
+ break
1783
+ yield item
1784
+
1785
+ def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
1786
+ return self._iter_queue(self._event_queue)
1787
+
1788
+ def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
1789
+ # No audio in transcription session
1790
+ async def _empty():
1791
+ if False:
1792
+ yield b""
1793
+
1794
+ return _empty()
1795
+
1796
+ def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1797
+ return self._iter_queue(self._in_tr_queue)
1798
+
1799
+ def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
1800
+ # No assistant transcript in transcription-only mode
1801
+ async def _empty():
1802
+ if False:
1803
+ yield ""
1804
+
1805
+ return _empty()
1806
+
1807
+ def set_tool_executor(self, executor): # pragma: no cover
1808
+ # Not applicable for transcription-only
1809
+ return
1810
+
1811
+ def reset_output_stream(self) -> None: # pragma: no cover
1812
+ # No audio output stream to reset
1813
+ return