solana-agent 31.1.7__py3-none-any.whl → 31.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solana_agent/adapters/ffmpeg_transcoder.py +282 -0
- solana_agent/adapters/openai_adapter.py +5 -0
- solana_agent/adapters/openai_realtime_ws.py +1613 -0
- solana_agent/client/solana_agent.py +29 -3
- solana_agent/factories/agent_factory.py +2 -1
- solana_agent/interfaces/client/client.py +18 -1
- solana_agent/interfaces/providers/audio.py +40 -0
- solana_agent/interfaces/providers/llm.py +0 -1
- solana_agent/interfaces/providers/realtime.py +100 -0
- solana_agent/interfaces/services/agent.py +0 -1
- solana_agent/interfaces/services/query.py +12 -1
- solana_agent/repositories/memory.py +184 -19
- solana_agent/services/agent.py +0 -5
- solana_agent/services/query.py +561 -6
- solana_agent/services/realtime.py +506 -0
- {solana_agent-31.1.7.dist-info → solana_agent-31.2.1.dist-info}/METADATA +40 -8
- {solana_agent-31.1.7.dist-info → solana_agent-31.2.1.dist-info}/RECORD +20 -15
- {solana_agent-31.1.7.dist-info → solana_agent-31.2.1.dist-info}/LICENSE +0 -0
- {solana_agent-31.1.7.dist-info → solana_agent-31.2.1.dist-info}/WHEEL +0 -0
- {solana_agent-31.1.7.dist-info → solana_agent-31.2.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1613 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
from typing import Any, AsyncGenerator, Dict, Optional
|
7
|
+
import asyncio
|
8
|
+
|
9
|
+
import websockets
|
10
|
+
|
11
|
+
from solana_agent.interfaces.providers.realtime import (
|
12
|
+
BaseRealtimeSession,
|
13
|
+
RealtimeSessionOptions,
|
14
|
+
)
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class OpenAIRealtimeWebSocketSession(BaseRealtimeSession):
|
20
|
+
"""OpenAI Realtime WebSocket session (server-to-server) with audio in/out.
|
21
|
+
|
22
|
+
Notes:
|
23
|
+
- Expects 24kHz audio.
|
24
|
+
- Emits raw PCM16 bytes for output audio deltas.
|
25
|
+
- Provides separate async generators for input/output transcripts.
|
26
|
+
- You can toggle VAD via session.update and manually commit when disabled.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
api_key: str,
|
32
|
+
url: str = "wss://api.openai.com/v1/realtime",
|
33
|
+
options: Optional[RealtimeSessionOptions] = None,
|
34
|
+
) -> None:
|
35
|
+
self.api_key = api_key
|
36
|
+
self.url = url
|
37
|
+
self.options = options or RealtimeSessionOptions()
|
38
|
+
|
39
|
+
# Queues and state
|
40
|
+
self._ws = None
|
41
|
+
self._event_queue = asyncio.Queue()
|
42
|
+
self._audio_queue = asyncio.Queue()
|
43
|
+
self._in_tr_queue = asyncio.Queue()
|
44
|
+
self._out_tr_queue = asyncio.Queue()
|
45
|
+
|
46
|
+
# Tool/function state
|
47
|
+
self._recv_task = None
|
48
|
+
self._tool_executor = None
|
49
|
+
self._pending_calls = {}
|
50
|
+
self._active_tool_calls = 0
|
51
|
+
# Map call_id or item_id -> asyncio.Event when server has created the function_call item
|
52
|
+
self._call_ready_events = {}
|
53
|
+
# Track identifiers already executed to avoid duplicates
|
54
|
+
self._executed_call_ids = set()
|
55
|
+
# Track mapping from tool call identifiers to the originating response.id
|
56
|
+
# Prefer addressing function_call_output to a response to avoid conversation lookup races
|
57
|
+
self._call_response_ids = {}
|
58
|
+
# Accumulate function call arguments when streamed
|
59
|
+
self._call_args_accum = {}
|
60
|
+
self._call_names = {}
|
61
|
+
# Map from function_call item_id -> call_id for reliable output addressing
|
62
|
+
self._item_call_ids = {}
|
63
|
+
|
64
|
+
# Input/response state
|
65
|
+
self._pending_input_bytes = 0
|
66
|
+
self._commit_evt = asyncio.Event()
|
67
|
+
self._commit_inflight = False
|
68
|
+
self._response_active = False
|
69
|
+
self._server_auto_create_enabled = False
|
70
|
+
self._last_commit_ts = 0.0
|
71
|
+
|
72
|
+
# Session/event tracking
|
73
|
+
self._session_created_evt = asyncio.Event()
|
74
|
+
self._session_updated_evt = asyncio.Event()
|
75
|
+
self._awaiting_session_updated = False
|
76
|
+
self._last_session_patch = {}
|
77
|
+
self._last_session_updated_payload = {}
|
78
|
+
self._last_input_item_id = None
|
79
|
+
# Response generation tracking to bind tool outputs to the active response
|
80
|
+
self._response_generation = 0
|
81
|
+
# Track the currently active response.id (fallback when some events omit it)
|
82
|
+
self._active_response_id = None
|
83
|
+
# Accumulate assistant output text by response.id and flush on response completion
|
84
|
+
self._out_text_buffers = {}
|
85
|
+
|
86
|
+
# Outbound event correlation
|
87
|
+
self._event_seq = 0
|
88
|
+
self._sent_events = {}
|
89
|
+
|
90
|
+
async def connect(self) -> None: # pragma: no cover
|
91
|
+
# Defensive: ensure session events exist even if __init__ didn’t set them (older builds)
|
92
|
+
if not hasattr(self, "_session_created_evt") or not isinstance(
|
93
|
+
getattr(self, "_session_created_evt", None), asyncio.Event
|
94
|
+
):
|
95
|
+
self._session_created_evt = asyncio.Event()
|
96
|
+
if not hasattr(self, "_session_updated_evt") or not isinstance(
|
97
|
+
getattr(self, "_session_updated_evt", None), asyncio.Event
|
98
|
+
):
|
99
|
+
self._session_updated_evt = asyncio.Event()
|
100
|
+
headers = [
|
101
|
+
("Authorization", f"Bearer {self.api_key}"),
|
102
|
+
]
|
103
|
+
model = self.options.model or "gpt-realtime"
|
104
|
+
uri = f"{self.url}?model={model}"
|
105
|
+
logger.info(
|
106
|
+
"Realtime WS connecting: uri=%s, input=%s@%sHz, output=%s@%sHz, voice=%s, vad=%s",
|
107
|
+
uri,
|
108
|
+
self.options.input_mime,
|
109
|
+
self.options.input_rate_hz,
|
110
|
+
self.options.output_mime,
|
111
|
+
self.options.output_rate_hz,
|
112
|
+
self.options.voice,
|
113
|
+
self.options.vad_enabled,
|
114
|
+
)
|
115
|
+
self._ws = await websockets.connect(
|
116
|
+
uri, additional_headers=headers, max_size=None
|
117
|
+
)
|
118
|
+
logger.info("Connected to OpenAI Realtime WS: %s", uri)
|
119
|
+
self._recv_task = asyncio.create_task(self._recv_loop())
|
120
|
+
|
121
|
+
# Optionally wait briefly for session.created; some servers ignore pre-created updates
|
122
|
+
try:
|
123
|
+
await asyncio.wait_for(self._session_created_evt.wait(), timeout=1.0)
|
124
|
+
logger.info(
|
125
|
+
"Realtime WS: session.created observed before first session.update"
|
126
|
+
)
|
127
|
+
except asyncio.TimeoutError:
|
128
|
+
logger.info(
|
129
|
+
"Realtime WS: no session.created within 1.0s; sending session.update anyway"
|
130
|
+
)
|
131
|
+
|
132
|
+
# Build optional prompt block
|
133
|
+
prompt_block = None
|
134
|
+
if getattr(self.options, "prompt_id", None):
|
135
|
+
prompt_block = {
|
136
|
+
"id": self.options.prompt_id,
|
137
|
+
}
|
138
|
+
if getattr(self.options, "prompt_version", None):
|
139
|
+
prompt_block["version"] = self.options.prompt_version
|
140
|
+
if getattr(self.options, "prompt_variables", None):
|
141
|
+
prompt_block["variables"] = self.options.prompt_variables
|
142
|
+
|
143
|
+
# Configure session (instructions, tools). VAD handled per-request.
|
144
|
+
# Per server schema, turn_detection belongs under audio.input; set to None to disable.
|
145
|
+
# When VAD is disabled, explicitly set create_response to False if the server honors it
|
146
|
+
td_input = (
|
147
|
+
{"type": "server_vad", "create_response": True}
|
148
|
+
if self.options.vad_enabled
|
149
|
+
else {"type": "server_vad", "create_response": False}
|
150
|
+
)
|
151
|
+
|
152
|
+
# Build session.update per docs (nested audio object)
|
153
|
+
def _strip_tool_strict(tools_val):
|
154
|
+
try:
|
155
|
+
tools_list = list(tools_val or [])
|
156
|
+
except Exception:
|
157
|
+
return tools_val
|
158
|
+
cleaned = []
|
159
|
+
for t in tools_list:
|
160
|
+
try:
|
161
|
+
t2 = dict(t)
|
162
|
+
t2.pop("strict", None)
|
163
|
+
cleaned.append(t2)
|
164
|
+
except Exception:
|
165
|
+
cleaned.append(t)
|
166
|
+
return cleaned
|
167
|
+
|
168
|
+
session_payload: Dict[str, Any] = {
|
169
|
+
"type": "session.update",
|
170
|
+
"session": {
|
171
|
+
"type": "realtime",
|
172
|
+
"output_modalities": ["audio"],
|
173
|
+
"audio": {
|
174
|
+
"input": {
|
175
|
+
"format": {
|
176
|
+
"type": self.options.input_mime or "audio/pcm",
|
177
|
+
"rate": int(self.options.input_rate_hz or 24000),
|
178
|
+
},
|
179
|
+
"turn_detection": td_input,
|
180
|
+
},
|
181
|
+
"output": {
|
182
|
+
"format": {
|
183
|
+
"type": self.options.output_mime or "audio/pcm",
|
184
|
+
"rate": int(self.options.output_rate_hz or 24000),
|
185
|
+
},
|
186
|
+
"voice": self.options.voice,
|
187
|
+
"speed": float(
|
188
|
+
getattr(self.options, "voice_speed", 1.0) or 1.0
|
189
|
+
),
|
190
|
+
},
|
191
|
+
},
|
192
|
+
# Note: no top-level turn_detection; nested under audio.input
|
193
|
+
**({"prompt": prompt_block} if prompt_block else {}),
|
194
|
+
"instructions": self.options.instructions or "",
|
195
|
+
**(
|
196
|
+
{"tools": _strip_tool_strict(self.options.tools)}
|
197
|
+
if self.options.tools
|
198
|
+
else {}
|
199
|
+
),
|
200
|
+
**(
|
201
|
+
{"tool_choice": self.options.tool_choice}
|
202
|
+
if getattr(self.options, "tool_choice", None)
|
203
|
+
else {}
|
204
|
+
),
|
205
|
+
},
|
206
|
+
}
|
207
|
+
logger.info(
|
208
|
+
"Realtime WS: sending session.update (voice=%s, vad=%s, output=%s@%s)",
|
209
|
+
self.options.voice,
|
210
|
+
self.options.vad_enabled,
|
211
|
+
(self.options.output_mime or "audio/pcm"),
|
212
|
+
int(self.options.output_rate_hz or 24000),
|
213
|
+
)
|
214
|
+
# Log exact session.update payload and mark awaiting session.updated
|
215
|
+
try:
|
216
|
+
logger.info(
|
217
|
+
"Realtime WS: sending session.update payload=%s",
|
218
|
+
json.dumps(session_payload.get("session", {}), sort_keys=True),
|
219
|
+
)
|
220
|
+
except Exception:
|
221
|
+
pass
|
222
|
+
self._last_session_patch = session_payload.get("session", {})
|
223
|
+
self._session_updated_evt = asyncio.Event()
|
224
|
+
self._awaiting_session_updated = True
|
225
|
+
# Quick sanity warnings
|
226
|
+
try:
|
227
|
+
sess = self._last_session_patch
|
228
|
+
instr = sess.get("instructions")
|
229
|
+
voice = ((sess.get("audio") or {}).get("output") or {}).get("voice")
|
230
|
+
if instr is None or (isinstance(instr, str) and instr.strip() == ""):
|
231
|
+
logger.warning(
|
232
|
+
"Realtime WS: instructions missing/empty in session.update"
|
233
|
+
)
|
234
|
+
if not voice:
|
235
|
+
logger.warning("Realtime WS: voice missing in session.update")
|
236
|
+
except Exception:
|
237
|
+
pass
|
238
|
+
await self._send_tracked(session_payload, label="session.update:init")
|
239
|
+
|
240
|
+
async def close(self) -> None: # pragma: no cover
|
241
|
+
if self._ws:
|
242
|
+
await self._ws.close()
|
243
|
+
self._ws = None
|
244
|
+
if self._recv_task:
|
245
|
+
self._recv_task.cancel()
|
246
|
+
self._recv_task = None
|
247
|
+
# Unblock any pending waiters to avoid dangling tasks
|
248
|
+
try:
|
249
|
+
self._commit_evt.set()
|
250
|
+
except Exception:
|
251
|
+
pass
|
252
|
+
try:
|
253
|
+
self._session_created_evt.set()
|
254
|
+
except Exception:
|
255
|
+
pass
|
256
|
+
try:
|
257
|
+
self._session_updated_evt.set()
|
258
|
+
except Exception:
|
259
|
+
pass
|
260
|
+
|
261
|
+
async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
|
262
|
+
if not self._ws:
|
263
|
+
raise RuntimeError("WebSocket not connected")
|
264
|
+
try:
|
265
|
+
await self._ws.send(json.dumps(payload))
|
266
|
+
finally:
|
267
|
+
try:
|
268
|
+
ptype = payload.get("type")
|
269
|
+
peid = payload.get("event_id")
|
270
|
+
except Exception:
|
271
|
+
ptype = str(type(payload))
|
272
|
+
peid = None
|
273
|
+
logger.debug("WS send: %s (event_id=%s)", ptype, peid)
|
274
|
+
|
275
|
+
def _next_event_id(self) -> str:
|
276
|
+
try:
|
277
|
+
self._event_seq += 1
|
278
|
+
return f"evt-{self._event_seq}"
|
279
|
+
except Exception:
|
280
|
+
# Fallback to random when counters fail
|
281
|
+
import uuid as _uuid
|
282
|
+
|
283
|
+
return str(_uuid.uuid4())
|
284
|
+
|
285
|
+
async def _send_tracked(
|
286
|
+
self, payload: Dict[str, Any], label: Optional[str] = None
|
287
|
+
) -> str:
|
288
|
+
"""Attach an event_id and retain a snapshot for correlating error events."""
|
289
|
+
try:
|
290
|
+
eid = payload.get("event_id") or self._next_event_id()
|
291
|
+
payload["event_id"] = eid
|
292
|
+
self._sent_events[eid] = {
|
293
|
+
"label": label or (payload.get("type") or "client.event"),
|
294
|
+
"type": payload.get("type"),
|
295
|
+
"payload": payload.copy(),
|
296
|
+
"ts": asyncio.get_event_loop().time(),
|
297
|
+
}
|
298
|
+
except Exception:
|
299
|
+
eid = payload.get("event_id") or ""
|
300
|
+
await self._send(payload)
|
301
|
+
return eid
|
302
|
+
|
303
|
+
async def _recv_loop(self) -> None: # pragma: no cover
|
304
|
+
assert self._ws is not None
|
305
|
+
try:
|
306
|
+
async for raw in self._ws:
|
307
|
+
try:
|
308
|
+
data = json.loads(raw)
|
309
|
+
etype = data.get("type")
|
310
|
+
eid = data.get("event_id")
|
311
|
+
logger.debug("WS recv: %s (event_id=%s)", etype, eid)
|
312
|
+
# Track active response.id from any response.* event that carries it
|
313
|
+
try:
|
314
|
+
if isinstance(etype, str) and etype.startswith("response."):
|
315
|
+
_resp = data.get("response") or {}
|
316
|
+
_rid = _resp.get("id")
|
317
|
+
if _rid:
|
318
|
+
self._active_response_id = _rid
|
319
|
+
except Exception:
|
320
|
+
pass
|
321
|
+
# Demux streams
|
322
|
+
if etype in ("response.output_audio.delta", "response.audio.delta"):
|
323
|
+
b64 = data.get("delta") or ""
|
324
|
+
if b64:
|
325
|
+
try:
|
326
|
+
chunk = base64.b64decode(b64)
|
327
|
+
self._audio_queue.put_nowait(chunk)
|
328
|
+
logger.info("Audio delta bytes=%d", len(chunk))
|
329
|
+
try:
|
330
|
+
# New response detected if we were previously inactive
|
331
|
+
if not getattr(self, "_response_active", False):
|
332
|
+
self._response_generation = (
|
333
|
+
int(
|
334
|
+
getattr(self, "_response_generation", 0)
|
335
|
+
)
|
336
|
+
+ 1
|
337
|
+
)
|
338
|
+
self._response_active = True
|
339
|
+
except Exception:
|
340
|
+
pass
|
341
|
+
except Exception:
|
342
|
+
pass
|
343
|
+
elif etype == "response.text.delta":
|
344
|
+
# Some servers emit generic text deltas with metadata marking transcription
|
345
|
+
metadata = data.get("response", {}).get("metadata", {})
|
346
|
+
if metadata.get("type") == "transcription":
|
347
|
+
delta = data.get("delta") or ""
|
348
|
+
if delta:
|
349
|
+
self._in_tr_queue.put_nowait(delta)
|
350
|
+
logger.info("Input transcript delta: %r", delta[:120])
|
351
|
+
else:
|
352
|
+
logger.debug("Input transcript delta: empty")
|
353
|
+
elif etype in (
|
354
|
+
"response.function_call_arguments.delta",
|
355
|
+
"response.function_call_arguments.done",
|
356
|
+
):
|
357
|
+
# Capture streamed function-call arguments early and mark readiness on .done
|
358
|
+
try:
|
359
|
+
resp = data.get("response") or {}
|
360
|
+
rid = resp.get("id") or getattr(
|
361
|
+
self, "_active_response_id", None
|
362
|
+
)
|
363
|
+
# Many servers include call_id at top-level for these events
|
364
|
+
call_id = data.get("call_id") or data.get("id")
|
365
|
+
# Some servers include the function call item under 'item'
|
366
|
+
if not call_id:
|
367
|
+
call_id = (data.get("item") or {}).get("call_id") or (
|
368
|
+
data.get("item") or {}
|
369
|
+
).get("id")
|
370
|
+
# Name can appear directly or under item
|
371
|
+
name = data.get("name") or (data.get("item") or {}).get(
|
372
|
+
"name"
|
373
|
+
)
|
374
|
+
if name:
|
375
|
+
self._call_names[call_id] = name
|
376
|
+
if rid and call_id:
|
377
|
+
self._call_response_ids[call_id] = rid
|
378
|
+
if etype.endswith("delta"):
|
379
|
+
delta = data.get("delta") or ""
|
380
|
+
if call_id and delta:
|
381
|
+
self._call_args_accum[call_id] = (
|
382
|
+
self._call_args_accum.get(call_id, "") + delta
|
383
|
+
)
|
384
|
+
else: # .done
|
385
|
+
if call_id:
|
386
|
+
# Mark call ready and enqueue pending execution if not already
|
387
|
+
ev = self._call_ready_events.get(call_id)
|
388
|
+
if not ev:
|
389
|
+
ev = asyncio.Event()
|
390
|
+
self._call_ready_events[call_id] = ev
|
391
|
+
ev.set()
|
392
|
+
# Register pending call using accumulated args
|
393
|
+
if call_id not in getattr(
|
394
|
+
self, "_pending_calls", {}
|
395
|
+
):
|
396
|
+
args_text = (
|
397
|
+
self._call_args_accum.get(call_id, "{}")
|
398
|
+
or "{}"
|
399
|
+
)
|
400
|
+
if not hasattr(self, "_pending_calls"):
|
401
|
+
self._pending_calls = {}
|
402
|
+
self._pending_calls[call_id] = {
|
403
|
+
"name": self._call_names.get(call_id),
|
404
|
+
"args": args_text,
|
405
|
+
"gen": int(
|
406
|
+
getattr(self, "_response_generation", 0)
|
407
|
+
),
|
408
|
+
"call_id": call_id,
|
409
|
+
"item_id": None,
|
410
|
+
}
|
411
|
+
self._executed_call_ids.add(call_id)
|
412
|
+
await self._execute_pending_call(call_id)
|
413
|
+
except Exception:
|
414
|
+
pass
|
415
|
+
elif etype == "response.output_text.delta":
|
416
|
+
# Assistant textual output delta. Buffer per response.id.
|
417
|
+
# Prefer the audio transcript stream for final transcript; only use text
|
418
|
+
# deltas if no audio transcript arrives.
|
419
|
+
try:
|
420
|
+
rid = (data.get("response") or {}).get("id") or getattr(
|
421
|
+
self, "_active_response_id", None
|
422
|
+
)
|
423
|
+
delta = data.get("delta") or ""
|
424
|
+
if rid and delta:
|
425
|
+
buf = self._out_text_buffers.setdefault(
|
426
|
+
rid, {"text": "", "has_audio": False}
|
427
|
+
)
|
428
|
+
# Only accumulate text when we don't yet have audio transcript
|
429
|
+
if not bool(buf.get("has_audio")):
|
430
|
+
buf["text"] = str(buf.get("text", "")) + delta
|
431
|
+
logger.debug(
|
432
|
+
"Buffered assistant text delta (rid=%s, len=%d)",
|
433
|
+
rid,
|
434
|
+
len(delta),
|
435
|
+
)
|
436
|
+
except Exception:
|
437
|
+
pass
|
438
|
+
elif etype == "conversation.item.input_audio_transcription.delta":
|
439
|
+
delta = data.get("delta") or ""
|
440
|
+
if delta:
|
441
|
+
self._in_tr_queue.put_nowait(delta)
|
442
|
+
logger.info("Input transcript delta (GA): %r", delta[:120])
|
443
|
+
else:
|
444
|
+
logger.debug("Input transcript delta (GA): empty")
|
445
|
+
elif etype in (
|
446
|
+
"response.output_audio_transcript.delta",
|
447
|
+
"response.audio_transcript.delta",
|
448
|
+
):
|
449
|
+
# Assistant audio transcript delta (authoritative for spoken output)
|
450
|
+
try:
|
451
|
+
rid = (data.get("response") or {}).get("id") or getattr(
|
452
|
+
self, "_active_response_id", None
|
453
|
+
)
|
454
|
+
delta = data.get("delta") or ""
|
455
|
+
if rid and delta:
|
456
|
+
buf = self._out_text_buffers.setdefault(
|
457
|
+
rid, {"text": "", "has_audio": False}
|
458
|
+
)
|
459
|
+
buf["has_audio"] = True
|
460
|
+
buf["text"] = str(buf.get("text", "")) + delta
|
461
|
+
logger.debug(
|
462
|
+
"Buffered audio transcript delta (rid=%s, len=%d)",
|
463
|
+
rid,
|
464
|
+
len(delta),
|
465
|
+
)
|
466
|
+
except Exception:
|
467
|
+
pass
|
468
|
+
elif etype == "input_audio_buffer.committed":
|
469
|
+
# Track the committed audio item id for OOB transcription referencing
|
470
|
+
item_id = data.get("item_id") or data.get("id")
|
471
|
+
if item_id:
|
472
|
+
self._last_input_item_id = item_id
|
473
|
+
try:
|
474
|
+
self._last_commit_ts = asyncio.get_event_loop().time()
|
475
|
+
except Exception:
|
476
|
+
pass
|
477
|
+
logger.info(
|
478
|
+
"Realtime WS: input_audio_buffer committed: item_id=%s",
|
479
|
+
item_id,
|
480
|
+
)
|
481
|
+
# Clear commit in-flight flag on ack and signal waiters
|
482
|
+
try:
|
483
|
+
self._commit_inflight = False
|
484
|
+
except Exception:
|
485
|
+
pass
|
486
|
+
try:
|
487
|
+
self._commit_evt.set()
|
488
|
+
except Exception:
|
489
|
+
pass
|
490
|
+
elif etype in (
|
491
|
+
"response.output_audio.done",
|
492
|
+
"response.audio.done",
|
493
|
+
):
|
494
|
+
# End of audio stream for the response; stop audio iterator but keep WS open for transcripts
|
495
|
+
logger.info(
|
496
|
+
"Realtime WS: output audio done; ending audio stream"
|
497
|
+
)
|
498
|
+
# If we have a buffered transcript for this response, flush it now
|
499
|
+
try:
|
500
|
+
rid = (data.get("response") or {}).get("id") or getattr(
|
501
|
+
self, "_active_response_id", None
|
502
|
+
)
|
503
|
+
if rid and rid in self._out_text_buffers:
|
504
|
+
final = str(
|
505
|
+
self._out_text_buffers.get(rid, {}).get("text")
|
506
|
+
or ""
|
507
|
+
)
|
508
|
+
if final:
|
509
|
+
self._out_tr_queue.put_nowait(final)
|
510
|
+
logger.debug(
|
511
|
+
"Flushed assistant transcript on audio.done (rid=%s, len=%d)",
|
512
|
+
rid,
|
513
|
+
len(final),
|
514
|
+
)
|
515
|
+
self._out_text_buffers.pop(rid, None)
|
516
|
+
except Exception:
|
517
|
+
pass
|
518
|
+
try:
|
519
|
+
self._audio_queue.put_nowait(None)
|
520
|
+
except Exception:
|
521
|
+
pass
|
522
|
+
try:
|
523
|
+
self._response_active = False
|
524
|
+
except Exception:
|
525
|
+
pass
|
526
|
+
try:
|
527
|
+
# Clear active response id when audio for that response is done
|
528
|
+
self._active_response_id = None
|
529
|
+
except Exception:
|
530
|
+
pass
|
531
|
+
# Don't break; we still want to receive input transcription events
|
532
|
+
elif etype in (
|
533
|
+
"response.completed",
|
534
|
+
"response.complete",
|
535
|
+
"response.done",
|
536
|
+
):
|
537
|
+
# Do not terminate audio stream here; completion may be for a function_call
|
538
|
+
metadata = data.get("response", {}).get("metadata", {})
|
539
|
+
if metadata.get("type") == "transcription":
|
540
|
+
logger.info("Realtime WS: transcription response completed")
|
541
|
+
# Flush buffered assistant transcript, if any
|
542
|
+
try:
|
543
|
+
rid = (data.get("response") or {}).get("id")
|
544
|
+
if rid and rid in self._out_text_buffers:
|
545
|
+
final = str(
|
546
|
+
self._out_text_buffers.get(rid, {}).get("text")
|
547
|
+
or ""
|
548
|
+
)
|
549
|
+
if final:
|
550
|
+
self._out_tr_queue.put_nowait(final)
|
551
|
+
logger.debug(
|
552
|
+
"Flushed assistant transcript on response.done (rid=%s, len=%d)",
|
553
|
+
rid,
|
554
|
+
len(final),
|
555
|
+
)
|
556
|
+
self._out_text_buffers.pop(rid, None)
|
557
|
+
except Exception:
|
558
|
+
pass
|
559
|
+
try:
|
560
|
+
self._response_active = False
|
561
|
+
except Exception:
|
562
|
+
pass
|
563
|
+
try:
|
564
|
+
# Response lifecycle ended; clear active id
|
565
|
+
self._active_response_id = None
|
566
|
+
except Exception:
|
567
|
+
pass
|
568
|
+
elif etype in ("response.text.done", "response.output_text.done"):
|
569
|
+
metadata = data.get("response", {}).get("metadata", {})
|
570
|
+
if metadata.get("type") == "transcription":
|
571
|
+
tr = data.get("text") or ""
|
572
|
+
if tr:
|
573
|
+
try:
|
574
|
+
self._in_tr_queue.put_nowait(tr)
|
575
|
+
logger.debug(
|
576
|
+
"Input transcript completed: %r", tr[:120]
|
577
|
+
)
|
578
|
+
except Exception:
|
579
|
+
pass
|
580
|
+
else:
|
581
|
+
# For assistant text-only completions without audio, flush buffered text
|
582
|
+
try:
|
583
|
+
rid = (data.get("response") or {}).get("id") or getattr(
|
584
|
+
self, "_active_response_id", None
|
585
|
+
)
|
586
|
+
if rid and rid in self._out_text_buffers:
|
587
|
+
final = str(
|
588
|
+
self._out_text_buffers.get(rid, {}).get("text")
|
589
|
+
or ""
|
590
|
+
)
|
591
|
+
if final:
|
592
|
+
self._out_tr_queue.put_nowait(final)
|
593
|
+
logger.debug(
|
594
|
+
"Flushed assistant transcript on text.done (rid=%s, len=%d)",
|
595
|
+
rid,
|
596
|
+
len(final),
|
597
|
+
)
|
598
|
+
self._out_text_buffers.pop(rid, None)
|
599
|
+
except Exception:
|
600
|
+
pass
|
601
|
+
elif (
|
602
|
+
etype == "conversation.item.input_audio_transcription.completed"
|
603
|
+
):
|
604
|
+
tr = data.get("transcript") or ""
|
605
|
+
if tr:
|
606
|
+
try:
|
607
|
+
self._in_tr_queue.put_nowait(tr)
|
608
|
+
logger.debug(
|
609
|
+
"Input transcript completed (GA): %r", tr[:120]
|
610
|
+
)
|
611
|
+
except Exception:
|
612
|
+
pass
|
613
|
+
elif etype == "session.updated":
|
614
|
+
sess = data.get("session", {})
|
615
|
+
self._last_session_updated_payload = sess
|
616
|
+
# Track server auto-create enablement (turn_detection.create_response)
|
617
|
+
try:
|
618
|
+
td_root = sess.get("turn_detection")
|
619
|
+
td_input = (
|
620
|
+
(sess.get("audio") or {})
|
621
|
+
.get("input", {})
|
622
|
+
.get("turn_detection")
|
623
|
+
)
|
624
|
+
td = td_root if td_root is not None else td_input
|
625
|
+
# If local VAD is disabled, force auto-create disabled
|
626
|
+
if not bool(getattr(self.options, "vad_enabled", False)):
|
627
|
+
self._server_auto_create_enabled = False
|
628
|
+
else:
|
629
|
+
self._server_auto_create_enabled = bool(
|
630
|
+
isinstance(td, dict)
|
631
|
+
and bool(td.get("create_response"))
|
632
|
+
)
|
633
|
+
logger.debug(
|
634
|
+
"Realtime WS: server_auto_create_enabled=%s",
|
635
|
+
self._server_auto_create_enabled,
|
636
|
+
)
|
637
|
+
except Exception:
|
638
|
+
pass
|
639
|
+
# Mark that the latest update has been applied
|
640
|
+
try:
|
641
|
+
self._awaiting_session_updated = False
|
642
|
+
self._session_updated_evt.set()
|
643
|
+
except Exception:
|
644
|
+
pass
|
645
|
+
try:
|
646
|
+
logger.info(
|
647
|
+
"Realtime WS: session.updated payload=%s",
|
648
|
+
json.dumps(sess, sort_keys=True),
|
649
|
+
)
|
650
|
+
except Exception:
|
651
|
+
logger.info(
|
652
|
+
"Realtime WS: session updated (payload dump failed)"
|
653
|
+
)
|
654
|
+
# Extra guardrails: warn if key fields absent/empty
|
655
|
+
try:
|
656
|
+
instr = sess.get("instructions")
|
657
|
+
voice = sess.get("voice") or (
|
658
|
+
(sess.get("audio") or {}).get("output", {}).get("voice")
|
659
|
+
)
|
660
|
+
if instr is None or (
|
661
|
+
isinstance(instr, str) and instr.strip() == ""
|
662
|
+
):
|
663
|
+
logger.warning(
|
664
|
+
"Realtime WS: session.updated has empty/missing instructions"
|
665
|
+
)
|
666
|
+
if not voice:
|
667
|
+
logger.warning(
|
668
|
+
"Realtime WS: session.updated missing voice"
|
669
|
+
)
|
670
|
+
except Exception:
|
671
|
+
pass
|
672
|
+
elif etype == "session.created":
|
673
|
+
logger.info("Realtime WS: session created")
|
674
|
+
try:
|
675
|
+
self._session_created_evt.set()
|
676
|
+
except Exception:
|
677
|
+
pass
|
678
|
+
elif etype == "error" or (
|
679
|
+
isinstance(etype, str)
|
680
|
+
and (
|
681
|
+
etype.endswith(".error")
|
682
|
+
or ("error" in etype)
|
683
|
+
or ("failed" in etype)
|
684
|
+
)
|
685
|
+
):
|
686
|
+
# Surface server errors explicitly
|
687
|
+
try:
|
688
|
+
logger.error(
|
689
|
+
"Realtime WS error event: %s",
|
690
|
+
json.dumps(data, sort_keys=True),
|
691
|
+
)
|
692
|
+
except Exception:
|
693
|
+
logger.error(
|
694
|
+
"Realtime WS error event (payload dump failed)"
|
695
|
+
)
|
696
|
+
# Correlate to the originating client event by event_id
|
697
|
+
try:
|
698
|
+
eid = data.get("event_id") or data.get("error", {}).get(
|
699
|
+
"event_id"
|
700
|
+
)
|
701
|
+
if eid and eid in self._sent_events:
|
702
|
+
sent = self._sent_events.get(eid) or {}
|
703
|
+
logger.error(
|
704
|
+
"Realtime WS error correlated: event_id=%s sent_label=%s sent_type=%s",
|
705
|
+
eid,
|
706
|
+
sent.get("label"),
|
707
|
+
sent.get("type"),
|
708
|
+
)
|
709
|
+
except Exception:
|
710
|
+
pass
|
711
|
+
# No legacy fallback; rely on current config/state.
|
712
|
+
# Always also publish raw events
|
713
|
+
try:
|
714
|
+
self._event_queue.put_nowait(data)
|
715
|
+
except Exception:
|
716
|
+
pass
|
717
|
+
# Handle tool/function calls (GA-only triggers)
|
718
|
+
if etype in (
|
719
|
+
"conversation.item.created",
|
720
|
+
"conversation.item.added",
|
721
|
+
):
|
722
|
+
try:
|
723
|
+
item = data.get("item") or {}
|
724
|
+
if item.get("type") == "function_call":
|
725
|
+
call_id = item.get("call_id")
|
726
|
+
item_id = item.get("id")
|
727
|
+
# Mark readiness for both identifiers if available
|
728
|
+
for cid in filter(None, [call_id, item_id]):
|
729
|
+
ev = self._call_ready_events.get(cid)
|
730
|
+
if not ev:
|
731
|
+
ev = asyncio.Event()
|
732
|
+
self._call_ready_events[cid] = ev
|
733
|
+
ev.set()
|
734
|
+
logger.debug(
|
735
|
+
"Call ready via item.%s: call_id=%s item_id=%s",
|
736
|
+
"created" if etype.endswith("created") else "added",
|
737
|
+
call_id,
|
738
|
+
item_id,
|
739
|
+
)
|
740
|
+
except Exception:
|
741
|
+
pass
|
742
|
+
elif etype in (
|
743
|
+
"response.output_item.added",
|
744
|
+
"response.output_item.created",
|
745
|
+
):
|
746
|
+
# Map response-scoped function_call items to response_id and ack function_call_output
|
747
|
+
try:
|
748
|
+
resp = data.get("response") or {}
|
749
|
+
rid = resp.get("id") or getattr(
|
750
|
+
self, "_active_response_id", None
|
751
|
+
)
|
752
|
+
item = data.get("item") or {}
|
753
|
+
itype = item.get("type")
|
754
|
+
if itype == "function_call":
|
755
|
+
call_id = item.get("call_id")
|
756
|
+
item_id = item.get("id")
|
757
|
+
name = item.get("name")
|
758
|
+
if name and call_id:
|
759
|
+
self._call_names[call_id] = name
|
760
|
+
if call_id and item_id:
|
761
|
+
self._item_call_ids[item_id] = call_id
|
762
|
+
# Bind call identifiers to the response id for response-scoped outputs
|
763
|
+
for cid in filter(None, [call_id, item_id]):
|
764
|
+
if rid:
|
765
|
+
self._call_response_ids[cid] = rid
|
766
|
+
# Signal readiness for these identifiers
|
767
|
+
for cid in filter(None, [call_id, item_id]):
|
768
|
+
ev = self._call_ready_events.get(cid)
|
769
|
+
if not ev:
|
770
|
+
ev = asyncio.Event()
|
771
|
+
self._call_ready_events[cid] = ev
|
772
|
+
ev.set()
|
773
|
+
logger.debug(
|
774
|
+
"Mapped function_call via response.%s: call_id=%s item_id=%s response_id=%s",
|
775
|
+
"created" if etype.endswith("created") else "added",
|
776
|
+
call_id,
|
777
|
+
item_id,
|
778
|
+
rid,
|
779
|
+
)
|
780
|
+
except Exception:
|
781
|
+
pass
|
782
|
+
elif etype in (
|
783
|
+
"response.done",
|
784
|
+
"response.completed",
|
785
|
+
"response.complete",
|
786
|
+
):
|
787
|
+
# Also detect GA function_call items in response.output
|
788
|
+
try:
|
789
|
+
resp = data.get("response", {})
|
790
|
+
rid = resp.get("id")
|
791
|
+
out_items = resp.get("output", [])
|
792
|
+
for item in out_items:
|
793
|
+
if item.get("type") == "function_call":
|
794
|
+
call_id = item.get("call_id")
|
795
|
+
item_id = item.get("id")
|
796
|
+
# Prefer a canonical id to reference later; try item_id first, then call_id
|
797
|
+
canonical_id = item_id or call_id
|
798
|
+
name = item.get("name")
|
799
|
+
args = item.get("arguments") or "{}"
|
800
|
+
status = item.get("status")
|
801
|
+
if call_id and item_id:
|
802
|
+
self._item_call_ids[item_id] = call_id
|
803
|
+
# Map this call to the response id for response-scoped outputs
|
804
|
+
for cid in filter(
|
805
|
+
None, [call_id, item_id, canonical_id]
|
806
|
+
):
|
807
|
+
if rid:
|
808
|
+
self._call_response_ids[cid] = rid
|
809
|
+
if (
|
810
|
+
canonical_id in self._executed_call_ids
|
811
|
+
or (
|
812
|
+
call_id
|
813
|
+
and call_id in self._executed_call_ids
|
814
|
+
)
|
815
|
+
or (
|
816
|
+
item_id
|
817
|
+
and item_id in self._executed_call_ids
|
818
|
+
)
|
819
|
+
):
|
820
|
+
continue
|
821
|
+
if status and str(status).lower() not in (
|
822
|
+
"completed",
|
823
|
+
"complete",
|
824
|
+
"done",
|
825
|
+
):
|
826
|
+
# Wait for completion in a later event
|
827
|
+
continue
|
828
|
+
if not hasattr(self, "_pending_calls"):
|
829
|
+
self._pending_calls = {}
|
830
|
+
self._pending_calls[canonical_id] = {
|
831
|
+
"name": name,
|
832
|
+
"args": args,
|
833
|
+
# Bind this call to the current response generation
|
834
|
+
"gen": int(
|
835
|
+
getattr(self, "_response_generation", 0)
|
836
|
+
),
|
837
|
+
# Keep both identifiers for fallback when posting outputs
|
838
|
+
"call_id": call_id,
|
839
|
+
"item_id": item_id,
|
840
|
+
}
|
841
|
+
# Mark this call as ready (server has produced the item)
|
842
|
+
for cid in filter(
|
843
|
+
None, [canonical_id, call_id, item_id]
|
844
|
+
):
|
845
|
+
ev = self._call_ready_events.get(cid)
|
846
|
+
if not ev:
|
847
|
+
ev = asyncio.Event()
|
848
|
+
self._call_ready_events[cid] = ev
|
849
|
+
ev.set()
|
850
|
+
logger.debug(
|
851
|
+
"Call ready via response.done: call_id=%s item_id=%s canonical_id=%s",
|
852
|
+
call_id,
|
853
|
+
item_id,
|
854
|
+
canonical_id,
|
855
|
+
)
|
856
|
+
for cid in filter(
|
857
|
+
None, [canonical_id, call_id, item_id]
|
858
|
+
):
|
859
|
+
self._executed_call_ids.add(cid)
|
860
|
+
await self._execute_pending_call(canonical_id)
|
861
|
+
except Exception:
|
862
|
+
pass
|
863
|
+
|
864
|
+
# No ack/error event mapping; server emits concrete lifecycle events
|
865
|
+
except Exception:
|
866
|
+
continue
|
867
|
+
except Exception:
|
868
|
+
logger.exception("Realtime WS receive loop error")
|
869
|
+
finally:
|
870
|
+
# Close queues gracefully
|
871
|
+
for q in (
|
872
|
+
self._audio_queue,
|
873
|
+
self._in_tr_queue,
|
874
|
+
self._out_tr_queue,
|
875
|
+
self._event_queue,
|
876
|
+
):
|
877
|
+
try:
|
878
|
+
q.put_nowait(None) # type: ignore
|
879
|
+
except Exception:
|
880
|
+
pass
|
881
|
+
|
882
|
+
# --- Client event helpers ---
|
883
|
+
async def update_session(
|
884
|
+
self, session_patch: Dict[str, Any]
|
885
|
+
) -> None: # pragma: no cover
|
886
|
+
# Build nested session.update per docs. Only include provided fields.
|
887
|
+
raw = dict(session_patch or {})
|
888
|
+
patch: Dict[str, Any] = {}
|
889
|
+
audio_patch: Dict[str, Any] = {}
|
890
|
+
|
891
|
+
try:
|
892
|
+
audio = dict(raw.get("audio") or {})
|
893
|
+
# Normalize turn_detection to audio.input per server schema
|
894
|
+
include_td = False
|
895
|
+
turn_det = None
|
896
|
+
inp = dict(audio.get("input") or {})
|
897
|
+
if "turn_detection" in inp:
|
898
|
+
include_td = True
|
899
|
+
turn_det = inp.get("turn_detection")
|
900
|
+
if "turn_detection" in audio:
|
901
|
+
include_td = True
|
902
|
+
turn_det = audio.get("turn_detection")
|
903
|
+
if "turn_detection" in raw:
|
904
|
+
include_td = True
|
905
|
+
turn_det = raw.get("turn_detection")
|
906
|
+
|
907
|
+
inp = dict(audio.get("input") or {})
|
908
|
+
out = dict(audio.get("output") or {})
|
909
|
+
|
910
|
+
# Input format
|
911
|
+
fmt_in = inp.get("format", raw.get("input_audio_format"))
|
912
|
+
if fmt_in is not None:
|
913
|
+
if isinstance(fmt_in, dict):
|
914
|
+
audio_patch.setdefault("input", {})["format"] = fmt_in
|
915
|
+
else:
|
916
|
+
ftype = str(fmt_in)
|
917
|
+
if ftype == "pcm16":
|
918
|
+
ftype = "audio/pcm"
|
919
|
+
audio_patch.setdefault("input", {})["format"] = {
|
920
|
+
"type": ftype,
|
921
|
+
"rate": int(self.options.input_rate_hz or 24000),
|
922
|
+
}
|
923
|
+
|
924
|
+
# Optional input extras
|
925
|
+
for key in ("noise_reduction", "transcription"):
|
926
|
+
if key in audio:
|
927
|
+
audio_patch[key] = audio.get(key)
|
928
|
+
|
929
|
+
# Apply turn_detection under audio.input if provided (allow None to disable)
|
930
|
+
if include_td:
|
931
|
+
audio_patch.setdefault("input", {})["turn_detection"] = turn_det
|
932
|
+
|
933
|
+
# Output format/voice/speed
|
934
|
+
op: Dict[str, Any] = {}
|
935
|
+
fmt_out = out.get("format", raw.get("output_audio_format"))
|
936
|
+
if fmt_out is not None:
|
937
|
+
if isinstance(fmt_out, dict):
|
938
|
+
op["format"] = fmt_out
|
939
|
+
else:
|
940
|
+
ftype = str(fmt_out)
|
941
|
+
if ftype == "pcm16":
|
942
|
+
ftype = "audio/pcm"
|
943
|
+
op["format"] = {
|
944
|
+
"type": ftype,
|
945
|
+
"rate": int(self.options.output_rate_hz or 24000),
|
946
|
+
}
|
947
|
+
if "voice" in out:
|
948
|
+
op["voice"] = out.get("voice")
|
949
|
+
if "speed" in out:
|
950
|
+
op["speed"] = out.get("speed")
|
951
|
+
# Convenience: allow top-level overrides
|
952
|
+
if "voice" in raw and "voice" not in op:
|
953
|
+
op["voice"] = raw.get("voice")
|
954
|
+
if "speed" in raw and "speed" not in op:
|
955
|
+
op["speed"] = raw.get("speed")
|
956
|
+
if op:
|
957
|
+
audio_patch.setdefault("output", {}).update(op)
|
958
|
+
except Exception:
|
959
|
+
pass
|
960
|
+
|
961
|
+
if audio_patch:
|
962
|
+
patch["audio"] = audio_patch
|
963
|
+
|
964
|
+
# Always include session.type in updates
|
965
|
+
patch["type"] = "realtime"
|
966
|
+
|
967
|
+
# No top-level turn_detection
|
968
|
+
|
969
|
+
def _strip_tool_strict(tools_val):
|
970
|
+
try:
|
971
|
+
tools_list = list(tools_val or [])
|
972
|
+
except Exception:
|
973
|
+
return tools_val
|
974
|
+
cleaned = []
|
975
|
+
for t in tools_list:
|
976
|
+
try:
|
977
|
+
t2 = dict(t)
|
978
|
+
t2.pop("strict", None)
|
979
|
+
cleaned.append(t2)
|
980
|
+
except Exception:
|
981
|
+
cleaned.append(t)
|
982
|
+
return cleaned
|
983
|
+
|
984
|
+
# Pass through other documented fields if present
|
985
|
+
for k in (
|
986
|
+
"model",
|
987
|
+
"output_modalities",
|
988
|
+
"prompt",
|
989
|
+
"instructions",
|
990
|
+
"tools",
|
991
|
+
"tool_choice",
|
992
|
+
"include",
|
993
|
+
"max_output_tokens",
|
994
|
+
"tracing",
|
995
|
+
"truncation",
|
996
|
+
):
|
997
|
+
if k in raw:
|
998
|
+
if k == "tools":
|
999
|
+
patch[k] = _strip_tool_strict(raw[k])
|
1000
|
+
else:
|
1001
|
+
patch[k] = raw[k]
|
1002
|
+
|
1003
|
+
# Ensure tools are cleaned even if provided only under audio or elsewhere
|
1004
|
+
if "tools" in patch:
|
1005
|
+
patch["tools"] = _strip_tool_strict(patch["tools"]) # idempotent
|
1006
|
+
|
1007
|
+
payload = {"type": "session.update", "session": patch}
|
1008
|
+
# Mark awaiting updated and store last patch
|
1009
|
+
self._last_session_patch = patch or {}
|
1010
|
+
self._session_updated_evt = asyncio.Event()
|
1011
|
+
self._awaiting_session_updated = True
|
1012
|
+
# Log payload and warn if potentially clearing/omitting critical fields
|
1013
|
+
try:
|
1014
|
+
logger.info(
|
1015
|
+
"Realtime WS: sending session.update payload=%s",
|
1016
|
+
json.dumps(self._last_session_patch, sort_keys=True),
|
1017
|
+
)
|
1018
|
+
if "instructions" in self._last_session_patch and (
|
1019
|
+
(self._last_session_patch.get("instructions") or "").strip() == ""
|
1020
|
+
):
|
1021
|
+
logger.warning(
|
1022
|
+
"Realtime WS: session.update sets empty instructions; this clears them"
|
1023
|
+
)
|
1024
|
+
out_cfg = (self._last_session_patch.get("audio") or {}).get("output") or {}
|
1025
|
+
if "voice" in out_cfg and not out_cfg.get("voice"):
|
1026
|
+
logger.warning("Realtime WS: session.update provides empty voice")
|
1027
|
+
if "instructions" not in self._last_session_patch:
|
1028
|
+
logger.warning(
|
1029
|
+
"Realtime WS: session.update omits instructions; relying on previous instructions"
|
1030
|
+
)
|
1031
|
+
except Exception:
|
1032
|
+
pass
|
1033
|
+
await self._send(payload)
|
1034
|
+
|
1035
|
+
async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
|
1036
|
+
b64 = base64.b64encode(pcm16_bytes).decode("ascii")
|
1037
|
+
await self._send_tracked(
|
1038
|
+
{"type": "input_audio_buffer.append", "audio": b64},
|
1039
|
+
label="input_audio_buffer.append",
|
1040
|
+
)
|
1041
|
+
try:
|
1042
|
+
self._pending_input_bytes += len(pcm16_bytes)
|
1043
|
+
except Exception:
|
1044
|
+
pass
|
1045
|
+
|
1046
|
+
async def commit_input(self) -> None: # pragma: no cover
|
1047
|
+
try:
|
1048
|
+
# Skip commits while a response is active to avoid server errors
|
1049
|
+
if bool(getattr(self, "_response_active", False)):
|
1050
|
+
logger.warning("Realtime WS: skipping commit; response active")
|
1051
|
+
return
|
1052
|
+
# Avoid overlapping commits while awaiting server ack
|
1053
|
+
if bool(getattr(self, "_commit_inflight", False)):
|
1054
|
+
logger.warning("Realtime WS: skipping commit; commit in-flight")
|
1055
|
+
return
|
1056
|
+
# Avoid rapid duplicate commits
|
1057
|
+
last_commit = float(getattr(self, "_last_commit_ts", 0.0))
|
1058
|
+
if last_commit and (asyncio.get_event_loop().time() - last_commit) < 1.0:
|
1059
|
+
logger.warning("Realtime WS: skipping commit; committed recently")
|
1060
|
+
return
|
1061
|
+
# Require at least 100ms of audio (~4800 bytes at 24kHz mono 16-bit)
|
1062
|
+
min_bytes = int(0.1 * int(self.options.input_rate_hz or 24000) * 2)
|
1063
|
+
except Exception:
|
1064
|
+
min_bytes = 4800
|
1065
|
+
if int(getattr(self, "_pending_input_bytes", 0)) < min_bytes:
|
1066
|
+
try:
|
1067
|
+
logger.warning(
|
1068
|
+
"Realtime WS: skipping commit; buffer too small bytes=%d < %d",
|
1069
|
+
int(getattr(self, "_pending_input_bytes", 0)),
|
1070
|
+
min_bytes,
|
1071
|
+
)
|
1072
|
+
except Exception:
|
1073
|
+
pass
|
1074
|
+
return
|
1075
|
+
# Reset commit event before sending a new commit and mark as in-flight
|
1076
|
+
try:
|
1077
|
+
self._commit_evt = asyncio.Event()
|
1078
|
+
self._commit_inflight = True
|
1079
|
+
except Exception:
|
1080
|
+
pass
|
1081
|
+
await self._send_tracked(
|
1082
|
+
{"type": "input_audio_buffer.commit"}, label="input_audio_buffer.commit"
|
1083
|
+
)
|
1084
|
+
try:
|
1085
|
+
logger.info("Realtime WS: input_audio_buffer.commit sent")
|
1086
|
+
self._pending_input_bytes = 0
|
1087
|
+
self._last_commit_ts = asyncio.get_event_loop().time()
|
1088
|
+
except Exception:
|
1089
|
+
pass
|
1090
|
+
|
1091
|
+
async def clear_input(self) -> None: # pragma: no cover
|
1092
|
+
await self._send_tracked(
|
1093
|
+
{"type": "input_audio_buffer.clear"}, label="input_audio_buffer.clear"
|
1094
|
+
)
|
1095
|
+
# Reset last input reference and commit event to avoid stale references
|
1096
|
+
try:
|
1097
|
+
self._last_input_item_id = None
|
1098
|
+
self._commit_evt = asyncio.Event()
|
1099
|
+
except Exception:
|
1100
|
+
pass
|
1101
|
+
|
1102
|
+
async def create_response(
|
1103
|
+
self, response_patch: Optional[Dict[str, Any]] = None
|
1104
|
+
) -> None: # pragma: no cover
|
1105
|
+
# Avoid duplicate responses: if server auto-creates after commit or one is already active, don't send.
|
1106
|
+
try:
|
1107
|
+
if getattr(self, "_response_active", False):
|
1108
|
+
logger.warning(
|
1109
|
+
"Realtime WS: response.create suppressed — response already active"
|
1110
|
+
)
|
1111
|
+
return
|
1112
|
+
auto = bool(getattr(self, "_server_auto_create_enabled", False))
|
1113
|
+
last_commit = float(getattr(self, "_last_commit_ts", 0.0))
|
1114
|
+
if auto and last_commit:
|
1115
|
+
# If we committed very recently (<1.0s), assume server will auto-create
|
1116
|
+
if (asyncio.get_event_loop().time() - last_commit) < 1.0:
|
1117
|
+
logger.info(
|
1118
|
+
"Realtime WS: response.create skipped — server auto-create expected"
|
1119
|
+
)
|
1120
|
+
return
|
1121
|
+
except Exception:
|
1122
|
+
pass
|
1123
|
+
# Wait briefly for commit event so we can reference the latest audio item when applicable
|
1124
|
+
if not self._last_input_item_id:
|
1125
|
+
try:
|
1126
|
+
await asyncio.wait_for(self._commit_evt.wait(), timeout=2.0)
|
1127
|
+
except asyncio.TimeoutError:
|
1128
|
+
pass
|
1129
|
+
# Ensure the latest session.update (if any) has been applied before responding
|
1130
|
+
if self._awaiting_session_updated:
|
1131
|
+
# Prefer an explicit session.updated; if absent, accept session.created
|
1132
|
+
try:
|
1133
|
+
if self._session_updated_evt.is_set():
|
1134
|
+
logger.info(
|
1135
|
+
"Realtime WS: response.create proceeding after session.updated (pre-set)"
|
1136
|
+
)
|
1137
|
+
else:
|
1138
|
+
await asyncio.wait_for(
|
1139
|
+
self._session_updated_evt.wait(), timeout=2.5
|
1140
|
+
)
|
1141
|
+
logger.info(
|
1142
|
+
"Realtime WS: response.create proceeding after session.updated"
|
1143
|
+
)
|
1144
|
+
except asyncio.TimeoutError:
|
1145
|
+
if self._session_created_evt.is_set():
|
1146
|
+
logger.info(
|
1147
|
+
"Realtime WS: response.create proceeding after session.created (no session.updated observed)"
|
1148
|
+
)
|
1149
|
+
# Best-effort: resend last session.update once more to apply voice/instructions
|
1150
|
+
try:
|
1151
|
+
if self._last_session_patch:
|
1152
|
+
logger.info(
|
1153
|
+
"Realtime WS: resending session.update to apply config before response"
|
1154
|
+
)
|
1155
|
+
# Reset awaiting flag and wait briefly again
|
1156
|
+
self._session_updated_evt = asyncio.Event()
|
1157
|
+
self._awaiting_session_updated = True
|
1158
|
+
# Ensure required session.type on retry
|
1159
|
+
_sess = dict(self._last_session_patch or {})
|
1160
|
+
_sess["type"] = "realtime"
|
1161
|
+
await self._send(
|
1162
|
+
{"type": "session.update", "session": _sess}
|
1163
|
+
)
|
1164
|
+
try:
|
1165
|
+
await asyncio.wait_for(
|
1166
|
+
self._session_updated_evt.wait(), timeout=1.0
|
1167
|
+
)
|
1168
|
+
logger.info(
|
1169
|
+
"Realtime WS: proceeding after retry session.updated"
|
1170
|
+
)
|
1171
|
+
except asyncio.TimeoutError:
|
1172
|
+
logger.warning(
|
1173
|
+
"Realtime WS: retry session.update did not yield session.updated in time"
|
1174
|
+
)
|
1175
|
+
except Exception:
|
1176
|
+
pass
|
1177
|
+
else:
|
1178
|
+
try:
|
1179
|
+
await asyncio.wait_for(
|
1180
|
+
self._session_created_evt.wait(), timeout=2.5
|
1181
|
+
)
|
1182
|
+
logger.info(
|
1183
|
+
"Realtime WS: response.create proceeding after session.created"
|
1184
|
+
)
|
1185
|
+
except asyncio.TimeoutError:
|
1186
|
+
logger.warning(
|
1187
|
+
"Realtime WS: neither session.updated nor session.created received in time; proceeding"
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
# Then, create main response
|
1191
|
+
payload: Dict[str, Any] = {"type": "response.create"}
|
1192
|
+
if response_patch:
|
1193
|
+
payload["response"] = response_patch
|
1194
|
+
# Ensure response object exists; rely on session defaults for modalities/audio
|
1195
|
+
if "response" not in payload:
|
1196
|
+
payload["response"] = {}
|
1197
|
+
rp = payload["response"]
|
1198
|
+
# Sanitize unsupported fields that servers may reject
|
1199
|
+
try:
|
1200
|
+
rp.pop("modalities", None)
|
1201
|
+
rp.pop("audio", None)
|
1202
|
+
except Exception:
|
1203
|
+
pass
|
1204
|
+
rp.setdefault("metadata", {"type": "response"})
|
1205
|
+
# Attach input reference so the model links this response to last audio
|
1206
|
+
if self._last_input_item_id and "input" not in rp:
|
1207
|
+
rp["input"] = [{"type": "item_reference", "id": self._last_input_item_id}]
|
1208
|
+
try:
|
1209
|
+
has_ref = bool(self._last_input_item_id)
|
1210
|
+
logger.info(
|
1211
|
+
"Realtime WS: sending response.create (input_ref=%s)",
|
1212
|
+
has_ref,
|
1213
|
+
)
|
1214
|
+
except Exception:
|
1215
|
+
pass
|
1216
|
+
# Increment response generation when we intentionally start a new response
|
1217
|
+
try:
|
1218
|
+
if not getattr(self, "_response_active", False):
|
1219
|
+
self._response_generation = (
|
1220
|
+
int(getattr(self, "_response_generation", 0)) + 1
|
1221
|
+
)
|
1222
|
+
except Exception:
|
1223
|
+
pass
|
1224
|
+
await self._send_tracked(payload, label="response.create")
|
1225
|
+
try:
|
1226
|
+
self._response_active = True
|
1227
|
+
except Exception:
|
1228
|
+
pass
|
1229
|
+
|
1230
|
+
# --- Streams ---
|
1231
|
+
async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
|
1232
|
+
while True:
|
1233
|
+
item = await q.get()
|
1234
|
+
if item is None:
|
1235
|
+
break
|
1236
|
+
yield item
|
1237
|
+
|
1238
|
+
def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
|
1239
|
+
return self._iter_queue(self._event_queue)
|
1240
|
+
|
1241
|
+
def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
1242
|
+
return self._iter_queue(self._audio_queue)
|
1243
|
+
|
1244
|
+
def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
|
1245
|
+
return self._iter_queue(self._in_tr_queue)
|
1246
|
+
|
1247
|
+
def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
|
1248
|
+
return self._iter_queue(self._out_tr_queue)
|
1249
|
+
|
1250
|
+
def set_tool_executor(self, executor): # pragma: no cover
|
1251
|
+
self._tool_executor = executor
|
1252
|
+
|
1253
|
+
# Expose whether a function/tool call is currently pending
|
1254
|
+
def has_pending_tool_call(self) -> bool: # pragma: no cover
|
1255
|
+
try:
|
1256
|
+
return (
|
1257
|
+
bool(getattr(self, "_pending_calls", {}))
|
1258
|
+
or int(getattr(self, "_active_tool_calls", 0)) > 0
|
1259
|
+
)
|
1260
|
+
except Exception:
|
1261
|
+
return False
|
1262
|
+
|
1263
|
+
# --- Internal helpers for GA tool execution ---
|
1264
|
+
async def _execute_pending_call(self, call_id: Optional[str]) -> None:
|
1265
|
+
if not call_id:
|
1266
|
+
return
|
1267
|
+
# Peek without popping so we remain in a "pending/active" state
|
1268
|
+
pc = getattr(self, "_pending_calls", {}).get(call_id)
|
1269
|
+
if not pc or not self._tool_executor or not pc.get("name"):
|
1270
|
+
return
|
1271
|
+
try:
|
1272
|
+
# Drop if this call was bound to a previous response generation
|
1273
|
+
try:
|
1274
|
+
call_gen = int(pc.get("gen", 0))
|
1275
|
+
cur_gen = int(getattr(self, "_response_generation", 0))
|
1276
|
+
if call_gen and call_gen != cur_gen:
|
1277
|
+
logger.warning(
|
1278
|
+
"Skipping stale tool call: id=%s name=%s call_gen=%d cur_gen=%d",
|
1279
|
+
call_id,
|
1280
|
+
pc.get("name"),
|
1281
|
+
call_gen,
|
1282
|
+
cur_gen,
|
1283
|
+
)
|
1284
|
+
return
|
1285
|
+
except Exception:
|
1286
|
+
pass
|
1287
|
+
# Mark as active to keep timeouts from firing while tool runs
|
1288
|
+
try:
|
1289
|
+
self._active_tool_calls += 1
|
1290
|
+
except Exception:
|
1291
|
+
pass
|
1292
|
+
args_preview_len = len((pc.get("args") or ""))
|
1293
|
+
logger.info(
|
1294
|
+
"Executing tool: id=%s name=%s args_len=%d",
|
1295
|
+
call_id,
|
1296
|
+
pc.get("name"),
|
1297
|
+
args_preview_len,
|
1298
|
+
)
|
1299
|
+
args = pc.get("args") or "{}"
|
1300
|
+
try:
|
1301
|
+
parsed = json.loads(args)
|
1302
|
+
except Exception:
|
1303
|
+
parsed = {}
|
1304
|
+
start_ts = asyncio.get_event_loop().time()
|
1305
|
+
timeout_s = float(getattr(self.options, "tool_timeout_s", 300.0) or 300.0)
|
1306
|
+
try:
|
1307
|
+
result = await asyncio.wait_for(
|
1308
|
+
self._tool_executor(pc["name"], parsed), timeout=timeout_s
|
1309
|
+
)
|
1310
|
+
except asyncio.TimeoutError:
|
1311
|
+
logger.warning(
|
1312
|
+
"Tool timeout: id=%s name=%s exceeded %.1fs",
|
1313
|
+
call_id,
|
1314
|
+
pc.get("name"),
|
1315
|
+
timeout_s,
|
1316
|
+
)
|
1317
|
+
result = {"error": "tool_timeout"}
|
1318
|
+
dur = asyncio.get_event_loop().time() - start_ts
|
1319
|
+
try:
|
1320
|
+
result_summary = (
|
1321
|
+
f"keys={list(result.keys())[:5]}"
|
1322
|
+
if isinstance(result, dict)
|
1323
|
+
else type(result).__name__
|
1324
|
+
)
|
1325
|
+
except Exception:
|
1326
|
+
result_summary = "<unavailable>"
|
1327
|
+
logger.info(
|
1328
|
+
"Tool done: id=%s name=%s dur=%.2fs result=%s",
|
1329
|
+
call_id,
|
1330
|
+
pc.get("name"),
|
1331
|
+
dur,
|
1332
|
+
result_summary,
|
1333
|
+
)
|
1334
|
+
# Ensure the server has created the function_call item before we post output
|
1335
|
+
try:
|
1336
|
+
ev = self._call_ready_events.get(call_id)
|
1337
|
+
if ev:
|
1338
|
+
try:
|
1339
|
+
await asyncio.wait_for(ev.wait(), timeout=1.5)
|
1340
|
+
except asyncio.TimeoutError:
|
1341
|
+
logger.debug(
|
1342
|
+
"Call ready wait timed out; proceeding anyway: call_id=%s",
|
1343
|
+
call_id,
|
1344
|
+
)
|
1345
|
+
# Tiny jitter to help ordering on the server
|
1346
|
+
await asyncio.sleep(0.03)
|
1347
|
+
except Exception:
|
1348
|
+
pass
|
1349
|
+
|
1350
|
+
# Send tool result via conversation.item.create, then trigger response.create (per docs)
|
1351
|
+
try:
|
1352
|
+
# Derive a valid call_id and avoid sending item_id as call_id
|
1353
|
+
derived_call_id = (
|
1354
|
+
pc.get("call_id")
|
1355
|
+
or self._item_call_ids.get(pc.get("item_id"))
|
1356
|
+
or (
|
1357
|
+
call_id
|
1358
|
+
if isinstance(call_id, str) and call_id.startswith("call_")
|
1359
|
+
else None
|
1360
|
+
)
|
1361
|
+
)
|
1362
|
+
if not derived_call_id:
|
1363
|
+
logger.error(
|
1364
|
+
"Cannot send function_call_output: missing call_id (id=%s name=%s)",
|
1365
|
+
call_id,
|
1366
|
+
pc.get("name"),
|
1367
|
+
)
|
1368
|
+
else:
|
1369
|
+
await self._send_tracked(
|
1370
|
+
{
|
1371
|
+
"type": "conversation.item.create",
|
1372
|
+
"item": {
|
1373
|
+
"type": "function_call_output",
|
1374
|
+
"call_id": derived_call_id,
|
1375
|
+
"output": json.dumps(result),
|
1376
|
+
},
|
1377
|
+
},
|
1378
|
+
label="conversation.item.create:function_call_output",
|
1379
|
+
)
|
1380
|
+
logger.info(
|
1381
|
+
"conversation.item.create(function_call_output) sent call_id=%s",
|
1382
|
+
derived_call_id,
|
1383
|
+
)
|
1384
|
+
except Exception:
|
1385
|
+
logger.exception(
|
1386
|
+
"Failed to send function_call_output for call_id=%s", call_id
|
1387
|
+
)
|
1388
|
+
try:
|
1389
|
+
await asyncio.sleep(0.02)
|
1390
|
+
await self._send_tracked(
|
1391
|
+
{
|
1392
|
+
"type": "response.create",
|
1393
|
+
"response": {"metadata": {"type": "response"}},
|
1394
|
+
},
|
1395
|
+
label="response.create:after_tool",
|
1396
|
+
)
|
1397
|
+
logger.info("response.create sent after tool output")
|
1398
|
+
except Exception:
|
1399
|
+
logger.exception(
|
1400
|
+
"Failed to send follow-up response.create after tool output"
|
1401
|
+
)
|
1402
|
+
# Cleanup readiness event
|
1403
|
+
try:
|
1404
|
+
self._call_ready_events.pop(call_id, None)
|
1405
|
+
except Exception:
|
1406
|
+
pass
|
1407
|
+
except Exception:
|
1408
|
+
logger.exception(
|
1409
|
+
"Tool execution raised unexpectedly for call_id=%s", call_id
|
1410
|
+
)
|
1411
|
+
finally:
|
1412
|
+
# Clear pending state and decrement active count
|
1413
|
+
try:
|
1414
|
+
getattr(self, "_pending_calls", {}).pop(call_id, None)
|
1415
|
+
except Exception:
|
1416
|
+
pass
|
1417
|
+
try:
|
1418
|
+
self._active_tool_calls = max(0, int(self._active_tool_calls) - 1)
|
1419
|
+
logger.debug(
|
1420
|
+
"Pending tool calls decremented; active=%d", self._active_tool_calls
|
1421
|
+
)
|
1422
|
+
except Exception:
|
1423
|
+
pass
|
1424
|
+
|
1425
|
+
|
1426
|
+
class OpenAITranscriptionWebSocketSession(BaseRealtimeSession):
|
1427
|
+
"""OpenAI Realtime Transcription WebSocket session.
|
1428
|
+
|
1429
|
+
This session is transcription-only per GA docs. It accepts PCM16 input and emits
|
1430
|
+
conversation.item.input_audio_transcription.* events.
|
1431
|
+
"""
|
1432
|
+
|
1433
|
+
def __init__(
|
1434
|
+
self,
|
1435
|
+
api_key: str,
|
1436
|
+
url: str = "wss://api.openai.com/v1/realtime",
|
1437
|
+
options: Optional[RealtimeSessionOptions] = None,
|
1438
|
+
) -> None:
|
1439
|
+
self.api_key = api_key
|
1440
|
+
self.url = url
|
1441
|
+
self.options = options or RealtimeSessionOptions()
|
1442
|
+
self._ws = None
|
1443
|
+
self._event_queue = asyncio.Queue()
|
1444
|
+
self._in_tr_queue = asyncio.Queue()
|
1445
|
+
self._recv_task = None
|
1446
|
+
self._last_input_item_id: Optional[str] = None
|
1447
|
+
|
1448
|
+
async def connect(self) -> None: # pragma: no cover
|
1449
|
+
headers = [("Authorization", f"Bearer {self.api_key}")]
|
1450
|
+
# Model is for TTS session; transcription model is set in session update
|
1451
|
+
model = self.options.model or "gpt-realtime"
|
1452
|
+
uri = f"{self.url}?model={model}"
|
1453
|
+
logger.info("Transcription WS connecting: uri=%s", uri)
|
1454
|
+
self._ws = await websockets.connect(
|
1455
|
+
uri, additional_headers=headers, max_size=None
|
1456
|
+
)
|
1457
|
+
self._recv_task = asyncio.create_task(self._recv_loop())
|
1458
|
+
|
1459
|
+
# Transcription session config per GA
|
1460
|
+
ts_payload: Dict[str, Any] = {
|
1461
|
+
"type": "transcription_session.update",
|
1462
|
+
"input_audio_format": "pcm16",
|
1463
|
+
"input_audio_transcription": {
|
1464
|
+
"model": getattr(self.options, "transcribe_model", None)
|
1465
|
+
or "gpt-4o-mini-transcribe",
|
1466
|
+
**(
|
1467
|
+
{"prompt": getattr(self.options, "transcribe_prompt", "")}
|
1468
|
+
if getattr(self.options, "transcribe_prompt", None) is not None
|
1469
|
+
else {}
|
1470
|
+
),
|
1471
|
+
**(
|
1472
|
+
{"language": getattr(self.options, "transcribe_language", "en")}
|
1473
|
+
if getattr(self.options, "transcribe_language", None) is not None
|
1474
|
+
else {}
|
1475
|
+
),
|
1476
|
+
},
|
1477
|
+
"turn_detection": (
|
1478
|
+
{"type": "server_vad"} if self.options.vad_enabled else None
|
1479
|
+
),
|
1480
|
+
# Optionally include extra properties (e.g., logprobs)
|
1481
|
+
# "include": ["item.input_audio_transcription.logprobs"],
|
1482
|
+
}
|
1483
|
+
logger.info("Transcription WS: sending transcription_session.update")
|
1484
|
+
await self._send(ts_payload)
|
1485
|
+
|
1486
|
+
async def close(self) -> None: # pragma: no cover
|
1487
|
+
if self._ws:
|
1488
|
+
await self._ws.close()
|
1489
|
+
self._ws = None
|
1490
|
+
if self._recv_task:
|
1491
|
+
self._recv_task.cancel()
|
1492
|
+
self._recv_task = None
|
1493
|
+
|
1494
|
+
async def _send(self, payload: Dict[str, Any]) -> None: # pragma: no cover
|
1495
|
+
if not self._ws:
|
1496
|
+
raise RuntimeError("WebSocket not connected")
|
1497
|
+
await self._ws.send(json.dumps(payload))
|
1498
|
+
|
1499
|
+
async def _recv_loop(self) -> None: # pragma: no cover
|
1500
|
+
assert self._ws is not None
|
1501
|
+
try:
|
1502
|
+
async for raw in self._ws:
|
1503
|
+
try:
|
1504
|
+
data = json.loads(raw)
|
1505
|
+
etype = data.get("type")
|
1506
|
+
# Temporarily log at INFO to diagnose missing events
|
1507
|
+
logger.info("Transcription WS recv: %s", etype)
|
1508
|
+
if etype == "input_audio_buffer.committed":
|
1509
|
+
self._last_input_item_id = data.get("item_id") or data.get("id")
|
1510
|
+
if self._last_input_item_id:
|
1511
|
+
logger.info(
|
1512
|
+
"Transcription WS: input_audio_buffer committed: item_id=%s",
|
1513
|
+
self._last_input_item_id,
|
1514
|
+
)
|
1515
|
+
elif etype in (
|
1516
|
+
"conversation.item.input_audio_transcription.delta",
|
1517
|
+
"input_audio_transcription.delta",
|
1518
|
+
"response.input_audio_transcription.delta",
|
1519
|
+
) or (
|
1520
|
+
isinstance(etype, str)
|
1521
|
+
and etype.endswith("input_audio_transcription.delta")
|
1522
|
+
):
|
1523
|
+
delta = data.get("delta") or ""
|
1524
|
+
if delta:
|
1525
|
+
self._in_tr_queue.put_nowait(delta)
|
1526
|
+
logger.info("Transcription delta: %r", delta[:120])
|
1527
|
+
elif etype in (
|
1528
|
+
"conversation.item.input_audio_transcription.completed",
|
1529
|
+
"input_audio_transcription.completed",
|
1530
|
+
"response.input_audio_transcription.completed",
|
1531
|
+
) or (
|
1532
|
+
isinstance(etype, str)
|
1533
|
+
and etype.endswith("input_audio_transcription.completed")
|
1534
|
+
):
|
1535
|
+
tr = data.get("transcript") or ""
|
1536
|
+
if tr:
|
1537
|
+
self._in_tr_queue.put_nowait(tr)
|
1538
|
+
logger.debug("Transcription completed: %r", tr[:120])
|
1539
|
+
# Always publish raw events
|
1540
|
+
try:
|
1541
|
+
self._event_queue.put_nowait(data)
|
1542
|
+
except Exception:
|
1543
|
+
pass
|
1544
|
+
except Exception:
|
1545
|
+
continue
|
1546
|
+
except Exception:
|
1547
|
+
logger.exception("Transcription WS receive loop error")
|
1548
|
+
finally:
|
1549
|
+
for q in (self._in_tr_queue, self._event_queue):
|
1550
|
+
try:
|
1551
|
+
q.put_nowait(None) # type: ignore
|
1552
|
+
except Exception:
|
1553
|
+
pass
|
1554
|
+
|
1555
|
+
# --- Client events ---
|
1556
|
+
async def update_session(
|
1557
|
+
self, session_patch: Dict[str, Any]
|
1558
|
+
) -> None: # pragma: no cover
|
1559
|
+
# Allow updating transcription session fields
|
1560
|
+
patch = {"type": "transcription_session.update", **session_patch}
|
1561
|
+
await self._send(patch)
|
1562
|
+
|
1563
|
+
async def append_audio(self, pcm16_bytes: bytes) -> None: # pragma: no cover
|
1564
|
+
b64 = base64.b64encode(pcm16_bytes).decode("ascii")
|
1565
|
+
await self._send({"type": "input_audio_buffer.append", "audio": b64})
|
1566
|
+
logger.info("Transcription WS: appended bytes=%d", len(pcm16_bytes))
|
1567
|
+
|
1568
|
+
async def commit_input(self) -> None: # pragma: no cover
|
1569
|
+
await self._send({"type": "input_audio_buffer.commit"})
|
1570
|
+
logger.info("Transcription WS: input_audio_buffer.commit sent")
|
1571
|
+
|
1572
|
+
async def clear_input(self) -> None: # pragma: no cover
|
1573
|
+
await self._send({"type": "input_audio_buffer.clear"})
|
1574
|
+
|
1575
|
+
async def create_response(
|
1576
|
+
self, response_patch: Optional[Dict[str, Any]] = None
|
1577
|
+
) -> None: # pragma: no cover
|
1578
|
+
# No responses in transcription session
|
1579
|
+
return
|
1580
|
+
|
1581
|
+
# --- Streams ---
|
1582
|
+
async def _iter_queue(self, q) -> AsyncGenerator[Any, None]:
|
1583
|
+
while True:
|
1584
|
+
item = await q.get()
|
1585
|
+
if item is None:
|
1586
|
+
break
|
1587
|
+
yield item
|
1588
|
+
|
1589
|
+
def iter_events(self) -> AsyncGenerator[Dict[str, Any], None]: # pragma: no cover
|
1590
|
+
return self._iter_queue(self._event_queue)
|
1591
|
+
|
1592
|
+
def iter_output_audio(self) -> AsyncGenerator[bytes, None]: # pragma: no cover
|
1593
|
+
# No audio in transcription session
|
1594
|
+
async def _empty():
|
1595
|
+
if False:
|
1596
|
+
yield b""
|
1597
|
+
|
1598
|
+
return _empty()
|
1599
|
+
|
1600
|
+
def iter_input_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
|
1601
|
+
return self._iter_queue(self._in_tr_queue)
|
1602
|
+
|
1603
|
+
def iter_output_transcript(self) -> AsyncGenerator[str, None]: # pragma: no cover
|
1604
|
+
# No assistant transcript in transcription-only mode
|
1605
|
+
async def _empty():
|
1606
|
+
if False:
|
1607
|
+
yield ""
|
1608
|
+
|
1609
|
+
return _empty()
|
1610
|
+
|
1611
|
+
def set_tool_executor(self, executor): # pragma: no cover
|
1612
|
+
# Not applicable for transcription-only
|
1613
|
+
return
|