openai-agents 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

@@ -0,0 +1,290 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from dataclasses import dataclass, field
5
+
6
+ from openai import AsyncStream
7
+ from openai.types.chat import ChatCompletionChunk
8
+ from openai.types.completion_usage import CompletionUsage
9
+ from openai.types.responses import (
10
+ Response,
11
+ ResponseCompletedEvent,
12
+ ResponseContentPartAddedEvent,
13
+ ResponseContentPartDoneEvent,
14
+ ResponseCreatedEvent,
15
+ ResponseFunctionCallArgumentsDeltaEvent,
16
+ ResponseFunctionToolCall,
17
+ ResponseOutputItem,
18
+ ResponseOutputItemAddedEvent,
19
+ ResponseOutputItemDoneEvent,
20
+ ResponseOutputMessage,
21
+ ResponseOutputRefusal,
22
+ ResponseOutputText,
23
+ ResponseRefusalDeltaEvent,
24
+ ResponseTextDeltaEvent,
25
+ ResponseUsage,
26
+ )
27
+ from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
28
+
29
+ from ..items import TResponseStreamEvent
30
+ from .fake_id import FAKE_RESPONSES_ID
31
+
32
+
33
+ @dataclass
34
+ class StreamingState:
35
+ started: bool = False
36
+ text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
37
+ refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
38
+ function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
39
+
40
+
41
+ class ChatCmplStreamHandler:
42
+ @classmethod
43
+ async def handle_stream(
44
+ cls,
45
+ response: Response,
46
+ stream: AsyncStream[ChatCompletionChunk],
47
+ ) -> AsyncIterator[TResponseStreamEvent]:
48
+ usage: CompletionUsage | None = None
49
+ state = StreamingState()
50
+
51
+ async for chunk in stream:
52
+ if not state.started:
53
+ state.started = True
54
+ yield ResponseCreatedEvent(
55
+ response=response,
56
+ type="response.created",
57
+ )
58
+
59
+ usage = chunk.usage
60
+
61
+ if not chunk.choices or not chunk.choices[0].delta:
62
+ continue
63
+
64
+ delta = chunk.choices[0].delta
65
+
66
+ # Handle text
67
+ if delta.content:
68
+ if not state.text_content_index_and_output:
69
+ # Initialize a content tracker for streaming text
70
+ state.text_content_index_and_output = (
71
+ 0 if not state.refusal_content_index_and_output else 1,
72
+ ResponseOutputText(
73
+ text="",
74
+ type="output_text",
75
+ annotations=[],
76
+ ),
77
+ )
78
+ # Start a new assistant message stream
79
+ assistant_item = ResponseOutputMessage(
80
+ id=FAKE_RESPONSES_ID,
81
+ content=[],
82
+ role="assistant",
83
+ type="message",
84
+ status="in_progress",
85
+ )
86
+ # Notify consumers of the start of a new output message + first content part
87
+ yield ResponseOutputItemAddedEvent(
88
+ item=assistant_item,
89
+ output_index=0,
90
+ type="response.output_item.added",
91
+ )
92
+ yield ResponseContentPartAddedEvent(
93
+ content_index=state.text_content_index_and_output[0],
94
+ item_id=FAKE_RESPONSES_ID,
95
+ output_index=0,
96
+ part=ResponseOutputText(
97
+ text="",
98
+ type="output_text",
99
+ annotations=[],
100
+ ),
101
+ type="response.content_part.added",
102
+ )
103
+ # Emit the delta for this segment of content
104
+ yield ResponseTextDeltaEvent(
105
+ content_index=state.text_content_index_and_output[0],
106
+ delta=delta.content,
107
+ item_id=FAKE_RESPONSES_ID,
108
+ output_index=0,
109
+ type="response.output_text.delta",
110
+ )
111
+ # Accumulate the text into the response part
112
+ state.text_content_index_and_output[1].text += delta.content
113
+
114
+ # Handle refusals (model declines to answer)
115
+ if delta.refusal:
116
+ if not state.refusal_content_index_and_output:
117
+ # Initialize a content tracker for streaming refusal text
118
+ state.refusal_content_index_and_output = (
119
+ 0 if not state.text_content_index_and_output else 1,
120
+ ResponseOutputRefusal(refusal="", type="refusal"),
121
+ )
122
+ # Start a new assistant message if one doesn't exist yet (in-progress)
123
+ assistant_item = ResponseOutputMessage(
124
+ id=FAKE_RESPONSES_ID,
125
+ content=[],
126
+ role="assistant",
127
+ type="message",
128
+ status="in_progress",
129
+ )
130
+ # Notify downstream that assistant message + first content part are starting
131
+ yield ResponseOutputItemAddedEvent(
132
+ item=assistant_item,
133
+ output_index=0,
134
+ type="response.output_item.added",
135
+ )
136
+ yield ResponseContentPartAddedEvent(
137
+ content_index=state.refusal_content_index_and_output[0],
138
+ item_id=FAKE_RESPONSES_ID,
139
+ output_index=0,
140
+ part=ResponseOutputText(
141
+ text="",
142
+ type="output_text",
143
+ annotations=[],
144
+ ),
145
+ type="response.content_part.added",
146
+ )
147
+ # Emit the delta for this segment of refusal
148
+ yield ResponseRefusalDeltaEvent(
149
+ content_index=state.refusal_content_index_and_output[0],
150
+ delta=delta.refusal,
151
+ item_id=FAKE_RESPONSES_ID,
152
+ output_index=0,
153
+ type="response.refusal.delta",
154
+ )
155
+ # Accumulate the refusal string in the output part
156
+ state.refusal_content_index_and_output[1].refusal += delta.refusal
157
+
158
+ # Handle tool calls
159
+ # Because we don't know the name of the function until the end of the stream, we'll
160
+ # save everything and yield events at the end
161
+ if delta.tool_calls:
162
+ for tc_delta in delta.tool_calls:
163
+ if tc_delta.index not in state.function_calls:
164
+ state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
165
+ id=FAKE_RESPONSES_ID,
166
+ arguments="",
167
+ name="",
168
+ type="function_call",
169
+ call_id="",
170
+ )
171
+ tc_function = tc_delta.function
172
+
173
+ state.function_calls[tc_delta.index].arguments += (
174
+ tc_function.arguments if tc_function else ""
175
+ ) or ""
176
+ state.function_calls[tc_delta.index].name += (
177
+ tc_function.name if tc_function else ""
178
+ ) or ""
179
+ state.function_calls[tc_delta.index].call_id += tc_delta.id or ""
180
+
181
+ function_call_starting_index = 0
182
+ if state.text_content_index_and_output:
183
+ function_call_starting_index += 1
184
+ # Send end event for this content part
185
+ yield ResponseContentPartDoneEvent(
186
+ content_index=state.text_content_index_and_output[0],
187
+ item_id=FAKE_RESPONSES_ID,
188
+ output_index=0,
189
+ part=state.text_content_index_and_output[1],
190
+ type="response.content_part.done",
191
+ )
192
+
193
+ if state.refusal_content_index_and_output:
194
+ function_call_starting_index += 1
195
+ # Send end event for this content part
196
+ yield ResponseContentPartDoneEvent(
197
+ content_index=state.refusal_content_index_and_output[0],
198
+ item_id=FAKE_RESPONSES_ID,
199
+ output_index=0,
200
+ part=state.refusal_content_index_and_output[1],
201
+ type="response.content_part.done",
202
+ )
203
+
204
+ # Actually send events for the function calls
205
+ for function_call in state.function_calls.values():
206
+ # First, a ResponseOutputItemAdded for the function call
207
+ yield ResponseOutputItemAddedEvent(
208
+ item=ResponseFunctionToolCall(
209
+ id=FAKE_RESPONSES_ID,
210
+ call_id=function_call.call_id,
211
+ arguments=function_call.arguments,
212
+ name=function_call.name,
213
+ type="function_call",
214
+ ),
215
+ output_index=function_call_starting_index,
216
+ type="response.output_item.added",
217
+ )
218
+ # Then, yield the args
219
+ yield ResponseFunctionCallArgumentsDeltaEvent(
220
+ delta=function_call.arguments,
221
+ item_id=FAKE_RESPONSES_ID,
222
+ output_index=function_call_starting_index,
223
+ type="response.function_call_arguments.delta",
224
+ )
225
+ # Finally, the ResponseOutputItemDone
226
+ yield ResponseOutputItemDoneEvent(
227
+ item=ResponseFunctionToolCall(
228
+ id=FAKE_RESPONSES_ID,
229
+ call_id=function_call.call_id,
230
+ arguments=function_call.arguments,
231
+ name=function_call.name,
232
+ type="function_call",
233
+ ),
234
+ output_index=function_call_starting_index,
235
+ type="response.output_item.done",
236
+ )
237
+
238
+ # Finally, send the Response completed event
239
+ outputs: list[ResponseOutputItem] = []
240
+ if state.text_content_index_and_output or state.refusal_content_index_and_output:
241
+ assistant_msg = ResponseOutputMessage(
242
+ id=FAKE_RESPONSES_ID,
243
+ content=[],
244
+ role="assistant",
245
+ type="message",
246
+ status="completed",
247
+ )
248
+ if state.text_content_index_and_output:
249
+ assistant_msg.content.append(state.text_content_index_and_output[1])
250
+ if state.refusal_content_index_and_output:
251
+ assistant_msg.content.append(state.refusal_content_index_and_output[1])
252
+ outputs.append(assistant_msg)
253
+
254
+ # send a ResponseOutputItemDone for the assistant message
255
+ yield ResponseOutputItemDoneEvent(
256
+ item=assistant_msg,
257
+ output_index=0,
258
+ type="response.output_item.done",
259
+ )
260
+
261
+ for function_call in state.function_calls.values():
262
+ outputs.append(function_call)
263
+
264
+ final_response = response.model_copy()
265
+ final_response.output = outputs
266
+ final_response.usage = (
267
+ ResponseUsage(
268
+ input_tokens=usage.prompt_tokens,
269
+ output_tokens=usage.completion_tokens,
270
+ total_tokens=usage.total_tokens,
271
+ output_tokens_details=OutputTokensDetails(
272
+ reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
273
+ if usage.completion_tokens_details
274
+ and usage.completion_tokens_details.reasoning_tokens
275
+ else 0
276
+ ),
277
+ input_tokens_details=InputTokensDetails(
278
+ cached_tokens=usage.prompt_tokens_details.cached_tokens
279
+ if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
280
+ else 0
281
+ ),
282
+ )
283
+ if usage
284
+ else None
285
+ )
286
+
287
+ yield ResponseCompletedEvent(
288
+ response=final_response,
289
+ type="response.completed",
290
+ )
@@ -5,7 +5,7 @@ import enum
5
5
  from collections.abc import AsyncIterator
6
6
  from typing import TYPE_CHECKING
7
7
 
8
- from ..agent_output import AgentOutputSchema
8
+ from ..agent_output import AgentOutputSchemaBase
9
9
  from ..handoffs import Handoff
10
10
  from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
11
11
  from ..tool import Tool
@@ -41,7 +41,7 @@ class Model(abc.ABC):
41
41
  input: str | list[TResponseInputItem],
42
42
  model_settings: ModelSettings,
43
43
  tools: list[Tool],
44
- output_schema: AgentOutputSchema | None,
44
+ output_schema: AgentOutputSchemaBase | None,
45
45
  handoffs: list[Handoff],
46
46
  tracing: ModelTracing,
47
47
  *,
@@ -72,7 +72,7 @@ class Model(abc.ABC):
72
72
  input: str | list[TResponseInputItem],
73
73
  model_settings: ModelSettings,
74
74
  tools: list[Tool],
75
- output_schema: AgentOutputSchema | None,
75
+ output_schema: AgentOutputSchemaBase | None,
76
76
  handoffs: list[Handoff],
77
77
  tracing: ModelTracing,
78
78
  *,
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ from openai import AsyncOpenAI
4
+
5
+ from ..exceptions import UserError
6
+ from .interface import Model, ModelProvider
7
+ from .openai_provider import OpenAIProvider
8
+
9
+
10
+ class MultiProviderMap:
11
+ """A map of model name prefixes to ModelProviders."""
12
+
13
+ def __init__(self):
14
+ self._mapping: dict[str, ModelProvider] = {}
15
+
16
+ def has_prefix(self, prefix: str) -> bool:
17
+ """Returns True if the given prefix is in the mapping."""
18
+ return prefix in self._mapping
19
+
20
+ def get_mapping(self) -> dict[str, ModelProvider]:
21
+ """Returns a copy of the current prefix -> ModelProvider mapping."""
22
+ return self._mapping.copy()
23
+
24
+ def set_mapping(self, mapping: dict[str, ModelProvider]):
25
+ """Overwrites the current mapping with a new one."""
26
+ self._mapping = mapping
27
+
28
+ def get_provider(self, prefix: str) -> ModelProvider | None:
29
+ """Returns the ModelProvider for the given prefix.
30
+
31
+ Args:
32
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
33
+ """
34
+ return self._mapping.get(prefix)
35
+
36
+ def add_provider(self, prefix: str, provider: ModelProvider):
37
+ """Adds a new prefix -> ModelProvider mapping.
38
+
39
+ Args:
40
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
41
+ provider: The ModelProvider to use for the given prefix.
42
+ """
43
+ self._mapping[prefix] = provider
44
+
45
+ def remove_provider(self, prefix: str):
46
+ """Removes the mapping for the given prefix.
47
+
48
+ Args:
49
+ prefix: The prefix of the model name e.g. "openai" or "my_prefix".
50
+ """
51
+ del self._mapping[prefix]
52
+
53
+
54
+ class MultiProvider(ModelProvider):
55
+ """This ModelProvider maps to a Model based on the prefix of the model name. By default, the
56
+ mapping is:
57
+ - "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
58
+ - "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
59
+
60
+ You can override or customize this mapping.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ *,
66
+ provider_map: MultiProviderMap | None = None,
67
+ openai_api_key: str | None = None,
68
+ openai_base_url: str | None = None,
69
+ openai_client: AsyncOpenAI | None = None,
70
+ openai_organization: str | None = None,
71
+ openai_project: str | None = None,
72
+ openai_use_responses: bool | None = None,
73
+ ) -> None:
74
+ """Create a new OpenAI provider.
75
+
76
+ Args:
77
+ provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
78
+ we will use a default mapping. See the documentation for this class to see the
79
+ default mapping.
80
+ openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
81
+ the default API key.
82
+ openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
83
+ use the default base URL.
84
+ openai_client: An optional OpenAI client to use. If not provided, we will create a new
85
+ OpenAI client using the api_key and base_url.
86
+ openai_organization: The organization to use for the OpenAI provider.
87
+ openai_project: The project to use for the OpenAI provider.
88
+ openai_use_responses: Whether to use the OpenAI responses API.
89
+ """
90
+ self.provider_map = provider_map
91
+ self.openai_provider = OpenAIProvider(
92
+ api_key=openai_api_key,
93
+ base_url=openai_base_url,
94
+ openai_client=openai_client,
95
+ organization=openai_organization,
96
+ project=openai_project,
97
+ use_responses=openai_use_responses,
98
+ )
99
+
100
+ self._fallback_providers: dict[str, ModelProvider] = {}
101
+
102
+ def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
103
+ if model_name is None:
104
+ return None, None
105
+ elif "/" in model_name:
106
+ prefix, model_name = model_name.split("/", 1)
107
+ return prefix, model_name
108
+ else:
109
+ return None, model_name
110
+
111
+ def _create_fallback_provider(self, prefix: str) -> ModelProvider:
112
+ if prefix == "litellm":
113
+ from ..extensions.models.litellm_provider import LitellmProvider
114
+
115
+ return LitellmProvider()
116
+ else:
117
+ raise UserError(f"Unknown prefix: {prefix}")
118
+
119
+ def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
120
+ if prefix is None or prefix == "openai":
121
+ return self.openai_provider
122
+ elif prefix in self._fallback_providers:
123
+ return self._fallback_providers[prefix]
124
+ else:
125
+ self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
126
+ return self._fallback_providers[prefix]
127
+
128
+ def get_model(self, model_name: str | None) -> Model:
129
+ """Returns a Model based on the model name. The model name can have a prefix, ending with
130
+ a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
131
+ the OpenAI provider.
132
+
133
+ Args:
134
+ model_name: The name of the model to get.
135
+
136
+ Returns:
137
+ A Model.
138
+ """
139
+ prefix, model_name = self._get_prefix_and_model_name(model_name)
140
+
141
+ if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
142
+ return provider.get_model(model_name)
143
+ else:
144
+ return self._get_fallback_provider(prefix).get_model(model_name)