openai-agents 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

Files changed (50) hide show
  1. agents/__init__.py +22 -5
  2. agents/_run_impl.py +101 -22
  3. agents/agent.py +55 -7
  4. agents/agent_output.py +4 -4
  5. agents/function_schema.py +4 -0
  6. agents/guardrail.py +1 -1
  7. agents/handoffs.py +4 -4
  8. agents/items.py +4 -2
  9. agents/models/openai_chatcompletions.py +6 -1
  10. agents/models/openai_provider.py +13 -0
  11. agents/result.py +7 -0
  12. agents/run.py +10 -10
  13. agents/tool.py +34 -10
  14. agents/tracing/__init__.py +12 -0
  15. agents/tracing/create.py +122 -2
  16. agents/tracing/processors.py +2 -2
  17. agents/tracing/scope.py +1 -1
  18. agents/tracing/setup.py +1 -1
  19. agents/tracing/span_data.py +98 -2
  20. agents/tracing/spans.py +1 -1
  21. agents/tracing/traces.py +1 -1
  22. agents/tracing/util.py +5 -0
  23. agents/util/__init__.py +0 -0
  24. agents/util/_coro.py +2 -0
  25. agents/util/_error_tracing.py +16 -0
  26. agents/util/_json.py +31 -0
  27. agents/util/_pretty_print.py +56 -0
  28. agents/util/_transforms.py +11 -0
  29. agents/util/_types.py +7 -0
  30. agents/voice/__init__.py +51 -0
  31. agents/voice/events.py +47 -0
  32. agents/voice/exceptions.py +8 -0
  33. agents/voice/imports.py +11 -0
  34. agents/voice/input.py +88 -0
  35. agents/voice/model.py +193 -0
  36. agents/voice/models/__init__.py +0 -0
  37. agents/voice/models/openai_model_provider.py +97 -0
  38. agents/voice/models/openai_stt.py +457 -0
  39. agents/voice/models/openai_tts.py +54 -0
  40. agents/voice/pipeline.py +151 -0
  41. agents/voice/pipeline_config.py +46 -0
  42. agents/voice/result.py +287 -0
  43. agents/voice/utils.py +37 -0
  44. agents/voice/workflow.py +93 -0
  45. {openai_agents-0.0.4.dist-info → openai_agents-0.0.6.dist-info}/METADATA +9 -4
  46. openai_agents-0.0.6.dist-info/RECORD +70 -0
  47. agents/_utils.py +0 -61
  48. openai_agents-0.0.4.dist-info/RECORD +0 -49
  49. {openai_agents-0.0.4.dist-info → openai_agents-0.0.6.dist-info}/WHEEL +0 -0
  50. {openai_agents-0.0.4.dist-info → openai_agents-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,457 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import json
6
+ import time
7
+ from collections.abc import AsyncIterator
8
+ from dataclasses import dataclass
9
+ from typing import Any, cast
10
+
11
+ from openai import AsyncOpenAI
12
+
13
+ from agents.exceptions import AgentsException
14
+
15
+ from ... import _debug
16
+ from ...logger import logger
17
+ from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span
18
+ from ..exceptions import STTWebsocketConnectionError
19
+ from ..imports import np, npt, websockets
20
+ from ..input import AudioInput, StreamedAudioInput
21
+ from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings
22
+
23
+ EVENT_INACTIVITY_TIMEOUT = 1000 # Timeout for inactivity in event processing
24
+ SESSION_CREATION_TIMEOUT = 10 # Timeout waiting for session.created event
25
+ SESSION_UPDATE_TIMEOUT = 10 # Timeout waiting for session.updated event
26
+
27
+ DEFAULT_TURN_DETECTION = {"type": "semantic_vad"}
28
+
29
+
30
+ @dataclass
31
+ class ErrorSentinel:
32
+ error: Exception
33
+
34
+
35
+ class SessionCompleteSentinel:
36
+ pass
37
+
38
+
39
+ class WebsocketDoneSentinel:
40
+ pass
41
+
42
+
43
+ def _audio_to_base64(audio_data: list[npt.NDArray[np.int16 | np.float32]]) -> str:
44
+ concatenated_audio = np.concatenate(audio_data)
45
+ if concatenated_audio.dtype == np.float32:
46
+ # convert to int16
47
+ concatenated_audio = np.clip(concatenated_audio, -1.0, 1.0)
48
+ concatenated_audio = (concatenated_audio * 32767).astype(np.int16)
49
+ audio_bytes = concatenated_audio.tobytes()
50
+ return base64.b64encode(audio_bytes).decode("utf-8")
51
+
52
+
53
+ async def _wait_for_event(
54
+ event_queue: asyncio.Queue[dict[str, Any]], expected_types: list[str], timeout: float
55
+ ):
56
+ """
57
+ Wait for an event from event_queue whose type is in expected_types within the specified timeout.
58
+ """
59
+ start_time = time.time()
60
+ while True:
61
+ remaining = timeout - (time.time() - start_time)
62
+ if remaining <= 0:
63
+ raise TimeoutError(f"Timeout waiting for event(s): {expected_types}")
64
+ evt = await asyncio.wait_for(event_queue.get(), timeout=remaining)
65
+ evt_type = evt.get("type", "")
66
+ if evt_type in expected_types:
67
+ return evt
68
+ elif evt_type == "error":
69
+ raise Exception(f"Error event: {evt.get('error')}")
70
+
71
+
72
+ class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
73
+ """A transcription session for OpenAI's STT model."""
74
+
75
+ def __init__(
76
+ self,
77
+ input: StreamedAudioInput,
78
+ client: AsyncOpenAI,
79
+ model: str,
80
+ settings: STTModelSettings,
81
+ trace_include_sensitive_data: bool,
82
+ trace_include_sensitive_audio_data: bool,
83
+ ):
84
+ self.connected: bool = False
85
+ self._client = client
86
+ self._model = model
87
+ self._settings = settings
88
+ self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
89
+ self._trace_include_sensitive_data = trace_include_sensitive_data
90
+ self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
91
+
92
+ self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
93
+ self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
94
+ asyncio.Queue()
95
+ )
96
+ self._websocket: websockets.ClientConnection | None = None
97
+ self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
98
+ self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
99
+ self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
100
+ self._tracing_span: Span[TranscriptionSpanData] | None = None
101
+
102
+ # tasks
103
+ self._listener_task: asyncio.Task[Any] | None = None
104
+ self._process_events_task: asyncio.Task[Any] | None = None
105
+ self._stream_audio_task: asyncio.Task[Any] | None = None
106
+ self._connection_task: asyncio.Task[Any] | None = None
107
+ self._stored_exception: Exception | None = None
108
+
109
+ def _start_turn(self) -> None:
110
+ self._tracing_span = transcription_span(
111
+ model=self._model,
112
+ model_config={
113
+ "temperature": self._settings.temperature,
114
+ "language": self._settings.language,
115
+ "prompt": self._settings.prompt,
116
+ "turn_detection": self._turn_detection,
117
+ },
118
+ )
119
+ self._tracing_span.start()
120
+
121
+ def _end_turn(self, _transcript: str) -> None:
122
+ if len(_transcript) < 1:
123
+ return
124
+
125
+ if self._tracing_span:
126
+ if self._trace_include_sensitive_audio_data:
127
+ self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
128
+
129
+ self._tracing_span.span_data.input_format = "pcm"
130
+
131
+ if self._trace_include_sensitive_data:
132
+ self._tracing_span.span_data.output = _transcript
133
+
134
+ self._tracing_span.finish()
135
+ self._turn_audio_buffer = []
136
+ self._tracing_span = None
137
+
138
+ async def _event_listener(self) -> None:
139
+ assert self._websocket is not None, "Websocket not initialized"
140
+
141
+ async for message in self._websocket:
142
+ try:
143
+ event = json.loads(message)
144
+
145
+ if event.get("type") == "error":
146
+ raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
147
+
148
+ if event.get("type") in [
149
+ "session.updated",
150
+ "transcription_session.updated",
151
+ "session.created",
152
+ "transcription_session.created",
153
+ ]:
154
+ await self._state_queue.put(event)
155
+
156
+ await self._event_queue.put(event)
157
+ except Exception as e:
158
+ await self._output_queue.put(ErrorSentinel(e))
159
+ raise STTWebsocketConnectionError("Error parsing events") from e
160
+ await self._event_queue.put(WebsocketDoneSentinel())
161
+
162
+ async def _configure_session(self) -> None:
163
+ assert self._websocket is not None, "Websocket not initialized"
164
+ await self._websocket.send(
165
+ json.dumps(
166
+ {
167
+ "type": "transcription_session.update",
168
+ "session": {
169
+ "input_audio_format": "pcm16",
170
+ "input_audio_transcription": {"model": self._model},
171
+ "turn_detection": self._turn_detection,
172
+ },
173
+ }
174
+ )
175
+ )
176
+
177
+ async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
178
+ self._websocket = ws
179
+ self._listener_task = asyncio.create_task(self._event_listener())
180
+
181
+ try:
182
+ event = await _wait_for_event(
183
+ self._state_queue,
184
+ ["session.created", "transcription_session.created"],
185
+ SESSION_CREATION_TIMEOUT,
186
+ )
187
+ except TimeoutError as e:
188
+ wrapped_err = STTWebsocketConnectionError(
189
+ "Timeout waiting for transcription_session.created event"
190
+ )
191
+ await self._output_queue.put(ErrorSentinel(wrapped_err))
192
+ raise wrapped_err from e
193
+ except Exception as e:
194
+ await self._output_queue.put(ErrorSentinel(e))
195
+ raise e
196
+
197
+ await self._configure_session()
198
+
199
+ try:
200
+ event = await _wait_for_event(
201
+ self._state_queue,
202
+ ["session.updated", "transcription_session.updated"],
203
+ SESSION_UPDATE_TIMEOUT,
204
+ )
205
+ if _debug.DONT_LOG_MODEL_DATA:
206
+ logger.debug("Session updated")
207
+ else:
208
+ logger.debug(f"Session updated: {event}")
209
+ except TimeoutError as e:
210
+ wrapped_err = STTWebsocketConnectionError(
211
+ "Timeout waiting for transcription_session.updated event"
212
+ )
213
+ await self._output_queue.put(ErrorSentinel(wrapped_err))
214
+ raise wrapped_err from e
215
+ except Exception as e:
216
+ await self._output_queue.put(ErrorSentinel(e))
217
+ raise
218
+
219
+ async def _handle_events(self) -> None:
220
+ while True:
221
+ try:
222
+ event = await asyncio.wait_for(
223
+ self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
224
+ )
225
+ if isinstance(event, WebsocketDoneSentinel):
226
+ # processed all events and websocket is done
227
+ break
228
+
229
+ event_type = event.get("type", "unknown")
230
+ if event_type == "conversation.item.input_audio_transcription.completed":
231
+ transcript = cast(str, event.get("transcript", ""))
232
+ if len(transcript) > 0:
233
+ self._end_turn(transcript)
234
+ self._start_turn()
235
+ await self._output_queue.put(transcript)
236
+ await asyncio.sleep(0) # yield control
237
+ except asyncio.TimeoutError:
238
+ # No new events for a while. Assume the session is done.
239
+ break
240
+ except Exception as e:
241
+ await self._output_queue.put(ErrorSentinel(e))
242
+ raise e
243
+ await self._output_queue.put(SessionCompleteSentinel())
244
+
245
+ async def _stream_audio(
246
+ self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
247
+ ) -> None:
248
+ assert self._websocket is not None, "Websocket not initialized"
249
+ self._start_turn()
250
+ while True:
251
+ buffer = await audio_queue.get()
252
+ if buffer is None:
253
+ break
254
+
255
+ self._turn_audio_buffer.append(buffer)
256
+ try:
257
+ await self._websocket.send(
258
+ json.dumps(
259
+ {
260
+ "type": "input_audio_buffer.append",
261
+ "audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
262
+ }
263
+ )
264
+ )
265
+ except websockets.ConnectionClosed:
266
+ break
267
+ except Exception as e:
268
+ await self._output_queue.put(ErrorSentinel(e))
269
+ raise e
270
+
271
+ await asyncio.sleep(0) # yield control
272
+
273
+ async def _process_websocket_connection(self) -> None:
274
+ try:
275
+ async with websockets.connect(
276
+ "wss://api.openai.com/v1/realtime?intent=transcription",
277
+ additional_headers={
278
+ "Authorization": f"Bearer {self._client.api_key}",
279
+ "OpenAI-Beta": "realtime=v1",
280
+ "OpenAI-Log-Session": "1",
281
+ },
282
+ ) as ws:
283
+ await self._setup_connection(ws)
284
+ self._process_events_task = asyncio.create_task(self._handle_events())
285
+ self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
286
+ self.connected = True
287
+ if self._listener_task:
288
+ await self._listener_task
289
+ else:
290
+ logger.error("Listener task not initialized")
291
+ raise AgentsException("Listener task not initialized")
292
+ except Exception as e:
293
+ await self._output_queue.put(ErrorSentinel(e))
294
+ raise e
295
+
296
+ def _check_errors(self) -> None:
297
+ if self._connection_task and self._connection_task.done():
298
+ exc = self._connection_task.exception()
299
+ if exc and isinstance(exc, Exception):
300
+ self._stored_exception = exc
301
+
302
+ if self._process_events_task and self._process_events_task.done():
303
+ exc = self._process_events_task.exception()
304
+ if exc and isinstance(exc, Exception):
305
+ self._stored_exception = exc
306
+
307
+ if self._stream_audio_task and self._stream_audio_task.done():
308
+ exc = self._stream_audio_task.exception()
309
+ if exc and isinstance(exc, Exception):
310
+ self._stored_exception = exc
311
+
312
+ if self._listener_task and self._listener_task.done():
313
+ exc = self._listener_task.exception()
314
+ if exc and isinstance(exc, Exception):
315
+ self._stored_exception = exc
316
+
317
+ def _cleanup_tasks(self) -> None:
318
+ if self._listener_task and not self._listener_task.done():
319
+ self._listener_task.cancel()
320
+
321
+ if self._process_events_task and not self._process_events_task.done():
322
+ self._process_events_task.cancel()
323
+
324
+ if self._stream_audio_task and not self._stream_audio_task.done():
325
+ self._stream_audio_task.cancel()
326
+
327
+ if self._connection_task and not self._connection_task.done():
328
+ self._connection_task.cancel()
329
+
330
+ async def transcribe_turns(self) -> AsyncIterator[str]:
331
+ self._connection_task = asyncio.create_task(self._process_websocket_connection())
332
+
333
+ while True:
334
+ try:
335
+ turn = await self._output_queue.get()
336
+ except asyncio.CancelledError:
337
+ break
338
+
339
+ if (
340
+ turn is None
341
+ or isinstance(turn, ErrorSentinel)
342
+ or isinstance(turn, SessionCompleteSentinel)
343
+ ):
344
+ self._output_queue.task_done()
345
+ break
346
+ yield turn
347
+ self._output_queue.task_done()
348
+
349
+ if self._tracing_span:
350
+ self._end_turn("")
351
+
352
+ if self._websocket:
353
+ await self._websocket.close()
354
+
355
+ self._check_errors()
356
+ if self._stored_exception:
357
+ raise self._stored_exception
358
+
359
+ async def close(self) -> None:
360
+ if self._websocket:
361
+ await self._websocket.close()
362
+
363
+ self._cleanup_tasks()
364
+
365
+
366
+ class OpenAISTTModel(STTModel):
367
+ """A speech-to-text model for OpenAI."""
368
+
369
+ def __init__(
370
+ self,
371
+ model: str,
372
+ openai_client: AsyncOpenAI,
373
+ ):
374
+ """Create a new OpenAI speech-to-text model.
375
+
376
+ Args:
377
+ model: The name of the model to use.
378
+ openai_client: The OpenAI client to use.
379
+ """
380
+ self.model = model
381
+ self._client = openai_client
382
+
383
+ @property
384
+ def model_name(self) -> str:
385
+ return self.model
386
+
387
+ def _non_null_or_not_given(self, value: Any) -> Any:
388
+ return value if value is not None else None # NOT_GIVEN
389
+
390
+ async def transcribe(
391
+ self,
392
+ input: AudioInput,
393
+ settings: STTModelSettings,
394
+ trace_include_sensitive_data: bool,
395
+ trace_include_sensitive_audio_data: bool,
396
+ ) -> str:
397
+ """Transcribe an audio input.
398
+
399
+ Args:
400
+ input: The audio input to transcribe.
401
+ settings: The settings to use for the transcription.
402
+
403
+ Returns:
404
+ The transcribed text.
405
+ """
406
+ with transcription_span(
407
+ model=self.model,
408
+ input=input.to_base64() if trace_include_sensitive_audio_data else "",
409
+ input_format="pcm",
410
+ model_config={
411
+ "temperature": self._non_null_or_not_given(settings.temperature),
412
+ "language": self._non_null_or_not_given(settings.language),
413
+ "prompt": self._non_null_or_not_given(settings.prompt),
414
+ },
415
+ ) as span:
416
+ try:
417
+ response = await self._client.audio.transcriptions.create(
418
+ model=self.model,
419
+ file=input.to_audio_file(),
420
+ prompt=self._non_null_or_not_given(settings.prompt),
421
+ language=self._non_null_or_not_given(settings.language),
422
+ temperature=self._non_null_or_not_given(settings.temperature),
423
+ )
424
+ if trace_include_sensitive_data:
425
+ span.span_data.output = response.text
426
+ return response.text
427
+ except Exception as e:
428
+ span.span_data.output = ""
429
+ span.set_error(SpanError(message=str(e), data={}))
430
+ raise e
431
+
432
+ async def create_session(
433
+ self,
434
+ input: StreamedAudioInput,
435
+ settings: STTModelSettings,
436
+ trace_include_sensitive_data: bool,
437
+ trace_include_sensitive_audio_data: bool,
438
+ ) -> StreamedTranscriptionSession:
439
+ """Create a new transcription session.
440
+
441
+ Args:
442
+ input: The audio input to transcribe.
443
+ settings: The settings to use for the transcription.
444
+ trace_include_sensitive_data: Whether to include sensitive data in traces.
445
+ trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
446
+
447
+ Returns:
448
+ A new transcription session.
449
+ """
450
+ return OpenAISTTTranscriptionSession(
451
+ input,
452
+ self._client,
453
+ self.model,
454
+ settings,
455
+ trace_include_sensitive_data,
456
+ trace_include_sensitive_audio_data,
457
+ )
@@ -0,0 +1,54 @@
1
+ from collections.abc import AsyncIterator
2
+ from typing import Literal
3
+
4
+ from openai import AsyncOpenAI
5
+
6
+ from ..model import TTSModel, TTSModelSettings
7
+
8
+ DEFAULT_VOICE: Literal["ash"] = "ash"
9
+
10
+
11
+ class OpenAITTSModel(TTSModel):
12
+ """A text-to-speech model for OpenAI."""
13
+
14
+ def __init__(
15
+ self,
16
+ model: str,
17
+ openai_client: AsyncOpenAI,
18
+ ):
19
+ """Create a new OpenAI text-to-speech model.
20
+
21
+ Args:
22
+ model: The name of the model to use.
23
+ openai_client: The OpenAI client to use.
24
+ """
25
+ self.model = model
26
+ self._client = openai_client
27
+
28
+ @property
29
+ def model_name(self) -> str:
30
+ return self.model
31
+
32
+ async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
33
+ """Run the text-to-speech model.
34
+
35
+ Args:
36
+ text: The text to convert to speech.
37
+ settings: The settings to use for the text-to-speech model.
38
+
39
+ Returns:
40
+ An iterator of audio chunks.
41
+ """
42
+ response = self._client.audio.speech.with_streaming_response.create(
43
+ model=self.model,
44
+ voice=settings.voice or DEFAULT_VOICE,
45
+ input=text,
46
+ response_format="pcm",
47
+ extra_body={
48
+ "instructions": settings.instructions,
49
+ },
50
+ )
51
+
52
+ async with response as stream:
53
+ async for chunk in stream.iter_bytes(chunk_size=1024):
54
+ yield chunk
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+
5
+ from .._run_impl import TraceCtxManager
6
+ from ..exceptions import UserError
7
+ from ..logger import logger
8
+ from .input import AudioInput, StreamedAudioInput
9
+ from .model import STTModel, TTSModel
10
+ from .pipeline_config import VoicePipelineConfig
11
+ from .result import StreamedAudioResult
12
+ from .workflow import VoiceWorkflowBase
13
+
14
+
15
+ class VoicePipeline:
16
+ """An opinionated voice agent pipeline. It works in three steps:
17
+ 1. Transcribe audio input into text.
18
+ 2. Run the provided `workflow`, which produces a sequence of text responses.
19
+ 3. Convert the text responses into streaming audio output.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ *,
25
+ workflow: VoiceWorkflowBase,
26
+ stt_model: STTModel | str | None = None,
27
+ tts_model: TTSModel | str | None = None,
28
+ config: VoicePipelineConfig | None = None,
29
+ ):
30
+ """Create a new voice pipeline.
31
+
32
+ Args:
33
+ workflow: The workflow to run. See `VoiceWorkflowBase`.
34
+ stt_model: The speech-to-text model to use. If not provided, a default OpenAI
35
+ model will be used.
36
+ tts_model: The text-to-speech model to use. If not provided, a default OpenAI
37
+ model will be used.
38
+ config: The pipeline configuration. If not provided, a default configuration will be
39
+ used.
40
+ """
41
+ self.workflow = workflow
42
+ self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
43
+ self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
44
+ self._stt_model_name = stt_model if isinstance(stt_model, str) else None
45
+ self._tts_model_name = tts_model if isinstance(tts_model, str) else None
46
+ self.config = config or VoicePipelineConfig()
47
+
48
+ async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
49
+ """Run the voice pipeline.
50
+
51
+ Args:
52
+ audio_input: The audio input to process. This can either be an `AudioInput` instance,
53
+ which is a single static buffer, or a `StreamedAudioInput` instance, which is a
54
+ stream of audio data that you can append to.
55
+
56
+ Returns:
57
+ A `StreamedAudioResult` instance. You can use this object to stream audio events and
58
+ play them out.
59
+ """
60
+ if isinstance(audio_input, AudioInput):
61
+ return await self._run_single_turn(audio_input)
62
+ elif isinstance(audio_input, StreamedAudioInput):
63
+ return await self._run_multi_turn(audio_input)
64
+ else:
65
+ raise UserError(f"Unsupported audio input type: {type(audio_input)}")
66
+
67
+ def _get_tts_model(self) -> TTSModel:
68
+ if not self.tts_model:
69
+ self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
70
+ return self.tts_model
71
+
72
+ def _get_stt_model(self) -> STTModel:
73
+ if not self.stt_model:
74
+ self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
75
+ return self.stt_model
76
+
77
+ async def _process_audio_input(self, audio_input: AudioInput) -> str:
78
+ model = self._get_stt_model()
79
+ return await model.transcribe(
80
+ audio_input,
81
+ self.config.stt_settings,
82
+ self.config.trace_include_sensitive_data,
83
+ self.config.trace_include_sensitive_audio_data,
84
+ )
85
+
86
+ async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
87
+ # Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
88
+ # trace
89
+ with TraceCtxManager(
90
+ workflow_name=self.config.workflow_name or "Voice Agent",
91
+ trace_id=None, # Automatically generated
92
+ group_id=self.config.group_id,
93
+ metadata=self.config.trace_metadata,
94
+ disabled=self.config.tracing_disabled,
95
+ ):
96
+ input_text = await self._process_audio_input(audio_input)
97
+
98
+ output = StreamedAudioResult(
99
+ self._get_tts_model(), self.config.tts_settings, self.config
100
+ )
101
+
102
+ async def stream_events():
103
+ try:
104
+ async for text_event in self.workflow.run(input_text):
105
+ await output._add_text(text_event)
106
+ await output._turn_done()
107
+ await output._done()
108
+ except Exception as e:
109
+ logger.error(f"Error processing single turn: {e}")
110
+ await output._add_error(e)
111
+ raise e
112
+
113
+ output._set_task(asyncio.create_task(stream_events()))
114
+ return output
115
+
116
+ async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
117
+ with TraceCtxManager(
118
+ workflow_name=self.config.workflow_name or "Voice Agent",
119
+ trace_id=None,
120
+ group_id=self.config.group_id,
121
+ metadata=self.config.trace_metadata,
122
+ disabled=self.config.tracing_disabled,
123
+ ):
124
+ output = StreamedAudioResult(
125
+ self._get_tts_model(), self.config.tts_settings, self.config
126
+ )
127
+
128
+ transcription_session = await self._get_stt_model().create_session(
129
+ audio_input,
130
+ self.config.stt_settings,
131
+ self.config.trace_include_sensitive_data,
132
+ self.config.trace_include_sensitive_audio_data,
133
+ )
134
+
135
+ async def process_turns():
136
+ try:
137
+ async for input_text in transcription_session.transcribe_turns():
138
+ result = self.workflow.run(input_text)
139
+ async for text_event in result:
140
+ await output._add_text(text_event)
141
+ await output._turn_done()
142
+ except Exception as e:
143
+ logger.error(f"Error processing turns: {e}")
144
+ await output._add_error(e)
145
+ raise e
146
+ finally:
147
+ await transcription_session.close()
148
+ await output._done()
149
+
150
+ output._set_task(asyncio.create_task(process_turns()))
151
+ return output