openai-agents 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

@@ -0,0 +1,292 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from dataclasses import dataclass, field
5
+
6
+ from openai import AsyncStream
7
+ from openai.types.chat import ChatCompletionChunk
8
+ from openai.types.completion_usage import CompletionUsage
9
+ from openai.types.responses import (
10
+ Response,
11
+ ResponseCompletedEvent,
12
+ ResponseContentPartAddedEvent,
13
+ ResponseContentPartDoneEvent,
14
+ ResponseCreatedEvent,
15
+ ResponseFunctionCallArgumentsDeltaEvent,
16
+ ResponseFunctionToolCall,
17
+ ResponseOutputItem,
18
+ ResponseOutputItemAddedEvent,
19
+ ResponseOutputItemDoneEvent,
20
+ ResponseOutputMessage,
21
+ ResponseOutputRefusal,
22
+ ResponseOutputText,
23
+ ResponseRefusalDeltaEvent,
24
+ ResponseTextDeltaEvent,
25
+ ResponseUsage,
26
+ )
27
+ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
28
+
29
+ from ..items import TResponseStreamEvent
30
+ from .fake_id import FAKE_RESPONSES_ID
31
+
32
+
33
+ @dataclass
34
+ class StreamingState:
35
+ started: bool = False
36
+ text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
37
+ refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
38
+ function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
39
+
40
+
41
+ class ChatCmplStreamHandler:
42
+ @classmethod
43
+ async def handle_stream(
44
+ cls,
45
+ response: Response,
46
+ stream: AsyncStream[ChatCompletionChunk],
47
+ ) -> AsyncIterator[TResponseStreamEvent]:
48
+ usage: CompletionUsage | None = None
49
+ state = StreamingState()
50
+
51
+ async for chunk in stream:
52
+ if not state.started:
53
+ state.started = True
54
+ yield ResponseCreatedEvent(
55
+ response=response,
56
+ type="response.created",
57
+ )
58
+
59
+ # This is always set by the OpenAI API, but not by others e.g. LiteLLM
60
+ usage = chunk.usage if hasattr(chunk, "usage") else None
61
+
62
+ if not chunk.choices or not chunk.choices[0].delta:
63
+ continue
64
+
65
+ delta = chunk.choices[0].delta
66
+
67
+ # Handle text
68
+ if delta.content:
69
+ if not state.text_content_index_and_output:
70
+ # Initialize a content tracker for streaming text
71
+ state.text_content_index_and_output = (
72
+ 0 if not state.refusal_content_index_and_output else 1,
73
+ ResponseOutputText(
74
+ text="",
75
+ type="output_text",
76
+ annotations=[],
77
+ ),
78
+ )
79
+ # Start a new assistant message stream
80
+ assistant_item = ResponseOutputMessage(
81
+ id=FAKE_RESPONSES_ID,
82
+ content=[],
83
+ role="assistant",
84
+ type="message",
85
+ status="in_progress",
86
+ )
87
+ # Notify consumers of the start of a new output message + first content part
88
+ yield ResponseOutputItemAddedEvent(
89
+ item=assistant_item,
90
+ output_index=0,
91
+ type="response.output_item.added",
92
+ )
93
+ yield ResponseContentPartAddedEvent(
94
+ content_index=state.text_content_index_and_output[0],
95
+ item_id=FAKE_RESPONSES_ID,
96
+ output_index=0,
97
+ part=ResponseOutputText(
98
+ text="",
99
+ type="output_text",
100
+ annotations=[],
101
+ ),
102
+ type="response.content_part.added",
103
+ )
104
+ # Emit the delta for this segment of content
105
+ yield ResponseTextDeltaEvent(
106
+ content_index=state.text_content_index_and_output[0],
107
+ delta=delta.content,
108
+ item_id=FAKE_RESPONSES_ID,
109
+ output_index=0,
110
+ type="response.output_text.delta",
111
+ )
112
+ # Accumulate the text into the response part
113
+ state.text_content_index_and_output[1].text += delta.content
114
+
115
+ # Handle refusals (model declines to answer)
116
+ # This is always set by the OpenAI API, but not by others e.g. LiteLLM
117
+ if hasattr(delta, "refusal") and delta.refusal:
118
+ if not state.refusal_content_index_and_output:
119
+ # Initialize a content tracker for streaming refusal text
120
+ state.refusal_content_index_and_output = (
121
+ 0 if not state.text_content_index_and_output else 1,
122
+ ResponseOutputRefusal(refusal="", type="refusal"),
123
+ )
124
+ # Start a new assistant message if one doesn't exist yet (in-progress)
125
+ assistant_item = ResponseOutputMessage(
126
+ id=FAKE_RESPONSES_ID,
127
+ content=[],
128
+ role="assistant",
129
+ type="message",
130
+ status="in_progress",
131
+ )
132
+ # Notify downstream that assistant message + first content part are starting
133
+ yield ResponseOutputItemAddedEvent(
134
+ item=assistant_item,
135
+ output_index=0,
136
+ type="response.output_item.added",
137
+ )
138
+ yield ResponseContentPartAddedEvent(
139
+ content_index=state.refusal_content_index_and_output[0],
140
+ item_id=FAKE_RESPONSES_ID,
141
+ output_index=0,
142
+ part=ResponseOutputText(
143
+ text="",
144
+ type="output_text",
145
+ annotations=[],
146
+ ),
147
+ type="response.content_part.added",
148
+ )
149
+ # Emit the delta for this segment of refusal
150
+ yield ResponseRefusalDeltaEvent(
151
+ content_index=state.refusal_content_index_and_output[0],
152
+ delta=delta.refusal,
153
+ item_id=FAKE_RESPONSES_ID,
154
+ output_index=0,
155
+ type="response.refusal.delta",
156
+ )
157
+ # Accumulate the refusal string in the output part
158
+ state.refusal_content_index_and_output[1].refusal += delta.refusal
159
+
160
+ # Handle tool calls
161
+ # Because we don't know the name of the function until the end of the stream, we'll
162
+ # save everything and yield events at the end
163
+ if delta.tool_calls:
164
+ for tc_delta in delta.tool_calls:
165
+ if tc_delta.index not in state.function_calls:
166
+ state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
167
+ id=FAKE_RESPONSES_ID,
168
+ arguments="",
169
+ name="",
170
+ type="function_call",
171
+ call_id="",
172
+ )
173
+ tc_function = tc_delta.function
174
+
175
+ state.function_calls[tc_delta.index].arguments += (
176
+ tc_function.arguments if tc_function else ""
177
+ ) or ""
178
+ state.function_calls[tc_delta.index].name += (
179
+ tc_function.name if tc_function else ""
180
+ ) or ""
181
+ state.function_calls[tc_delta.index].call_id += tc_delta.id or ""
182
+
183
+ function_call_starting_index = 0
184
+ if state.text_content_index_and_output:
185
+ function_call_starting_index += 1
186
+ # Send end event for this content part
187
+ yield ResponseContentPartDoneEvent(
188
+ content_index=state.text_content_index_and_output[0],
189
+ item_id=FAKE_RESPONSES_ID,
190
+ output_index=0,
191
+ part=state.text_content_index_and_output[1],
192
+ type="response.content_part.done",
193
+ )
194
+
195
+ if state.refusal_content_index_and_output:
196
+ function_call_starting_index += 1
197
+ # Send end event for this content part
198
+ yield ResponseContentPartDoneEvent(
199
+ content_index=state.refusal_content_index_and_output[0],
200
+ item_id=FAKE_RESPONSES_ID,
201
+ output_index=0,
202
+ part=state.refusal_content_index_and_output[1],
203
+ type="response.content_part.done",
204
+ )
205
+
206
+ # Actually send events for the function calls
207
+ for function_call in state.function_calls.values():
208
+ # First, a ResponseOutputItemAdded for the function call
209
+ yield ResponseOutputItemAddedEvent(
210
+ item=ResponseFunctionToolCall(
211
+ id=FAKE_RESPONSES_ID,
212
+ call_id=function_call.call_id,
213
+ arguments=function_call.arguments,
214
+ name=function_call.name,
215
+ type="function_call",
216
+ ),
217
+ output_index=function_call_starting_index,
218
+ type="response.output_item.added",
219
+ )
220
+ # Then, yield the args
221
+ yield ResponseFunctionCallArgumentsDeltaEvent(
222
+ delta=function_call.arguments,
223
+ item_id=FAKE_RESPONSES_ID,
224
+ output_index=function_call_starting_index,
225
+ type="response.function_call_arguments.delta",
226
+ )
227
+ # Finally, the ResponseOutputItemDone
228
+ yield ResponseOutputItemDoneEvent(
229
+ item=ResponseFunctionToolCall(
230
+ id=FAKE_RESPONSES_ID,
231
+ call_id=function_call.call_id,
232
+ arguments=function_call.arguments,
233
+ name=function_call.name,
234
+ type="function_call",
235
+ ),
236
+ output_index=function_call_starting_index,
237
+ type="response.output_item.done",
238
+ )
239
+
240
+ # Finally, send the Response completed event
241
+ outputs: list[ResponseOutputItem] = []
242
+ if state.text_content_index_and_output or state.refusal_content_index_and_output:
243
+ assistant_msg = ResponseOutputMessage(
244
+ id=FAKE_RESPONSES_ID,
245
+ content=[],
246
+ role="assistant",
247
+ type="message",
248
+ status="completed",
249
+ )
250
+ if state.text_content_index_and_output:
251
+ assistant_msg.content.append(state.text_content_index_and_output[1])
252
+ if state.refusal_content_index_and_output:
253
+ assistant_msg.content.append(state.refusal_content_index_and_output[1])
254
+ outputs.append(assistant_msg)
255
+
256
+ # send a ResponseOutputItemDone for the assistant message
257
+ yield ResponseOutputItemDoneEvent(
258
+ item=assistant_msg,
259
+ output_index=0,
260
+ type="response.output_item.done",
261
+ )
262
+
263
+ for function_call in state.function_calls.values():
264
+ outputs.append(function_call)
265
+
266
+ final_response = response.model_copy()
267
+ final_response.output = outputs
268
+ final_response.usage = (
269
+ ResponseUsage(
270
+ input_tokens=usage.prompt_tokens,
271
+ output_tokens=usage.completion_tokens,
272
+ total_tokens=usage.total_tokens,
273
+ output_tokens_details=OutputTokensDetails(
274
+ reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
275
+ if usage.completion_tokens_details
276
+ and usage.completion_tokens_details.reasoning_tokens
277
+ else 0
278
+ ),
279
+ input_tokens_details=InputTokensDetails(
280
+ cached_tokens=usage.prompt_tokens_details.cached_tokens
281
+ if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
282
+ else 0
283
+ ),
284
+ )
285
+ if usage
286
+ else None
287
+ )
288
+
289
+ yield ResponseCompletedEvent(
290
+ response=final_response,
291
+ type="response.completed",
292
+ )
@@ -5,7 +5,7 @@ import enum
5
5
  from collections.abc import AsyncIterator
6
6
  from typing import TYPE_CHECKING
7
7
 
8
- from ..agent_output import AgentOutputSchema
8
+ from ..agent_output import AgentOutputSchemaBase
9
9
  from ..handoffs import Handoff
10
10
  from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
11
11
  from ..tool import Tool
@@ -41,7 +41,7 @@ class Model(abc.ABC):
41
41
  input: str | list[TResponseInputItem],
42
42
  model_settings: ModelSettings,
43
43
  tools: list[Tool],
44
- output_schema: AgentOutputSchema | None,
44
+ output_schema: AgentOutputSchemaBase | None,
45
45
  handoffs: list[Handoff],
46
46
  tracing: ModelTracing,
47
47
  *,
@@ -72,7 +72,7 @@ class Model(abc.ABC):
72
72
  input: str | list[TResponseInputItem],
73
73
  model_settings: ModelSettings,
74
74
  tools: list[Tool],
75
- output_schema: AgentOutputSchema | None,
75
+ output_schema: AgentOutputSchemaBase | None,
76
76
  handoffs: list[Handoff],
77
77
  tracing: ModelTracing,
78
78
  *,
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ from openai import AsyncOpenAI
4
+
5
+ from ..exceptions import UserError
6
+ from .interface import Model, ModelProvider
7
+ from .openai_provider import OpenAIProvider
8
+
9
+
10
+ class MultiProviderMap:
11
+ """A map of model name prefixes to ModelProviders."""
12
+
13
+ def __init__(self):
14
+ self._mapping: dict[str, ModelProvider] = {}
15
+
16
+ def has_prefix(self, prefix: str) -> bool:
17
+ """Returns True if the given prefix is in the mapping."""
18
+ return prefix in self._mapping
19
+
20
+ def get_mapping(self) -> dict[str, ModelProvider]:
21
+ """Returns a copy of the current prefix -> ModelProvider mapping."""
22
+ return self._mapping.copy()
23
+
24
+ def set_mapping(self, mapping: dict[str, ModelProvider]):
25
+ """Overwrites the current mapping with a new one."""
26
+ self._mapping = mapping
27
+
28
+ def get_provider(self, prefix: str) -> ModelProvider | None:
29
+ """Returns the ModelProvider for the given prefix.
30
+
31
+ Args:
32
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
33
+ """
34
+ return self._mapping.get(prefix)
35
+
36
+ def add_provider(self, prefix: str, provider: ModelProvider):
37
+ """Adds a new prefix -> ModelProvider mapping.
38
+
39
+ Args:
40
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
41
+ provider: The ModelProvider to use for the given prefix.
42
+ """
43
+ self._mapping[prefix] = provider
44
+
45
+ def remove_provider(self, prefix: str):
46
+ """Removes the mapping for the given prefix.
47
+
48
+ Args:
49
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
50
+ """
51
+ del self._mapping[prefix]
52
+
53
+
54
+ class MultiProvider(ModelProvider):
55
+ """This ModelProvider maps to a Model based on the prefix of the model name. By default, the
56
+ mapping is:
57
+ - "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
58
+ - "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
59
+
60
+ You can override or customize this mapping.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ *,
66
+ provider_map: MultiProviderMap | None = None,
67
+ openai_api_key: str | None = None,
68
+ openai_base_url: str | None = None,
69
+ openai_client: AsyncOpenAI | None = None,
70
+ openai_organization: str | None = None,
71
+ openai_project: str | None = None,
72
+ openai_use_responses: bool | None = None,
73
+ ) -> None:
74
+ """Create a new OpenAI provider.
75
+
76
+ Args:
77
+ provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
78
+ we will use a default mapping. See the documentation for this class to see the
79
+ default mapping.
80
+ openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
81
+ the default API key.
82
+ openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
83
+ use the default base URL.
84
+ openai_client: An optional OpenAI client to use. If not provided, we will create a new
85
+ OpenAI client using the api_key and base_url.
86
+ openai_organization: The organization to use for the OpenAI provider.
87
+ openai_project: The project to use for the OpenAI provider.
88
+ openai_use_responses: Whether to use the OpenAI responses API.
89
+ """
90
+ self.provider_map = provider_map
91
+ self.openai_provider = OpenAIProvider(
92
+ api_key=openai_api_key,
93
+ base_url=openai_base_url,
94
+ openai_client=openai_client,
95
+ organization=openai_organization,
96
+ project=openai_project,
97
+ use_responses=openai_use_responses,
98
+ )
99
+
100
+ self._fallback_providers: dict[str, ModelProvider] = {}
101
+
102
+ def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
103
+ if model_name is None:
104
+ return None, None
105
+ elif "/" in model_name:
106
+ prefix, model_name = model_name.split("/", 1)
107
+ return prefix, model_name
108
+ else:
109
+ return None, model_name
110
+
111
+ def _create_fallback_provider(self, prefix: str) -> ModelProvider:
112
+ if prefix == "litellm":
113
+ from ..extensions.models.litellm_provider import LitellmProvider
114
+
115
+ return LitellmProvider()
116
+ else:
117
+ raise UserError(f"Unknown prefix: {prefix}")
118
+
119
+ def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
120
+ if prefix is None or prefix == "openai":
121
+ return self.openai_provider
122
+ elif prefix in self._fallback_providers:
123
+ return self._fallback_providers[prefix]
124
+ else:
125
+ self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
126
+ return self._fallback_providers[prefix]
127
+
128
+ def get_model(self, model_name: str | None) -> Model:
129
+ """Returns a Model based on the model name. The model name can have a prefix, ending with
130
+ a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
131
+ the OpenAI provider.
132
+
133
+ Args:
134
+ model_name: The name of the model to get.
135
+
136
+ Returns:
137
+ A Model.
138
+ """
139
+ prefix, model_name = self._get_prefix_and_model_name(model_name)
140
+
141
+ if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
142
+ return provider.get_model(model_name)
143
+ else:
144
+ return self._get_fallback_provider(prefix).get_model(model_name)