openai-agents 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/agent.py +1 -28
- agents/agent_output.py +1 -1
- agents/extensions/models/litellm_model.py +14 -1
- agents/function_schema.py +3 -3
- agents/guardrail.py +9 -4
- agents/items.py +2 -1
- agents/model_settings.py +2 -1
- agents/models/chatcmpl_converter.py +12 -1
- agents/models/chatcmpl_stream_handler.py +17 -14
- agents/realtime/__init__.py +4 -0
- agents/realtime/_default_tracker.py +47 -0
- agents/realtime/_util.py +9 -0
- agents/realtime/events.py +18 -0
- agents/realtime/model.py +94 -0
- agents/realtime/model_events.py +28 -0
- agents/realtime/openai_realtime.py +97 -29
- agents/realtime/session.py +37 -10
- agents/tool.py +5 -0
- agents/tracing/create.py +1 -2
- agents/tracing/processors.py +4 -5
- agents/tracing/traces.py +1 -1
- agents/usage.py +2 -1
- {openai_agents-0.2.2.dist-info → openai_agents-0.2.4.dist-info}/METADATA +116 -112
- {openai_agents-0.2.2.dist-info → openai_agents-0.2.4.dist-info}/RECORD +26 -24
- {openai_agents-0.2.2.dist-info → openai_agents-0.2.4.dist-info}/WHEEL +0 -0
- {openai_agents-0.2.2.dist-info → openai_agents-0.2.4.dist-info}/licenses/LICENSE +0 -0
agents/agent.py
CHANGED
|
@@ -214,7 +214,7 @@ class Agent(AgentBase, Generic[TContext]):
|
|
|
214
214
|
calls result in a final output.
|
|
215
215
|
|
|
216
216
|
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
|
217
|
-
web search, etc are always processed by the LLM.
|
|
217
|
+
web search, etc. are always processed by the LLM.
|
|
218
218
|
"""
|
|
219
219
|
|
|
220
220
|
reset_tool_choice: bool = True
|
|
@@ -289,30 +289,3 @@ class Agent(AgentBase, Generic[TContext]):
|
|
|
289
289
|
) -> ResponsePromptParam | None:
|
|
290
290
|
"""Get the prompt for the agent."""
|
|
291
291
|
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
|
292
|
-
|
|
293
|
-
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
|
|
294
|
-
"""Fetches the available tools from the MCP servers."""
|
|
295
|
-
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
|
296
|
-
return await MCPUtil.get_all_function_tools(
|
|
297
|
-
self.mcp_servers, convert_schemas_to_strict, run_context, self
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
|
301
|
-
"""All agent tools, including MCP tools and function tools."""
|
|
302
|
-
mcp_tools = await self.get_mcp_tools(run_context)
|
|
303
|
-
|
|
304
|
-
async def _check_tool_enabled(tool: Tool) -> bool:
|
|
305
|
-
if not isinstance(tool, FunctionTool):
|
|
306
|
-
return True
|
|
307
|
-
|
|
308
|
-
attr = tool.is_enabled
|
|
309
|
-
if isinstance(attr, bool):
|
|
310
|
-
return attr
|
|
311
|
-
res = attr(run_context, self)
|
|
312
|
-
if inspect.isawaitable(res):
|
|
313
|
-
return bool(await res)
|
|
314
|
-
return bool(res)
|
|
315
|
-
|
|
316
|
-
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
|
317
|
-
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
|
318
|
-
return [*mcp_tools, *enabled]
|
agents/agent_output.py
CHANGED
|
@@ -116,7 +116,7 @@ class AgentOutputSchema(AgentOutputSchemaBase):
|
|
|
116
116
|
raise UserError(
|
|
117
117
|
"Strict JSON schema is enabled, but the output type is not valid. "
|
|
118
118
|
"Either make the output type strict, "
|
|
119
|
-
"or wrap your type with AgentOutputSchema(
|
|
119
|
+
"or wrap your type with AgentOutputSchema(YourType, strict_json_schema=False)"
|
|
120
120
|
) from e
|
|
121
121
|
|
|
122
122
|
def is_plain_text(self) -> bool:
|
|
@@ -45,6 +45,14 @@ from ...tracing.spans import Span
|
|
|
45
45
|
from ...usage import Usage
|
|
46
46
|
|
|
47
47
|
|
|
48
|
+
class InternalChatCompletionMessage(ChatCompletionMessage):
|
|
49
|
+
"""
|
|
50
|
+
An internal subclass to carry reasoning_content without modifying the original model.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
reasoning_content: str
|
|
54
|
+
|
|
55
|
+
|
|
48
56
|
class LitellmModel(Model):
|
|
49
57
|
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
|
|
50
58
|
Anthropic, Gemini, Mistral, and many other models.
|
|
@@ -364,13 +372,18 @@ class LitellmConverter:
|
|
|
364
372
|
provider_specific_fields.get("refusal", None) if provider_specific_fields else None
|
|
365
373
|
)
|
|
366
374
|
|
|
367
|
-
|
|
375
|
+
reasoning_content = ""
|
|
376
|
+
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
377
|
+
reasoning_content = message.reasoning_content
|
|
378
|
+
|
|
379
|
+
return InternalChatCompletionMessage(
|
|
368
380
|
content=message.content,
|
|
369
381
|
refusal=refusal,
|
|
370
382
|
role="assistant",
|
|
371
383
|
annotations=cls.convert_annotations_to_openai(message),
|
|
372
384
|
audio=message.get("audio", None), # litellm deletes audio if not present
|
|
373
385
|
tool_calls=tool_calls,
|
|
386
|
+
reasoning_content=reasoning_content,
|
|
374
387
|
)
|
|
375
388
|
|
|
376
389
|
@classmethod
|
agents/function_schema.py
CHANGED
|
@@ -76,7 +76,7 @@ class FuncSchema:
|
|
|
76
76
|
|
|
77
77
|
@dataclass
|
|
78
78
|
class FuncDocumentation:
|
|
79
|
-
"""Contains metadata about a
|
|
79
|
+
"""Contains metadata about a Python function, extracted from its docstring."""
|
|
80
80
|
|
|
81
81
|
name: str
|
|
82
82
|
"""The name of the function, via `__name__`."""
|
|
@@ -194,7 +194,7 @@ def function_schema(
|
|
|
194
194
|
strict_json_schema: bool = True,
|
|
195
195
|
) -> FuncSchema:
|
|
196
196
|
"""
|
|
197
|
-
Given a
|
|
197
|
+
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
|
|
198
198
|
parameter descriptions, and other metadata.
|
|
199
199
|
|
|
200
200
|
Args:
|
|
@@ -208,7 +208,7 @@ def function_schema(
|
|
|
208
208
|
descriptions.
|
|
209
209
|
strict_json_schema: Whether the JSON schema is in strict mode. If True, we'll ensure that
|
|
210
210
|
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
|
|
211
|
-
recommend setting this to True, as it increases the likelihood of the LLM
|
|
211
|
+
recommend setting this to True, as it increases the likelihood of the LLM producing
|
|
212
212
|
correct JSON input.
|
|
213
213
|
|
|
214
214
|
Returns:
|
agents/guardrail.py
CHANGED
|
@@ -78,8 +78,9 @@ class InputGuardrail(Generic[TContext]):
|
|
|
78
78
|
You can use the `@input_guardrail()` decorator to turn a function into an `InputGuardrail`, or
|
|
79
79
|
create an `InputGuardrail` manually.
|
|
80
80
|
|
|
81
|
-
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
|
|
82
|
-
execution will immediately stop and
|
|
81
|
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
|
|
82
|
+
the agent's execution will immediately stop, and
|
|
83
|
+
an `InputGuardrailTripwireTriggered` exception will be raised
|
|
83
84
|
"""
|
|
84
85
|
|
|
85
86
|
guardrail_function: Callable[
|
|
@@ -132,7 +133,7 @@ class OutputGuardrail(Generic[TContext]):
|
|
|
132
133
|
You can use the `@output_guardrail()` decorator to turn a function into an `OutputGuardrail`,
|
|
133
134
|
or create an `OutputGuardrail` manually.
|
|
134
135
|
|
|
135
|
-
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`,
|
|
136
|
+
Guardrails return a `GuardrailResult`. If `result.tripwire_triggered` is `True`, an
|
|
136
137
|
`OutputGuardrailTripwireTriggered` exception will be raised.
|
|
137
138
|
"""
|
|
138
139
|
|
|
@@ -314,7 +315,11 @@ def output_guardrail(
|
|
|
314
315
|
def decorator(
|
|
315
316
|
f: _OutputGuardrailFuncSync[TContext_co] | _OutputGuardrailFuncAsync[TContext_co],
|
|
316
317
|
) -> OutputGuardrail[TContext_co]:
|
|
317
|
-
return OutputGuardrail(
|
|
318
|
+
return OutputGuardrail(
|
|
319
|
+
guardrail_function=f,
|
|
320
|
+
# Guardrail name defaults to function's name when not specified (None).
|
|
321
|
+
name=name if name else f.__name__,
|
|
322
|
+
)
|
|
318
323
|
|
|
319
324
|
if func is not None:
|
|
320
325
|
# Decorator was used without parentheses
|
agents/items.py
CHANGED
|
@@ -5,6 +5,7 @@ import copy
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union
|
|
7
7
|
|
|
8
|
+
import pydantic
|
|
8
9
|
from openai.types.responses import (
|
|
9
10
|
Response,
|
|
10
11
|
ResponseComputerToolCall,
|
|
@@ -212,7 +213,7 @@ RunItem: TypeAlias = Union[
|
|
|
212
213
|
"""An item generated by an agent."""
|
|
213
214
|
|
|
214
215
|
|
|
215
|
-
@dataclass
|
|
216
|
+
@pydantic.dataclasses.dataclass
|
|
216
217
|
class ModelResponse:
|
|
217
218
|
output: list[TResponseOutputItem]
|
|
218
219
|
"""A list of outputs (messages, tool calls, etc) generated by the model"""
|
agents/model_settings.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
from collections.abc import Mapping
|
|
5
|
-
from dataclasses import
|
|
5
|
+
from dataclasses import fields, replace
|
|
6
6
|
from typing import Annotated, Any, Literal, Union
|
|
7
7
|
|
|
8
8
|
from openai import Omit as _Omit
|
|
@@ -10,6 +10,7 @@ from openai._types import Body, Query
|
|
|
10
10
|
from openai.types.responses import ResponseIncludable
|
|
11
11
|
from openai.types.shared import Reasoning
|
|
12
12
|
from pydantic import BaseModel, GetCoreSchemaHandler
|
|
13
|
+
from pydantic.dataclasses import dataclass
|
|
13
14
|
from pydantic_core import core_schema
|
|
14
15
|
from typing_extensions import TypeAlias
|
|
15
16
|
|
|
@@ -36,6 +36,7 @@ from openai.types.responses import (
|
|
|
36
36
|
ResponseOutputRefusal,
|
|
37
37
|
ResponseOutputText,
|
|
38
38
|
ResponseReasoningItem,
|
|
39
|
+
ResponseReasoningItemParam,
|
|
39
40
|
)
|
|
40
41
|
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
|
|
41
42
|
from openai.types.responses.response_reasoning_item import Summary
|
|
@@ -210,6 +211,12 @@ class Converter:
|
|
|
210
211
|
return cast(ResponseOutputMessageParam, item)
|
|
211
212
|
return None
|
|
212
213
|
|
|
214
|
+
@classmethod
|
|
215
|
+
def maybe_reasoning_message(cls, item: Any) -> ResponseReasoningItemParam | None:
|
|
216
|
+
if isinstance(item, dict) and item.get("type") == "reasoning":
|
|
217
|
+
return cast(ResponseReasoningItemParam, item)
|
|
218
|
+
return None
|
|
219
|
+
|
|
213
220
|
@classmethod
|
|
214
221
|
def extract_text_content(
|
|
215
222
|
cls, content: str | Iterable[ResponseInputContentParam]
|
|
@@ -459,7 +466,11 @@ class Converter:
|
|
|
459
466
|
f"Encountered an item_reference, which is not supported: {item_ref}"
|
|
460
467
|
)
|
|
461
468
|
|
|
462
|
-
# 7)
|
|
469
|
+
# 7) reasoning message => not handled
|
|
470
|
+
elif cls.maybe_reasoning_message(item):
|
|
471
|
+
pass
|
|
472
|
+
|
|
473
|
+
# 8) If we haven't recognized it => fail or ignore
|
|
463
474
|
else:
|
|
464
475
|
raise UserError(f"Unhandled item type or structure: {item}")
|
|
465
476
|
|
|
@@ -198,6 +198,7 @@ class ChatCmplStreamHandler:
|
|
|
198
198
|
is not None, # fixed 0 -> 0 or 1
|
|
199
199
|
type="response.output_text.delta",
|
|
200
200
|
sequence_number=sequence_number.get_and_increment(),
|
|
201
|
+
logprobs=[],
|
|
201
202
|
)
|
|
202
203
|
# Accumulate the text into the response part
|
|
203
204
|
state.text_content_index_and_output[1].text += delta.content
|
|
@@ -288,10 +289,11 @@ class ChatCmplStreamHandler:
|
|
|
288
289
|
function_call = state.function_calls[tc_delta.index]
|
|
289
290
|
|
|
290
291
|
# Start streaming as soon as we have function name and call_id
|
|
291
|
-
if (
|
|
292
|
-
|
|
293
|
-
function_call.
|
|
294
|
-
|
|
292
|
+
if (
|
|
293
|
+
not state.function_call_streaming[tc_delta.index]
|
|
294
|
+
and function_call.name
|
|
295
|
+
and function_call.call_id
|
|
296
|
+
):
|
|
295
297
|
# Calculate the output index for this function call
|
|
296
298
|
function_call_starting_index = 0
|
|
297
299
|
if state.reasoning_content_index_and_output:
|
|
@@ -308,9 +310,9 @@ class ChatCmplStreamHandler:
|
|
|
308
310
|
|
|
309
311
|
# Mark this function call as streaming and store its output index
|
|
310
312
|
state.function_call_streaming[tc_delta.index] = True
|
|
311
|
-
state.function_call_output_idx[
|
|
312
|
-
|
|
313
|
-
|
|
313
|
+
state.function_call_output_idx[tc_delta.index] = (
|
|
314
|
+
function_call_starting_index
|
|
315
|
+
)
|
|
314
316
|
|
|
315
317
|
# Send initial function call added event
|
|
316
318
|
yield ResponseOutputItemAddedEvent(
|
|
@@ -327,10 +329,11 @@ class ChatCmplStreamHandler:
|
|
|
327
329
|
)
|
|
328
330
|
|
|
329
331
|
# Stream arguments if we've started streaming this function call
|
|
330
|
-
if (
|
|
331
|
-
|
|
332
|
-
tc_function
|
|
333
|
-
|
|
332
|
+
if (
|
|
333
|
+
state.function_call_streaming.get(tc_delta.index, False)
|
|
334
|
+
and tc_function
|
|
335
|
+
and tc_function.arguments
|
|
336
|
+
):
|
|
334
337
|
output_index = state.function_call_output_idx[tc_delta.index]
|
|
335
338
|
yield ResponseFunctionCallArgumentsDeltaEvent(
|
|
336
339
|
delta=tc_function.arguments,
|
|
@@ -493,9 +496,9 @@ class ChatCmplStreamHandler:
|
|
|
493
496
|
final_response.output = outputs
|
|
494
497
|
final_response.usage = (
|
|
495
498
|
ResponseUsage(
|
|
496
|
-
input_tokens=usage.prompt_tokens,
|
|
497
|
-
output_tokens=usage.completion_tokens,
|
|
498
|
-
total_tokens=usage.total_tokens,
|
|
499
|
+
input_tokens=usage.prompt_tokens or 0,
|
|
500
|
+
output_tokens=usage.completion_tokens or 0,
|
|
501
|
+
total_tokens=usage.total_tokens or 0,
|
|
499
502
|
output_tokens_details=OutputTokensDetails(
|
|
500
503
|
reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
|
|
501
504
|
if usage.completion_tokens_details
|
agents/realtime/__init__.py
CHANGED
|
@@ -47,6 +47,8 @@ from .model import (
|
|
|
47
47
|
RealtimeModel,
|
|
48
48
|
RealtimeModelConfig,
|
|
49
49
|
RealtimeModelListener,
|
|
50
|
+
RealtimePlaybackState,
|
|
51
|
+
RealtimePlaybackTracker,
|
|
50
52
|
)
|
|
51
53
|
from .model_events import (
|
|
52
54
|
RealtimeConnectionStatus,
|
|
@@ -139,6 +141,8 @@ __all__ = [
|
|
|
139
141
|
"RealtimeModel",
|
|
140
142
|
"RealtimeModelConfig",
|
|
141
143
|
"RealtimeModelListener",
|
|
144
|
+
"RealtimePlaybackTracker",
|
|
145
|
+
"RealtimePlaybackState",
|
|
142
146
|
# Model Events
|
|
143
147
|
"RealtimeConnectionStatus",
|
|
144
148
|
"RealtimeModelAudioDoneEvent",
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from ._util import calculate_audio_length_ms
|
|
7
|
+
from .config import RealtimeAudioFormat
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class ModelAudioState:
|
|
12
|
+
initial_received_time: datetime
|
|
13
|
+
audio_length_ms: float
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelAudioTracker:
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
# (item_id, item_content_index) -> ModelAudioState
|
|
19
|
+
self._states: dict[tuple[str, int], ModelAudioState] = {}
|
|
20
|
+
self._last_audio_item: tuple[str, int] | None = None
|
|
21
|
+
|
|
22
|
+
def set_audio_format(self, format: RealtimeAudioFormat) -> None:
|
|
23
|
+
"""Called when the model wants to set the audio format."""
|
|
24
|
+
self._format = format
|
|
25
|
+
|
|
26
|
+
def on_audio_delta(self, item_id: str, item_content_index: int, audio_bytes: bytes) -> None:
|
|
27
|
+
"""Called when an audio delta is received from the model."""
|
|
28
|
+
ms = calculate_audio_length_ms(self._format, audio_bytes)
|
|
29
|
+
new_key = (item_id, item_content_index)
|
|
30
|
+
|
|
31
|
+
self._last_audio_item = new_key
|
|
32
|
+
if new_key not in self._states:
|
|
33
|
+
self._states[new_key] = ModelAudioState(datetime.now(), ms)
|
|
34
|
+
else:
|
|
35
|
+
self._states[new_key].audio_length_ms += ms
|
|
36
|
+
|
|
37
|
+
def on_interrupted(self) -> None:
|
|
38
|
+
"""Called when the audio playback has been interrupted."""
|
|
39
|
+
self._last_audio_item = None
|
|
40
|
+
|
|
41
|
+
def get_state(self, item_id: str, item_content_index: int) -> ModelAudioState | None:
|
|
42
|
+
"""Called when the model wants to get the current playback state."""
|
|
43
|
+
return self._states.get((item_id, item_content_index))
|
|
44
|
+
|
|
45
|
+
def get_last_audio_item(self) -> tuple[str, int] | None:
|
|
46
|
+
"""Called when the model wants to get the last audio item ID and content index."""
|
|
47
|
+
return self._last_audio_item
|
agents/realtime/_util.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .config import RealtimeAudioFormat
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float:
|
|
7
|
+
if format and format.startswith("g711"):
|
|
8
|
+
return (len(audio_bytes) / 8000) * 1000
|
|
9
|
+
return (len(audio_bytes) / 24 / 2) * 1000
|
agents/realtime/events.py
CHANGED
|
@@ -115,6 +115,12 @@ class RealtimeAudioEnd:
|
|
|
115
115
|
info: RealtimeEventInfo
|
|
116
116
|
"""Common info for all events, such as the context."""
|
|
117
117
|
|
|
118
|
+
item_id: str
|
|
119
|
+
"""The ID of the item containing audio."""
|
|
120
|
+
|
|
121
|
+
content_index: int
|
|
122
|
+
"""The index of the audio content in `item.content`"""
|
|
123
|
+
|
|
118
124
|
type: Literal["audio_end"] = "audio_end"
|
|
119
125
|
|
|
120
126
|
|
|
@@ -125,6 +131,12 @@ class RealtimeAudio:
|
|
|
125
131
|
audio: RealtimeModelAudioEvent
|
|
126
132
|
"""The audio event from the model layer."""
|
|
127
133
|
|
|
134
|
+
item_id: str
|
|
135
|
+
"""The ID of the item containing audio."""
|
|
136
|
+
|
|
137
|
+
content_index: int
|
|
138
|
+
"""The index of the audio content in `item.content`"""
|
|
139
|
+
|
|
128
140
|
info: RealtimeEventInfo
|
|
129
141
|
"""Common info for all events, such as the context."""
|
|
130
142
|
|
|
@@ -140,6 +152,12 @@ class RealtimeAudioInterrupted:
|
|
|
140
152
|
info: RealtimeEventInfo
|
|
141
153
|
"""Common info for all events, such as the context."""
|
|
142
154
|
|
|
155
|
+
item_id: str
|
|
156
|
+
"""The ID of the item containing audio."""
|
|
157
|
+
|
|
158
|
+
content_index: int
|
|
159
|
+
"""The index of the audio content in `item.content`"""
|
|
160
|
+
|
|
143
161
|
type: Literal["audio_interrupted"] = "audio_interrupted"
|
|
144
162
|
|
|
145
163
|
|
agents/realtime/model.py
CHANGED
|
@@ -6,13 +6,95 @@ from typing import Callable
|
|
|
6
6
|
from typing_extensions import NotRequired, TypedDict
|
|
7
7
|
|
|
8
8
|
from ..util._types import MaybeAwaitable
|
|
9
|
+
from ._util import calculate_audio_length_ms
|
|
9
10
|
from .config import (
|
|
11
|
+
RealtimeAudioFormat,
|
|
10
12
|
RealtimeSessionModelSettings,
|
|
11
13
|
)
|
|
12
14
|
from .model_events import RealtimeModelEvent
|
|
13
15
|
from .model_inputs import RealtimeModelSendEvent
|
|
14
16
|
|
|
15
17
|
|
|
18
|
+
class RealtimePlaybackState(TypedDict):
|
|
19
|
+
current_item_id: str | None
|
|
20
|
+
"""The item ID of the current item being played."""
|
|
21
|
+
|
|
22
|
+
current_item_content_index: int | None
|
|
23
|
+
"""The index of the current item content being played."""
|
|
24
|
+
|
|
25
|
+
elapsed_ms: float | None
|
|
26
|
+
"""The number of milliseconds of audio that have been played."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RealtimePlaybackTracker:
|
|
30
|
+
"""If you have custom playback logic or expect that audio is played with delays or at different
|
|
31
|
+
speeds, create an instance of RealtimePlaybackTracker and pass it to the session. You are
|
|
32
|
+
responsible for tracking the audio playback progress and calling `on_play_bytes` or
|
|
33
|
+
`on_play_ms` when the user has played some audio."""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self._format: RealtimeAudioFormat | None = None
|
|
37
|
+
# (item_id, item_content_index)
|
|
38
|
+
self._current_item: tuple[str, int] | None = None
|
|
39
|
+
self._elapsed_ms: float | None = None
|
|
40
|
+
|
|
41
|
+
def on_play_bytes(self, item_id: str, item_content_index: int, bytes: bytes) -> None:
|
|
42
|
+
"""Called by you when you have played some audio.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
item_id: The item ID of the audio being played.
|
|
46
|
+
item_content_index: The index of the audio content in `item.content`
|
|
47
|
+
bytes: The audio bytes that have been fully played.
|
|
48
|
+
"""
|
|
49
|
+
ms = calculate_audio_length_ms(self._format, bytes)
|
|
50
|
+
self.on_play_ms(item_id, item_content_index, ms)
|
|
51
|
+
|
|
52
|
+
def on_play_ms(self, item_id: str, item_content_index: int, ms: float) -> None:
|
|
53
|
+
"""Called by you when you have played some audio.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
item_id: The item ID of the audio being played.
|
|
57
|
+
item_content_index: The index of the audio content in `item.content`
|
|
58
|
+
ms: The number of milliseconds of audio that have been played.
|
|
59
|
+
"""
|
|
60
|
+
if self._current_item != (item_id, item_content_index):
|
|
61
|
+
self._current_item = (item_id, item_content_index)
|
|
62
|
+
self._elapsed_ms = ms
|
|
63
|
+
else:
|
|
64
|
+
assert self._elapsed_ms is not None
|
|
65
|
+
self._elapsed_ms += ms
|
|
66
|
+
|
|
67
|
+
def on_interrupted(self) -> None:
|
|
68
|
+
"""Called by the model when the audio playback has been interrupted."""
|
|
69
|
+
self._current_item = None
|
|
70
|
+
self._elapsed_ms = None
|
|
71
|
+
|
|
72
|
+
def set_audio_format(self, format: RealtimeAudioFormat) -> None:
|
|
73
|
+
"""Will be called by the model to set the audio format.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
format: The audio format to use.
|
|
77
|
+
"""
|
|
78
|
+
self._format = format
|
|
79
|
+
|
|
80
|
+
def get_state(self) -> RealtimePlaybackState:
|
|
81
|
+
"""Will be called by the model to get the current playback state."""
|
|
82
|
+
if self._current_item is None:
|
|
83
|
+
return {
|
|
84
|
+
"current_item_id": None,
|
|
85
|
+
"current_item_content_index": None,
|
|
86
|
+
"elapsed_ms": None,
|
|
87
|
+
}
|
|
88
|
+
assert self._elapsed_ms is not None
|
|
89
|
+
|
|
90
|
+
item_id, item_content_index = self._current_item
|
|
91
|
+
return {
|
|
92
|
+
"current_item_id": item_id,
|
|
93
|
+
"current_item_content_index": item_content_index,
|
|
94
|
+
"elapsed_ms": self._elapsed_ms,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
16
98
|
class RealtimeModelListener(abc.ABC):
|
|
17
99
|
"""A listener for realtime transport events."""
|
|
18
100
|
|
|
@@ -39,6 +121,18 @@ class RealtimeModelConfig(TypedDict):
|
|
|
39
121
|
initial_model_settings: NotRequired[RealtimeSessionModelSettings]
|
|
40
122
|
"""The initial model settings to use when connecting."""
|
|
41
123
|
|
|
124
|
+
playback_tracker: NotRequired[RealtimePlaybackTracker]
|
|
125
|
+
"""The playback tracker to use when tracking audio playback progress. If not set, the model will
|
|
126
|
+
use a default implementation that assumes audio is played immediately, at realtime speed.
|
|
127
|
+
|
|
128
|
+
A playback tracker is useful for interruptions. The model generates audio much faster than
|
|
129
|
+
realtime playback speed. So if there's an interruption, its useful for the model to know how
|
|
130
|
+
much of the audio has been played by the user. In low-latency scenarios, it's fine to assume
|
|
131
|
+
that audio is played back immediately at realtime speed. But in scenarios like phone calls or
|
|
132
|
+
other remote interactions, you can set a playback tracker that lets the model know when audio
|
|
133
|
+
is played to the user.
|
|
134
|
+
"""
|
|
135
|
+
|
|
42
136
|
|
|
43
137
|
class RealtimeModel(abc.ABC):
|
|
44
138
|
"""Interface for connecting to a realtime model and sending/receiving events."""
|
agents/realtime/model_events.py
CHANGED
|
@@ -40,6 +40,12 @@ class RealtimeModelAudioEvent:
|
|
|
40
40
|
data: bytes
|
|
41
41
|
response_id: str
|
|
42
42
|
|
|
43
|
+
item_id: str
|
|
44
|
+
"""The ID of the item containing audio."""
|
|
45
|
+
|
|
46
|
+
content_index: int
|
|
47
|
+
"""The index of the audio content in `item.content`"""
|
|
48
|
+
|
|
43
49
|
type: Literal["audio"] = "audio"
|
|
44
50
|
|
|
45
51
|
|
|
@@ -47,6 +53,12 @@ class RealtimeModelAudioEvent:
|
|
|
47
53
|
class RealtimeModelAudioInterruptedEvent:
|
|
48
54
|
"""Audio interrupted."""
|
|
49
55
|
|
|
56
|
+
item_id: str
|
|
57
|
+
"""The ID of the item containing audio."""
|
|
58
|
+
|
|
59
|
+
content_index: int
|
|
60
|
+
"""The index of the audio content in `item.content`"""
|
|
61
|
+
|
|
50
62
|
type: Literal["audio_interrupted"] = "audio_interrupted"
|
|
51
63
|
|
|
52
64
|
|
|
@@ -54,6 +66,12 @@ class RealtimeModelAudioInterruptedEvent:
|
|
|
54
66
|
class RealtimeModelAudioDoneEvent:
|
|
55
67
|
"""Audio done."""
|
|
56
68
|
|
|
69
|
+
item_id: str
|
|
70
|
+
"""The ID of the item containing audio."""
|
|
71
|
+
|
|
72
|
+
content_index: int
|
|
73
|
+
"""The index of the audio content in `item.content`"""
|
|
74
|
+
|
|
57
75
|
type: Literal["audio_done"] = "audio_done"
|
|
58
76
|
|
|
59
77
|
|
|
@@ -138,6 +156,15 @@ class RealtimeModelExceptionEvent:
|
|
|
138
156
|
type: Literal["exception"] = "exception"
|
|
139
157
|
|
|
140
158
|
|
|
159
|
+
@dataclass
|
|
160
|
+
class RealtimeModelRawServerEvent:
|
|
161
|
+
"""Raw events forwarded from the server."""
|
|
162
|
+
|
|
163
|
+
data: Any
|
|
164
|
+
|
|
165
|
+
type: Literal["raw_server_event"] = "raw_server_event"
|
|
166
|
+
|
|
167
|
+
|
|
141
168
|
# TODO (rm) Add usage events
|
|
142
169
|
|
|
143
170
|
|
|
@@ -156,4 +183,5 @@ RealtimeModelEvent: TypeAlias = Union[
|
|
|
156
183
|
RealtimeModelTurnEndedEvent,
|
|
157
184
|
RealtimeModelOtherEvent,
|
|
158
185
|
RealtimeModelExceptionEvent,
|
|
186
|
+
RealtimeModelRawServerEvent,
|
|
159
187
|
]
|