openai-agents 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +16 -0
- agents/_run_impl.py +56 -6
- agents/agent.py +25 -0
- agents/extensions/visualization.py +137 -0
- agents/mcp/__init__.py +21 -0
- agents/mcp/server.py +301 -0
- agents/mcp/util.py +115 -0
- agents/models/openai_chatcompletions.py +1 -1
- agents/models/openai_provider.py +13 -0
- agents/models/openai_responses.py +6 -2
- agents/py.typed +1 -0
- agents/run.py +45 -7
- agents/tracing/__init__.py +16 -0
- agents/tracing/create.py +150 -1
- agents/tracing/processors.py +26 -7
- agents/tracing/span_data.py +128 -2
- agents/tracing/util.py +5 -0
- agents/voice/__init__.py +51 -0
- agents/voice/events.py +47 -0
- agents/voice/exceptions.py +8 -0
- agents/voice/imports.py +11 -0
- agents/voice/input.py +88 -0
- agents/voice/model.py +193 -0
- agents/voice/models/__init__.py +0 -0
- agents/voice/models/openai_model_provider.py +97 -0
- agents/voice/models/openai_stt.py +456 -0
- agents/voice/models/openai_tts.py +54 -0
- agents/voice/pipeline.py +151 -0
- agents/voice/pipeline_config.py +46 -0
- agents/voice/result.py +287 -0
- agents/voice/utils.py +37 -0
- agents/voice/workflow.py +93 -0
- {openai_agents-0.0.5.dist-info → openai_agents-0.0.7.dist-info}/METADATA +9 -1
- {openai_agents-0.0.5.dist-info → openai_agents-0.0.7.dist-info}/RECORD +36 -16
- {openai_agents-0.0.5.dist-info → openai_agents-0.0.7.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.5.dist-info → openai_agents-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, cast
|
|
10
|
+
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
|
|
13
|
+
from ... import _debug
|
|
14
|
+
from ...exceptions import AgentsException
|
|
15
|
+
from ...logger import logger
|
|
16
|
+
from ...tracing import Span, SpanError, TranscriptionSpanData, transcription_span
|
|
17
|
+
from ..exceptions import STTWebsocketConnectionError
|
|
18
|
+
from ..imports import np, npt, websockets
|
|
19
|
+
from ..input import AudioInput, StreamedAudioInput
|
|
20
|
+
from ..model import StreamedTranscriptionSession, STTModel, STTModelSettings
|
|
21
|
+
|
|
22
|
+
EVENT_INACTIVITY_TIMEOUT = 1000 # Timeout for inactivity in event processing
|
|
23
|
+
SESSION_CREATION_TIMEOUT = 10 # Timeout waiting for session.created event
|
|
24
|
+
SESSION_UPDATE_TIMEOUT = 10 # Timeout waiting for session.updated event
|
|
25
|
+
|
|
26
|
+
DEFAULT_TURN_DETECTION = {"type": "semantic_vad"}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ErrorSentinel:
|
|
31
|
+
error: Exception
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SessionCompleteSentinel:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class WebsocketDoneSentinel:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _audio_to_base64(audio_data: list[npt.NDArray[np.int16 | np.float32]]) -> str:
|
|
43
|
+
concatenated_audio = np.concatenate(audio_data)
|
|
44
|
+
if concatenated_audio.dtype == np.float32:
|
|
45
|
+
# convert to int16
|
|
46
|
+
concatenated_audio = np.clip(concatenated_audio, -1.0, 1.0)
|
|
47
|
+
concatenated_audio = (concatenated_audio * 32767).astype(np.int16)
|
|
48
|
+
audio_bytes = concatenated_audio.tobytes()
|
|
49
|
+
return base64.b64encode(audio_bytes).decode("utf-8")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def _wait_for_event(
|
|
53
|
+
event_queue: asyncio.Queue[dict[str, Any]], expected_types: list[str], timeout: float
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Wait for an event from event_queue whose type is in expected_types within the specified timeout.
|
|
57
|
+
"""
|
|
58
|
+
start_time = time.time()
|
|
59
|
+
while True:
|
|
60
|
+
remaining = timeout - (time.time() - start_time)
|
|
61
|
+
if remaining <= 0:
|
|
62
|
+
raise TimeoutError(f"Timeout waiting for event(s): {expected_types}")
|
|
63
|
+
evt = await asyncio.wait_for(event_queue.get(), timeout=remaining)
|
|
64
|
+
evt_type = evt.get("type", "")
|
|
65
|
+
if evt_type in expected_types:
|
|
66
|
+
return evt
|
|
67
|
+
elif evt_type == "error":
|
|
68
|
+
raise Exception(f"Error event: {evt.get('error')}")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class OpenAISTTTranscriptionSession(StreamedTranscriptionSession):
|
|
72
|
+
"""A transcription session for OpenAI's STT model."""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
input: StreamedAudioInput,
|
|
77
|
+
client: AsyncOpenAI,
|
|
78
|
+
model: str,
|
|
79
|
+
settings: STTModelSettings,
|
|
80
|
+
trace_include_sensitive_data: bool,
|
|
81
|
+
trace_include_sensitive_audio_data: bool,
|
|
82
|
+
):
|
|
83
|
+
self.connected: bool = False
|
|
84
|
+
self._client = client
|
|
85
|
+
self._model = model
|
|
86
|
+
self._settings = settings
|
|
87
|
+
self._turn_detection = settings.turn_detection or DEFAULT_TURN_DETECTION
|
|
88
|
+
self._trace_include_sensitive_data = trace_include_sensitive_data
|
|
89
|
+
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data
|
|
90
|
+
|
|
91
|
+
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
|
|
92
|
+
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
|
|
93
|
+
asyncio.Queue()
|
|
94
|
+
)
|
|
95
|
+
self._websocket: websockets.ClientConnection | None = None
|
|
96
|
+
self._event_queue: asyncio.Queue[dict[str, Any] | WebsocketDoneSentinel] = asyncio.Queue()
|
|
97
|
+
self._state_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
98
|
+
self._turn_audio_buffer: list[npt.NDArray[np.int16 | np.float32]] = []
|
|
99
|
+
self._tracing_span: Span[TranscriptionSpanData] | None = None
|
|
100
|
+
|
|
101
|
+
# tasks
|
|
102
|
+
self._listener_task: asyncio.Task[Any] | None = None
|
|
103
|
+
self._process_events_task: asyncio.Task[Any] | None = None
|
|
104
|
+
self._stream_audio_task: asyncio.Task[Any] | None = None
|
|
105
|
+
self._connection_task: asyncio.Task[Any] | None = None
|
|
106
|
+
self._stored_exception: Exception | None = None
|
|
107
|
+
|
|
108
|
+
def _start_turn(self) -> None:
|
|
109
|
+
self._tracing_span = transcription_span(
|
|
110
|
+
model=self._model,
|
|
111
|
+
model_config={
|
|
112
|
+
"temperature": self._settings.temperature,
|
|
113
|
+
"language": self._settings.language,
|
|
114
|
+
"prompt": self._settings.prompt,
|
|
115
|
+
"turn_detection": self._turn_detection,
|
|
116
|
+
},
|
|
117
|
+
)
|
|
118
|
+
self._tracing_span.start()
|
|
119
|
+
|
|
120
|
+
def _end_turn(self, _transcript: str) -> None:
|
|
121
|
+
if len(_transcript) < 1:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
if self._tracing_span:
|
|
125
|
+
if self._trace_include_sensitive_audio_data:
|
|
126
|
+
self._tracing_span.span_data.input = _audio_to_base64(self._turn_audio_buffer)
|
|
127
|
+
|
|
128
|
+
self._tracing_span.span_data.input_format = "pcm"
|
|
129
|
+
|
|
130
|
+
if self._trace_include_sensitive_data:
|
|
131
|
+
self._tracing_span.span_data.output = _transcript
|
|
132
|
+
|
|
133
|
+
self._tracing_span.finish()
|
|
134
|
+
self._turn_audio_buffer = []
|
|
135
|
+
self._tracing_span = None
|
|
136
|
+
|
|
137
|
+
async def _event_listener(self) -> None:
|
|
138
|
+
assert self._websocket is not None, "Websocket not initialized"
|
|
139
|
+
|
|
140
|
+
async for message in self._websocket:
|
|
141
|
+
try:
|
|
142
|
+
event = json.loads(message)
|
|
143
|
+
|
|
144
|
+
if event.get("type") == "error":
|
|
145
|
+
raise STTWebsocketConnectionError(f"Error event: {event.get('error')}")
|
|
146
|
+
|
|
147
|
+
if event.get("type") in [
|
|
148
|
+
"session.updated",
|
|
149
|
+
"transcription_session.updated",
|
|
150
|
+
"session.created",
|
|
151
|
+
"transcription_session.created",
|
|
152
|
+
]:
|
|
153
|
+
await self._state_queue.put(event)
|
|
154
|
+
|
|
155
|
+
await self._event_queue.put(event)
|
|
156
|
+
except Exception as e:
|
|
157
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
158
|
+
raise STTWebsocketConnectionError("Error parsing events") from e
|
|
159
|
+
await self._event_queue.put(WebsocketDoneSentinel())
|
|
160
|
+
|
|
161
|
+
async def _configure_session(self) -> None:
|
|
162
|
+
assert self._websocket is not None, "Websocket not initialized"
|
|
163
|
+
await self._websocket.send(
|
|
164
|
+
json.dumps(
|
|
165
|
+
{
|
|
166
|
+
"type": "transcription_session.update",
|
|
167
|
+
"session": {
|
|
168
|
+
"input_audio_format": "pcm16",
|
|
169
|
+
"input_audio_transcription": {"model": self._model},
|
|
170
|
+
"turn_detection": self._turn_detection,
|
|
171
|
+
},
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
async def _setup_connection(self, ws: websockets.ClientConnection) -> None:
|
|
177
|
+
self._websocket = ws
|
|
178
|
+
self._listener_task = asyncio.create_task(self._event_listener())
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
event = await _wait_for_event(
|
|
182
|
+
self._state_queue,
|
|
183
|
+
["session.created", "transcription_session.created"],
|
|
184
|
+
SESSION_CREATION_TIMEOUT,
|
|
185
|
+
)
|
|
186
|
+
except TimeoutError as e:
|
|
187
|
+
wrapped_err = STTWebsocketConnectionError(
|
|
188
|
+
"Timeout waiting for transcription_session.created event"
|
|
189
|
+
)
|
|
190
|
+
await self._output_queue.put(ErrorSentinel(wrapped_err))
|
|
191
|
+
raise wrapped_err from e
|
|
192
|
+
except Exception as e:
|
|
193
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
194
|
+
raise e
|
|
195
|
+
|
|
196
|
+
await self._configure_session()
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
event = await _wait_for_event(
|
|
200
|
+
self._state_queue,
|
|
201
|
+
["session.updated", "transcription_session.updated"],
|
|
202
|
+
SESSION_UPDATE_TIMEOUT,
|
|
203
|
+
)
|
|
204
|
+
if _debug.DONT_LOG_MODEL_DATA:
|
|
205
|
+
logger.debug("Session updated")
|
|
206
|
+
else:
|
|
207
|
+
logger.debug(f"Session updated: {event}")
|
|
208
|
+
except TimeoutError as e:
|
|
209
|
+
wrapped_err = STTWebsocketConnectionError(
|
|
210
|
+
"Timeout waiting for transcription_session.updated event"
|
|
211
|
+
)
|
|
212
|
+
await self._output_queue.put(ErrorSentinel(wrapped_err))
|
|
213
|
+
raise wrapped_err from e
|
|
214
|
+
except Exception as e:
|
|
215
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
216
|
+
raise
|
|
217
|
+
|
|
218
|
+
async def _handle_events(self) -> None:
|
|
219
|
+
while True:
|
|
220
|
+
try:
|
|
221
|
+
event = await asyncio.wait_for(
|
|
222
|
+
self._event_queue.get(), timeout=EVENT_INACTIVITY_TIMEOUT
|
|
223
|
+
)
|
|
224
|
+
if isinstance(event, WebsocketDoneSentinel):
|
|
225
|
+
# processed all events and websocket is done
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
event_type = event.get("type", "unknown")
|
|
229
|
+
if event_type == "conversation.item.input_audio_transcription.completed":
|
|
230
|
+
transcript = cast(str, event.get("transcript", ""))
|
|
231
|
+
if len(transcript) > 0:
|
|
232
|
+
self._end_turn(transcript)
|
|
233
|
+
self._start_turn()
|
|
234
|
+
await self._output_queue.put(transcript)
|
|
235
|
+
await asyncio.sleep(0) # yield control
|
|
236
|
+
except asyncio.TimeoutError:
|
|
237
|
+
# No new events for a while. Assume the session is done.
|
|
238
|
+
break
|
|
239
|
+
except Exception as e:
|
|
240
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
241
|
+
raise e
|
|
242
|
+
await self._output_queue.put(SessionCompleteSentinel())
|
|
243
|
+
|
|
244
|
+
async def _stream_audio(
|
|
245
|
+
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
|
|
246
|
+
) -> None:
|
|
247
|
+
assert self._websocket is not None, "Websocket not initialized"
|
|
248
|
+
self._start_turn()
|
|
249
|
+
while True:
|
|
250
|
+
buffer = await audio_queue.get()
|
|
251
|
+
if buffer is None:
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
self._turn_audio_buffer.append(buffer)
|
|
255
|
+
try:
|
|
256
|
+
await self._websocket.send(
|
|
257
|
+
json.dumps(
|
|
258
|
+
{
|
|
259
|
+
"type": "input_audio_buffer.append",
|
|
260
|
+
"audio": base64.b64encode(buffer.tobytes()).decode("utf-8"),
|
|
261
|
+
}
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
except websockets.ConnectionClosed:
|
|
265
|
+
break
|
|
266
|
+
except Exception as e:
|
|
267
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
268
|
+
raise e
|
|
269
|
+
|
|
270
|
+
await asyncio.sleep(0) # yield control
|
|
271
|
+
|
|
272
|
+
async def _process_websocket_connection(self) -> None:
|
|
273
|
+
try:
|
|
274
|
+
async with websockets.connect(
|
|
275
|
+
"wss://api.openai.com/v1/realtime?intent=transcription",
|
|
276
|
+
additional_headers={
|
|
277
|
+
"Authorization": f"Bearer {self._client.api_key}",
|
|
278
|
+
"OpenAI-Beta": "realtime=v1",
|
|
279
|
+
"OpenAI-Log-Session": "1",
|
|
280
|
+
},
|
|
281
|
+
) as ws:
|
|
282
|
+
await self._setup_connection(ws)
|
|
283
|
+
self._process_events_task = asyncio.create_task(self._handle_events())
|
|
284
|
+
self._stream_audio_task = asyncio.create_task(self._stream_audio(self._input_queue))
|
|
285
|
+
self.connected = True
|
|
286
|
+
if self._listener_task:
|
|
287
|
+
await self._listener_task
|
|
288
|
+
else:
|
|
289
|
+
logger.error("Listener task not initialized")
|
|
290
|
+
raise AgentsException("Listener task not initialized")
|
|
291
|
+
except Exception as e:
|
|
292
|
+
await self._output_queue.put(ErrorSentinel(e))
|
|
293
|
+
raise e
|
|
294
|
+
|
|
295
|
+
def _check_errors(self) -> None:
|
|
296
|
+
if self._connection_task and self._connection_task.done():
|
|
297
|
+
exc = self._connection_task.exception()
|
|
298
|
+
if exc and isinstance(exc, Exception):
|
|
299
|
+
self._stored_exception = exc
|
|
300
|
+
|
|
301
|
+
if self._process_events_task and self._process_events_task.done():
|
|
302
|
+
exc = self._process_events_task.exception()
|
|
303
|
+
if exc and isinstance(exc, Exception):
|
|
304
|
+
self._stored_exception = exc
|
|
305
|
+
|
|
306
|
+
if self._stream_audio_task and self._stream_audio_task.done():
|
|
307
|
+
exc = self._stream_audio_task.exception()
|
|
308
|
+
if exc and isinstance(exc, Exception):
|
|
309
|
+
self._stored_exception = exc
|
|
310
|
+
|
|
311
|
+
if self._listener_task and self._listener_task.done():
|
|
312
|
+
exc = self._listener_task.exception()
|
|
313
|
+
if exc and isinstance(exc, Exception):
|
|
314
|
+
self._stored_exception = exc
|
|
315
|
+
|
|
316
|
+
def _cleanup_tasks(self) -> None:
|
|
317
|
+
if self._listener_task and not self._listener_task.done():
|
|
318
|
+
self._listener_task.cancel()
|
|
319
|
+
|
|
320
|
+
if self._process_events_task and not self._process_events_task.done():
|
|
321
|
+
self._process_events_task.cancel()
|
|
322
|
+
|
|
323
|
+
if self._stream_audio_task and not self._stream_audio_task.done():
|
|
324
|
+
self._stream_audio_task.cancel()
|
|
325
|
+
|
|
326
|
+
if self._connection_task and not self._connection_task.done():
|
|
327
|
+
self._connection_task.cancel()
|
|
328
|
+
|
|
329
|
+
async def transcribe_turns(self) -> AsyncIterator[str]:
|
|
330
|
+
self._connection_task = asyncio.create_task(self._process_websocket_connection())
|
|
331
|
+
|
|
332
|
+
while True:
|
|
333
|
+
try:
|
|
334
|
+
turn = await self._output_queue.get()
|
|
335
|
+
except asyncio.CancelledError:
|
|
336
|
+
break
|
|
337
|
+
|
|
338
|
+
if (
|
|
339
|
+
turn is None
|
|
340
|
+
or isinstance(turn, ErrorSentinel)
|
|
341
|
+
or isinstance(turn, SessionCompleteSentinel)
|
|
342
|
+
):
|
|
343
|
+
self._output_queue.task_done()
|
|
344
|
+
break
|
|
345
|
+
yield turn
|
|
346
|
+
self._output_queue.task_done()
|
|
347
|
+
|
|
348
|
+
if self._tracing_span:
|
|
349
|
+
self._end_turn("")
|
|
350
|
+
|
|
351
|
+
if self._websocket:
|
|
352
|
+
await self._websocket.close()
|
|
353
|
+
|
|
354
|
+
self._check_errors()
|
|
355
|
+
if self._stored_exception:
|
|
356
|
+
raise self._stored_exception
|
|
357
|
+
|
|
358
|
+
async def close(self) -> None:
|
|
359
|
+
if self._websocket:
|
|
360
|
+
await self._websocket.close()
|
|
361
|
+
|
|
362
|
+
self._cleanup_tasks()
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class OpenAISTTModel(STTModel):
|
|
366
|
+
"""A speech-to-text model for OpenAI."""
|
|
367
|
+
|
|
368
|
+
def __init__(
|
|
369
|
+
self,
|
|
370
|
+
model: str,
|
|
371
|
+
openai_client: AsyncOpenAI,
|
|
372
|
+
):
|
|
373
|
+
"""Create a new OpenAI speech-to-text model.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
model: The name of the model to use.
|
|
377
|
+
openai_client: The OpenAI client to use.
|
|
378
|
+
"""
|
|
379
|
+
self.model = model
|
|
380
|
+
self._client = openai_client
|
|
381
|
+
|
|
382
|
+
@property
|
|
383
|
+
def model_name(self) -> str:
|
|
384
|
+
return self.model
|
|
385
|
+
|
|
386
|
+
def _non_null_or_not_given(self, value: Any) -> Any:
|
|
387
|
+
return value if value is not None else None # NOT_GIVEN
|
|
388
|
+
|
|
389
|
+
async def transcribe(
|
|
390
|
+
self,
|
|
391
|
+
input: AudioInput,
|
|
392
|
+
settings: STTModelSettings,
|
|
393
|
+
trace_include_sensitive_data: bool,
|
|
394
|
+
trace_include_sensitive_audio_data: bool,
|
|
395
|
+
) -> str:
|
|
396
|
+
"""Transcribe an audio input.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
input: The audio input to transcribe.
|
|
400
|
+
settings: The settings to use for the transcription.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
The transcribed text.
|
|
404
|
+
"""
|
|
405
|
+
with transcription_span(
|
|
406
|
+
model=self.model,
|
|
407
|
+
input=input.to_base64() if trace_include_sensitive_audio_data else "",
|
|
408
|
+
input_format="pcm",
|
|
409
|
+
model_config={
|
|
410
|
+
"temperature": self._non_null_or_not_given(settings.temperature),
|
|
411
|
+
"language": self._non_null_or_not_given(settings.language),
|
|
412
|
+
"prompt": self._non_null_or_not_given(settings.prompt),
|
|
413
|
+
},
|
|
414
|
+
) as span:
|
|
415
|
+
try:
|
|
416
|
+
response = await self._client.audio.transcriptions.create(
|
|
417
|
+
model=self.model,
|
|
418
|
+
file=input.to_audio_file(),
|
|
419
|
+
prompt=self._non_null_or_not_given(settings.prompt),
|
|
420
|
+
language=self._non_null_or_not_given(settings.language),
|
|
421
|
+
temperature=self._non_null_or_not_given(settings.temperature),
|
|
422
|
+
)
|
|
423
|
+
if trace_include_sensitive_data:
|
|
424
|
+
span.span_data.output = response.text
|
|
425
|
+
return response.text
|
|
426
|
+
except Exception as e:
|
|
427
|
+
span.span_data.output = ""
|
|
428
|
+
span.set_error(SpanError(message=str(e), data={}))
|
|
429
|
+
raise e
|
|
430
|
+
|
|
431
|
+
async def create_session(
|
|
432
|
+
self,
|
|
433
|
+
input: StreamedAudioInput,
|
|
434
|
+
settings: STTModelSettings,
|
|
435
|
+
trace_include_sensitive_data: bool,
|
|
436
|
+
trace_include_sensitive_audio_data: bool,
|
|
437
|
+
) -> StreamedTranscriptionSession:
|
|
438
|
+
"""Create a new transcription session.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
input: The audio input to transcribe.
|
|
442
|
+
settings: The settings to use for the transcription.
|
|
443
|
+
trace_include_sensitive_data: Whether to include sensitive data in traces.
|
|
444
|
+
trace_include_sensitive_audio_data: Whether to include sensitive audio data in traces.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
A new transcription session.
|
|
448
|
+
"""
|
|
449
|
+
return OpenAISTTTranscriptionSession(
|
|
450
|
+
input,
|
|
451
|
+
self._client,
|
|
452
|
+
self.model,
|
|
453
|
+
settings,
|
|
454
|
+
trace_include_sensitive_data,
|
|
455
|
+
trace_include_sensitive_audio_data,
|
|
456
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
|
|
6
|
+
from ..model import TTSModel, TTSModelSettings
|
|
7
|
+
|
|
8
|
+
DEFAULT_VOICE: Literal["ash"] = "ash"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAITTSModel(TTSModel):
|
|
12
|
+
"""A text-to-speech model for OpenAI."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model: str,
|
|
17
|
+
openai_client: AsyncOpenAI,
|
|
18
|
+
):
|
|
19
|
+
"""Create a new OpenAI text-to-speech model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model: The name of the model to use.
|
|
23
|
+
openai_client: The OpenAI client to use.
|
|
24
|
+
"""
|
|
25
|
+
self.model = model
|
|
26
|
+
self._client = openai_client
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def model_name(self) -> str:
|
|
30
|
+
return self.model
|
|
31
|
+
|
|
32
|
+
async def run(self, text: str, settings: TTSModelSettings) -> AsyncIterator[bytes]:
|
|
33
|
+
"""Run the text-to-speech model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
text: The text to convert to speech.
|
|
37
|
+
settings: The settings to use for the text-to-speech model.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
An iterator of audio chunks.
|
|
41
|
+
"""
|
|
42
|
+
response = self._client.audio.speech.with_streaming_response.create(
|
|
43
|
+
model=self.model,
|
|
44
|
+
voice=settings.voice or DEFAULT_VOICE,
|
|
45
|
+
input=text,
|
|
46
|
+
response_format="pcm",
|
|
47
|
+
extra_body={
|
|
48
|
+
"instructions": settings.instructions,
|
|
49
|
+
},
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
async with response as stream:
|
|
53
|
+
async for chunk in stream.iter_bytes(chunk_size=1024):
|
|
54
|
+
yield chunk
|
agents/voice/pipeline.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
from .._run_impl import TraceCtxManager
|
|
6
|
+
from ..exceptions import UserError
|
|
7
|
+
from ..logger import logger
|
|
8
|
+
from .input import AudioInput, StreamedAudioInput
|
|
9
|
+
from .model import STTModel, TTSModel
|
|
10
|
+
from .pipeline_config import VoicePipelineConfig
|
|
11
|
+
from .result import StreamedAudioResult
|
|
12
|
+
from .workflow import VoiceWorkflowBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VoicePipeline:
|
|
16
|
+
"""An opinionated voice agent pipeline. It works in three steps:
|
|
17
|
+
1. Transcribe audio input into text.
|
|
18
|
+
2. Run the provided `workflow`, which produces a sequence of text responses.
|
|
19
|
+
3. Convert the text responses into streaming audio output.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
workflow: VoiceWorkflowBase,
|
|
26
|
+
stt_model: STTModel | str | None = None,
|
|
27
|
+
tts_model: TTSModel | str | None = None,
|
|
28
|
+
config: VoicePipelineConfig | None = None,
|
|
29
|
+
):
|
|
30
|
+
"""Create a new voice pipeline.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
workflow: The workflow to run. See `VoiceWorkflowBase`.
|
|
34
|
+
stt_model: The speech-to-text model to use. If not provided, a default OpenAI
|
|
35
|
+
model will be used.
|
|
36
|
+
tts_model: The text-to-speech model to use. If not provided, a default OpenAI
|
|
37
|
+
model will be used.
|
|
38
|
+
config: The pipeline configuration. If not provided, a default configuration will be
|
|
39
|
+
used.
|
|
40
|
+
"""
|
|
41
|
+
self.workflow = workflow
|
|
42
|
+
self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
|
|
43
|
+
self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
|
|
44
|
+
self._stt_model_name = stt_model if isinstance(stt_model, str) else None
|
|
45
|
+
self._tts_model_name = tts_model if isinstance(tts_model, str) else None
|
|
46
|
+
self.config = config or VoicePipelineConfig()
|
|
47
|
+
|
|
48
|
+
async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
|
|
49
|
+
"""Run the voice pipeline.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
audio_input: The audio input to process. This can either be an `AudioInput` instance,
|
|
53
|
+
which is a single static buffer, or a `StreamedAudioInput` instance, which is a
|
|
54
|
+
stream of audio data that you can append to.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A `StreamedAudioResult` instance. You can use this object to stream audio events and
|
|
58
|
+
play them out.
|
|
59
|
+
"""
|
|
60
|
+
if isinstance(audio_input, AudioInput):
|
|
61
|
+
return await self._run_single_turn(audio_input)
|
|
62
|
+
elif isinstance(audio_input, StreamedAudioInput):
|
|
63
|
+
return await self._run_multi_turn(audio_input)
|
|
64
|
+
else:
|
|
65
|
+
raise UserError(f"Unsupported audio input type: {type(audio_input)}")
|
|
66
|
+
|
|
67
|
+
def _get_tts_model(self) -> TTSModel:
|
|
68
|
+
if not self.tts_model:
|
|
69
|
+
self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
|
|
70
|
+
return self.tts_model
|
|
71
|
+
|
|
72
|
+
def _get_stt_model(self) -> STTModel:
|
|
73
|
+
if not self.stt_model:
|
|
74
|
+
self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
|
|
75
|
+
return self.stt_model
|
|
76
|
+
|
|
77
|
+
async def _process_audio_input(self, audio_input: AudioInput) -> str:
|
|
78
|
+
model = self._get_stt_model()
|
|
79
|
+
return await model.transcribe(
|
|
80
|
+
audio_input,
|
|
81
|
+
self.config.stt_settings,
|
|
82
|
+
self.config.trace_include_sensitive_data,
|
|
83
|
+
self.config.trace_include_sensitive_audio_data,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
|
|
87
|
+
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
|
|
88
|
+
# trace
|
|
89
|
+
with TraceCtxManager(
|
|
90
|
+
workflow_name=self.config.workflow_name or "Voice Agent",
|
|
91
|
+
trace_id=None, # Automatically generated
|
|
92
|
+
group_id=self.config.group_id,
|
|
93
|
+
metadata=self.config.trace_metadata,
|
|
94
|
+
disabled=self.config.tracing_disabled,
|
|
95
|
+
):
|
|
96
|
+
input_text = await self._process_audio_input(audio_input)
|
|
97
|
+
|
|
98
|
+
output = StreamedAudioResult(
|
|
99
|
+
self._get_tts_model(), self.config.tts_settings, self.config
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
async def stream_events():
|
|
103
|
+
try:
|
|
104
|
+
async for text_event in self.workflow.run(input_text):
|
|
105
|
+
await output._add_text(text_event)
|
|
106
|
+
await output._turn_done()
|
|
107
|
+
await output._done()
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"Error processing single turn: {e}")
|
|
110
|
+
await output._add_error(e)
|
|
111
|
+
raise e
|
|
112
|
+
|
|
113
|
+
output._set_task(asyncio.create_task(stream_events()))
|
|
114
|
+
return output
|
|
115
|
+
|
|
116
|
+
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
|
|
117
|
+
with TraceCtxManager(
|
|
118
|
+
workflow_name=self.config.workflow_name or "Voice Agent",
|
|
119
|
+
trace_id=None,
|
|
120
|
+
group_id=self.config.group_id,
|
|
121
|
+
metadata=self.config.trace_metadata,
|
|
122
|
+
disabled=self.config.tracing_disabled,
|
|
123
|
+
):
|
|
124
|
+
output = StreamedAudioResult(
|
|
125
|
+
self._get_tts_model(), self.config.tts_settings, self.config
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
transcription_session = await self._get_stt_model().create_session(
|
|
129
|
+
audio_input,
|
|
130
|
+
self.config.stt_settings,
|
|
131
|
+
self.config.trace_include_sensitive_data,
|
|
132
|
+
self.config.trace_include_sensitive_audio_data,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
async def process_turns():
|
|
136
|
+
try:
|
|
137
|
+
async for input_text in transcription_session.transcribe_turns():
|
|
138
|
+
result = self.workflow.run(input_text)
|
|
139
|
+
async for text_event in result:
|
|
140
|
+
await output._add_text(text_event)
|
|
141
|
+
await output._turn_done()
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Error processing turns: {e}")
|
|
144
|
+
await output._add_error(e)
|
|
145
|
+
raise e
|
|
146
|
+
finally:
|
|
147
|
+
await transcription_session.close()
|
|
148
|
+
await output._done()
|
|
149
|
+
|
|
150
|
+
output._set_task(asyncio.create_task(process_turns()))
|
|
151
|
+
return output
|