openai-agents 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +2 -1
- agents/_run_impl.py +3 -3
- agents/agent.py +9 -2
- agents/agent_output.py +58 -8
- agents/extensions/models/__init__.py +0 -0
- agents/extensions/models/litellm_model.py +381 -0
- agents/extensions/models/litellm_provider.py +21 -0
- agents/extensions/visualization.py +1 -1
- agents/mcp/server.py +24 -5
- agents/model_settings.py +21 -2
- agents/models/chatcmpl_converter.py +466 -0
- agents/models/chatcmpl_helpers.py +37 -0
- agents/models/chatcmpl_stream_handler.py +292 -0
- agents/models/interface.py +3 -3
- agents/models/multi_provider.py +144 -0
- agents/models/openai_chatcompletions.py +35 -781
- agents/models/openai_responses.py +9 -9
- agents/result.py +24 -9
- agents/run.py +16 -13
- agents/tracing/processors.py +6 -0
- agents/tracing/span_data.py +1 -1
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.13.dist-info}/METADATA +5 -5
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.13.dist-info}/RECORD +25 -18
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.13.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
|
|
6
|
+
from openai import AsyncStream
|
|
7
|
+
from openai.types.chat import ChatCompletionChunk
|
|
8
|
+
from openai.types.completion_usage import CompletionUsage
|
|
9
|
+
from openai.types.responses import (
|
|
10
|
+
Response,
|
|
11
|
+
ResponseCompletedEvent,
|
|
12
|
+
ResponseContentPartAddedEvent,
|
|
13
|
+
ResponseContentPartDoneEvent,
|
|
14
|
+
ResponseCreatedEvent,
|
|
15
|
+
ResponseFunctionCallArgumentsDeltaEvent,
|
|
16
|
+
ResponseFunctionToolCall,
|
|
17
|
+
ResponseOutputItem,
|
|
18
|
+
ResponseOutputItemAddedEvent,
|
|
19
|
+
ResponseOutputItemDoneEvent,
|
|
20
|
+
ResponseOutputMessage,
|
|
21
|
+
ResponseOutputRefusal,
|
|
22
|
+
ResponseOutputText,
|
|
23
|
+
ResponseRefusalDeltaEvent,
|
|
24
|
+
ResponseTextDeltaEvent,
|
|
25
|
+
ResponseUsage,
|
|
26
|
+
)
|
|
27
|
+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
|
28
|
+
|
|
29
|
+
from ..items import TResponseStreamEvent
|
|
30
|
+
from .fake_id import FAKE_RESPONSES_ID
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class StreamingState:
|
|
35
|
+
started: bool = False
|
|
36
|
+
text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
|
|
37
|
+
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
|
|
38
|
+
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ChatCmplStreamHandler:
|
|
42
|
+
@classmethod
|
|
43
|
+
async def handle_stream(
|
|
44
|
+
cls,
|
|
45
|
+
response: Response,
|
|
46
|
+
stream: AsyncStream[ChatCompletionChunk],
|
|
47
|
+
) -> AsyncIterator[TResponseStreamEvent]:
|
|
48
|
+
usage: CompletionUsage | None = None
|
|
49
|
+
state = StreamingState()
|
|
50
|
+
|
|
51
|
+
async for chunk in stream:
|
|
52
|
+
if not state.started:
|
|
53
|
+
state.started = True
|
|
54
|
+
yield ResponseCreatedEvent(
|
|
55
|
+
response=response,
|
|
56
|
+
type="response.created",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# This is always set by the OpenAI API, but not by others e.g. LiteLLM
|
|
60
|
+
usage = chunk.usage if hasattr(chunk, "usage") else None
|
|
61
|
+
|
|
62
|
+
if not chunk.choices or not chunk.choices[0].delta:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
delta = chunk.choices[0].delta
|
|
66
|
+
|
|
67
|
+
# Handle text
|
|
68
|
+
if delta.content:
|
|
69
|
+
if not state.text_content_index_and_output:
|
|
70
|
+
# Initialize a content tracker for streaming text
|
|
71
|
+
state.text_content_index_and_output = (
|
|
72
|
+
0 if not state.refusal_content_index_and_output else 1,
|
|
73
|
+
ResponseOutputText(
|
|
74
|
+
text="",
|
|
75
|
+
type="output_text",
|
|
76
|
+
annotations=[],
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
# Start a new assistant message stream
|
|
80
|
+
assistant_item = ResponseOutputMessage(
|
|
81
|
+
id=FAKE_RESPONSES_ID,
|
|
82
|
+
content=[],
|
|
83
|
+
role="assistant",
|
|
84
|
+
type="message",
|
|
85
|
+
status="in_progress",
|
|
86
|
+
)
|
|
87
|
+
# Notify consumers of the start of a new output message + first content part
|
|
88
|
+
yield ResponseOutputItemAddedEvent(
|
|
89
|
+
item=assistant_item,
|
|
90
|
+
output_index=0,
|
|
91
|
+
type="response.output_item.added",
|
|
92
|
+
)
|
|
93
|
+
yield ResponseContentPartAddedEvent(
|
|
94
|
+
content_index=state.text_content_index_and_output[0],
|
|
95
|
+
item_id=FAKE_RESPONSES_ID,
|
|
96
|
+
output_index=0,
|
|
97
|
+
part=ResponseOutputText(
|
|
98
|
+
text="",
|
|
99
|
+
type="output_text",
|
|
100
|
+
annotations=[],
|
|
101
|
+
),
|
|
102
|
+
type="response.content_part.added",
|
|
103
|
+
)
|
|
104
|
+
# Emit the delta for this segment of content
|
|
105
|
+
yield ResponseTextDeltaEvent(
|
|
106
|
+
content_index=state.text_content_index_and_output[0],
|
|
107
|
+
delta=delta.content,
|
|
108
|
+
item_id=FAKE_RESPONSES_ID,
|
|
109
|
+
output_index=0,
|
|
110
|
+
type="response.output_text.delta",
|
|
111
|
+
)
|
|
112
|
+
# Accumulate the text into the response part
|
|
113
|
+
state.text_content_index_and_output[1].text += delta.content
|
|
114
|
+
|
|
115
|
+
# Handle refusals (model declines to answer)
|
|
116
|
+
# This is always set by the OpenAI API, but not by others e.g. LiteLLM
|
|
117
|
+
if hasattr(delta, "refusal") and delta.refusal:
|
|
118
|
+
if not state.refusal_content_index_and_output:
|
|
119
|
+
# Initialize a content tracker for streaming refusal text
|
|
120
|
+
state.refusal_content_index_and_output = (
|
|
121
|
+
0 if not state.text_content_index_and_output else 1,
|
|
122
|
+
ResponseOutputRefusal(refusal="", type="refusal"),
|
|
123
|
+
)
|
|
124
|
+
# Start a new assistant message if one doesn't exist yet (in-progress)
|
|
125
|
+
assistant_item = ResponseOutputMessage(
|
|
126
|
+
id=FAKE_RESPONSES_ID,
|
|
127
|
+
content=[],
|
|
128
|
+
role="assistant",
|
|
129
|
+
type="message",
|
|
130
|
+
status="in_progress",
|
|
131
|
+
)
|
|
132
|
+
# Notify downstream that assistant message + first content part are starting
|
|
133
|
+
yield ResponseOutputItemAddedEvent(
|
|
134
|
+
item=assistant_item,
|
|
135
|
+
output_index=0,
|
|
136
|
+
type="response.output_item.added",
|
|
137
|
+
)
|
|
138
|
+
yield ResponseContentPartAddedEvent(
|
|
139
|
+
content_index=state.refusal_content_index_and_output[0],
|
|
140
|
+
item_id=FAKE_RESPONSES_ID,
|
|
141
|
+
output_index=0,
|
|
142
|
+
part=ResponseOutputText(
|
|
143
|
+
text="",
|
|
144
|
+
type="output_text",
|
|
145
|
+
annotations=[],
|
|
146
|
+
),
|
|
147
|
+
type="response.content_part.added",
|
|
148
|
+
)
|
|
149
|
+
# Emit the delta for this segment of refusal
|
|
150
|
+
yield ResponseRefusalDeltaEvent(
|
|
151
|
+
content_index=state.refusal_content_index_and_output[0],
|
|
152
|
+
delta=delta.refusal,
|
|
153
|
+
item_id=FAKE_RESPONSES_ID,
|
|
154
|
+
output_index=0,
|
|
155
|
+
type="response.refusal.delta",
|
|
156
|
+
)
|
|
157
|
+
# Accumulate the refusal string in the output part
|
|
158
|
+
state.refusal_content_index_and_output[1].refusal += delta.refusal
|
|
159
|
+
|
|
160
|
+
# Handle tool calls
|
|
161
|
+
# Because we don't know the name of the function until the end of the stream, we'll
|
|
162
|
+
# save everything and yield events at the end
|
|
163
|
+
if delta.tool_calls:
|
|
164
|
+
for tc_delta in delta.tool_calls:
|
|
165
|
+
if tc_delta.index not in state.function_calls:
|
|
166
|
+
state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
|
|
167
|
+
id=FAKE_RESPONSES_ID,
|
|
168
|
+
arguments="",
|
|
169
|
+
name="",
|
|
170
|
+
type="function_call",
|
|
171
|
+
call_id="",
|
|
172
|
+
)
|
|
173
|
+
tc_function = tc_delta.function
|
|
174
|
+
|
|
175
|
+
state.function_calls[tc_delta.index].arguments += (
|
|
176
|
+
tc_function.arguments if tc_function else ""
|
|
177
|
+
) or ""
|
|
178
|
+
state.function_calls[tc_delta.index].name += (
|
|
179
|
+
tc_function.name if tc_function else ""
|
|
180
|
+
) or ""
|
|
181
|
+
state.function_calls[tc_delta.index].call_id += tc_delta.id or ""
|
|
182
|
+
|
|
183
|
+
function_call_starting_index = 0
|
|
184
|
+
if state.text_content_index_and_output:
|
|
185
|
+
function_call_starting_index += 1
|
|
186
|
+
# Send end event for this content part
|
|
187
|
+
yield ResponseContentPartDoneEvent(
|
|
188
|
+
content_index=state.text_content_index_and_output[0],
|
|
189
|
+
item_id=FAKE_RESPONSES_ID,
|
|
190
|
+
output_index=0,
|
|
191
|
+
part=state.text_content_index_and_output[1],
|
|
192
|
+
type="response.content_part.done",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if state.refusal_content_index_and_output:
|
|
196
|
+
function_call_starting_index += 1
|
|
197
|
+
# Send end event for this content part
|
|
198
|
+
yield ResponseContentPartDoneEvent(
|
|
199
|
+
content_index=state.refusal_content_index_and_output[0],
|
|
200
|
+
item_id=FAKE_RESPONSES_ID,
|
|
201
|
+
output_index=0,
|
|
202
|
+
part=state.refusal_content_index_and_output[1],
|
|
203
|
+
type="response.content_part.done",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Actually send events for the function calls
|
|
207
|
+
for function_call in state.function_calls.values():
|
|
208
|
+
# First, a ResponseOutputItemAdded for the function call
|
|
209
|
+
yield ResponseOutputItemAddedEvent(
|
|
210
|
+
item=ResponseFunctionToolCall(
|
|
211
|
+
id=FAKE_RESPONSES_ID,
|
|
212
|
+
call_id=function_call.call_id,
|
|
213
|
+
arguments=function_call.arguments,
|
|
214
|
+
name=function_call.name,
|
|
215
|
+
type="function_call",
|
|
216
|
+
),
|
|
217
|
+
output_index=function_call_starting_index,
|
|
218
|
+
type="response.output_item.added",
|
|
219
|
+
)
|
|
220
|
+
# Then, yield the args
|
|
221
|
+
yield ResponseFunctionCallArgumentsDeltaEvent(
|
|
222
|
+
delta=function_call.arguments,
|
|
223
|
+
item_id=FAKE_RESPONSES_ID,
|
|
224
|
+
output_index=function_call_starting_index,
|
|
225
|
+
type="response.function_call_arguments.delta",
|
|
226
|
+
)
|
|
227
|
+
# Finally, the ResponseOutputItemDone
|
|
228
|
+
yield ResponseOutputItemDoneEvent(
|
|
229
|
+
item=ResponseFunctionToolCall(
|
|
230
|
+
id=FAKE_RESPONSES_ID,
|
|
231
|
+
call_id=function_call.call_id,
|
|
232
|
+
arguments=function_call.arguments,
|
|
233
|
+
name=function_call.name,
|
|
234
|
+
type="function_call",
|
|
235
|
+
),
|
|
236
|
+
output_index=function_call_starting_index,
|
|
237
|
+
type="response.output_item.done",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Finally, send the Response completed event
|
|
241
|
+
outputs: list[ResponseOutputItem] = []
|
|
242
|
+
if state.text_content_index_and_output or state.refusal_content_index_and_output:
|
|
243
|
+
assistant_msg = ResponseOutputMessage(
|
|
244
|
+
id=FAKE_RESPONSES_ID,
|
|
245
|
+
content=[],
|
|
246
|
+
role="assistant",
|
|
247
|
+
type="message",
|
|
248
|
+
status="completed",
|
|
249
|
+
)
|
|
250
|
+
if state.text_content_index_and_output:
|
|
251
|
+
assistant_msg.content.append(state.text_content_index_and_output[1])
|
|
252
|
+
if state.refusal_content_index_and_output:
|
|
253
|
+
assistant_msg.content.append(state.refusal_content_index_and_output[1])
|
|
254
|
+
outputs.append(assistant_msg)
|
|
255
|
+
|
|
256
|
+
# send a ResponseOutputItemDone for the assistant message
|
|
257
|
+
yield ResponseOutputItemDoneEvent(
|
|
258
|
+
item=assistant_msg,
|
|
259
|
+
output_index=0,
|
|
260
|
+
type="response.output_item.done",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
for function_call in state.function_calls.values():
|
|
264
|
+
outputs.append(function_call)
|
|
265
|
+
|
|
266
|
+
final_response = response.model_copy()
|
|
267
|
+
final_response.output = outputs
|
|
268
|
+
final_response.usage = (
|
|
269
|
+
ResponseUsage(
|
|
270
|
+
input_tokens=usage.prompt_tokens,
|
|
271
|
+
output_tokens=usage.completion_tokens,
|
|
272
|
+
total_tokens=usage.total_tokens,
|
|
273
|
+
output_tokens_details=OutputTokensDetails(
|
|
274
|
+
reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
|
|
275
|
+
if usage.completion_tokens_details
|
|
276
|
+
and usage.completion_tokens_details.reasoning_tokens
|
|
277
|
+
else 0
|
|
278
|
+
),
|
|
279
|
+
input_tokens_details=InputTokensDetails(
|
|
280
|
+
cached_tokens=usage.prompt_tokens_details.cached_tokens
|
|
281
|
+
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
|
|
282
|
+
else 0
|
|
283
|
+
),
|
|
284
|
+
)
|
|
285
|
+
if usage
|
|
286
|
+
else None
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
yield ResponseCompletedEvent(
|
|
290
|
+
response=final_response,
|
|
291
|
+
type="response.completed",
|
|
292
|
+
)
|
agents/models/interface.py
CHANGED
|
@@ -5,7 +5,7 @@ import enum
|
|
|
5
5
|
from collections.abc import AsyncIterator
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
|
-
from ..agent_output import
|
|
8
|
+
from ..agent_output import AgentOutputSchemaBase
|
|
9
9
|
from ..handoffs import Handoff
|
|
10
10
|
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
|
11
11
|
from ..tool import Tool
|
|
@@ -41,7 +41,7 @@ class Model(abc.ABC):
|
|
|
41
41
|
input: str | list[TResponseInputItem],
|
|
42
42
|
model_settings: ModelSettings,
|
|
43
43
|
tools: list[Tool],
|
|
44
|
-
output_schema:
|
|
44
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
45
45
|
handoffs: list[Handoff],
|
|
46
46
|
tracing: ModelTracing,
|
|
47
47
|
*,
|
|
@@ -72,7 +72,7 @@ class Model(abc.ABC):
|
|
|
72
72
|
input: str | list[TResponseInputItem],
|
|
73
73
|
model_settings: ModelSettings,
|
|
74
74
|
tools: list[Tool],
|
|
75
|
-
output_schema:
|
|
75
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
76
76
|
handoffs: list[Handoff],
|
|
77
77
|
tracing: ModelTracing,
|
|
78
78
|
*,
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from ..exceptions import UserError
|
|
6
|
+
from .interface import Model, ModelProvider
|
|
7
|
+
from .openai_provider import OpenAIProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MultiProviderMap:
|
|
11
|
+
"""A map of model name prefixes to ModelProviders."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self._mapping: dict[str, ModelProvider] = {}
|
|
15
|
+
|
|
16
|
+
def has_prefix(self, prefix: str) -> bool:
|
|
17
|
+
"""Returns True if the given prefix is in the mapping."""
|
|
18
|
+
return prefix in self._mapping
|
|
19
|
+
|
|
20
|
+
def get_mapping(self) -> dict[str, ModelProvider]:
|
|
21
|
+
"""Returns a copy of the current prefix -> ModelProvider mapping."""
|
|
22
|
+
return self._mapping.copy()
|
|
23
|
+
|
|
24
|
+
def set_mapping(self, mapping: dict[str, ModelProvider]):
|
|
25
|
+
"""Overwrites the current mapping with a new one."""
|
|
26
|
+
self._mapping = mapping
|
|
27
|
+
|
|
28
|
+
def get_provider(self, prefix: str) -> ModelProvider | None:
|
|
29
|
+
"""Returns the ModelProvider for the given prefix.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
|
33
|
+
"""
|
|
34
|
+
return self._mapping.get(prefix)
|
|
35
|
+
|
|
36
|
+
def add_provider(self, prefix: str, provider: ModelProvider):
|
|
37
|
+
"""Adds a new prefix -> ModelProvider mapping.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
|
41
|
+
provider: The ModelProvider to use for the given prefix.
|
|
42
|
+
"""
|
|
43
|
+
self._mapping[prefix] = provider
|
|
44
|
+
|
|
45
|
+
def remove_provider(self, prefix: str):
|
|
46
|
+
"""Removes the mapping for the given prefix.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
|
|
50
|
+
"""
|
|
51
|
+
del self._mapping[prefix]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MultiProvider(ModelProvider):
|
|
55
|
+
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
|
|
56
|
+
mapping is:
|
|
57
|
+
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
|
|
58
|
+
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
|
|
59
|
+
|
|
60
|
+
You can override or customize this mapping.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
*,
|
|
66
|
+
provider_map: MultiProviderMap | None = None,
|
|
67
|
+
openai_api_key: str | None = None,
|
|
68
|
+
openai_base_url: str | None = None,
|
|
69
|
+
openai_client: AsyncOpenAI | None = None,
|
|
70
|
+
openai_organization: str | None = None,
|
|
71
|
+
openai_project: str | None = None,
|
|
72
|
+
openai_use_responses: bool | None = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
"""Create a new OpenAI provider.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
|
|
78
|
+
we will use a default mapping. See the documentation for this class to see the
|
|
79
|
+
default mapping.
|
|
80
|
+
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
|
|
81
|
+
the default API key.
|
|
82
|
+
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
|
|
83
|
+
use the default base URL.
|
|
84
|
+
openai_client: An optional OpenAI client to use. If not provided, we will create a new
|
|
85
|
+
OpenAI client using the api_key and base_url.
|
|
86
|
+
openai_organization: The organization to use for the OpenAI provider.
|
|
87
|
+
openai_project: The project to use for the OpenAI provider.
|
|
88
|
+
openai_use_responses: Whether to use the OpenAI responses API.
|
|
89
|
+
"""
|
|
90
|
+
self.provider_map = provider_map
|
|
91
|
+
self.openai_provider = OpenAIProvider(
|
|
92
|
+
api_key=openai_api_key,
|
|
93
|
+
base_url=openai_base_url,
|
|
94
|
+
openai_client=openai_client,
|
|
95
|
+
organization=openai_organization,
|
|
96
|
+
project=openai_project,
|
|
97
|
+
use_responses=openai_use_responses,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._fallback_providers: dict[str, ModelProvider] = {}
|
|
101
|
+
|
|
102
|
+
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
|
|
103
|
+
if model_name is None:
|
|
104
|
+
return None, None
|
|
105
|
+
elif "/" in model_name:
|
|
106
|
+
prefix, model_name = model_name.split("/", 1)
|
|
107
|
+
return prefix, model_name
|
|
108
|
+
else:
|
|
109
|
+
return None, model_name
|
|
110
|
+
|
|
111
|
+
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
|
|
112
|
+
if prefix == "litellm":
|
|
113
|
+
from ..extensions.models.litellm_provider import LitellmProvider
|
|
114
|
+
|
|
115
|
+
return LitellmProvider()
|
|
116
|
+
else:
|
|
117
|
+
raise UserError(f"Unknown prefix: {prefix}")
|
|
118
|
+
|
|
119
|
+
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
|
|
120
|
+
if prefix is None or prefix == "openai":
|
|
121
|
+
return self.openai_provider
|
|
122
|
+
elif prefix in self._fallback_providers:
|
|
123
|
+
return self._fallback_providers[prefix]
|
|
124
|
+
else:
|
|
125
|
+
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
|
|
126
|
+
return self._fallback_providers[prefix]
|
|
127
|
+
|
|
128
|
+
def get_model(self, model_name: str | None) -> Model:
|
|
129
|
+
"""Returns a Model based on the model name. The model name can have a prefix, ending with
|
|
130
|
+
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
|
|
131
|
+
the OpenAI provider.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
model_name: The name of the model to get.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
A Model.
|
|
138
|
+
"""
|
|
139
|
+
prefix, model_name = self._get_prefix_and_model_name(model_name)
|
|
140
|
+
|
|
141
|
+
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
|
|
142
|
+
return provider.get_model(model_name)
|
|
143
|
+
else:
|
|
144
|
+
return self._get_fallback_provider(prefix).get_model(model_name)
|