openai-agents 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/__init__.py +2 -1
- agents/_run_impl.py +3 -3
- agents/agent.py +9 -2
- agents/agent_output.py +58 -8
- agents/extensions/models/__init__.py +0 -0
- agents/extensions/models/litellm_model.py +382 -0
- agents/extensions/models/litellm_provider.py +21 -0
- agents/extensions/visualization.py +1 -1
- agents/mcp/server.py +2 -1
- agents/models/chatcmpl_converter.py +466 -0
- agents/models/chatcmpl_helpers.py +37 -0
- agents/models/chatcmpl_stream_handler.py +290 -0
- agents/models/interface.py +3 -3
- agents/models/multi_provider.py +144 -0
- agents/models/openai_chatcompletions.py +33 -776
- agents/models/openai_responses.py +8 -8
- agents/result.py +3 -6
- agents/run.py +16 -13
- agents/tracing/processors.py +6 -0
- agents/tracing/span_data.py +1 -1
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.12.dist-info}/METADATA +4 -4
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.12.dist-info}/RECORD +24 -17
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.12.dist-info}/WHEEL +0 -0
- {openai_agents-0.0.11.dist-info → openai_agents-0.0.12.dist-info}/licenses/LICENSE +0 -0
agents/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
|
|
|
6
6
|
|
|
7
7
|
from . import _config
|
|
8
8
|
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
|
9
|
-
from .agent_output import AgentOutputSchema
|
|
9
|
+
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
|
|
10
10
|
from .computer import AsyncComputer, Button, Computer, Environment
|
|
11
11
|
from .exceptions import (
|
|
12
12
|
AgentsException,
|
|
@@ -158,6 +158,7 @@ __all__ = [
|
|
|
158
158
|
"OpenAIProvider",
|
|
159
159
|
"OpenAIResponsesModel",
|
|
160
160
|
"AgentOutputSchema",
|
|
161
|
+
"AgentOutputSchemaBase",
|
|
161
162
|
"Computer",
|
|
162
163
|
"AsyncComputer",
|
|
163
164
|
"Environment",
|
agents/_run_impl.py
CHANGED
|
@@ -29,7 +29,7 @@ from openai.types.responses.response_input_param import ComputerCallOutput
|
|
|
29
29
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
|
30
30
|
|
|
31
31
|
from .agent import Agent, ToolsToFinalOutputResult
|
|
32
|
-
from .agent_output import
|
|
32
|
+
from .agent_output import AgentOutputSchemaBase
|
|
33
33
|
from .computer import AsyncComputer, Computer
|
|
34
34
|
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
|
35
35
|
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
|
|
@@ -195,7 +195,7 @@ class RunImpl:
|
|
|
195
195
|
pre_step_items: list[RunItem],
|
|
196
196
|
new_response: ModelResponse,
|
|
197
197
|
processed_response: ProcessedResponse,
|
|
198
|
-
output_schema:
|
|
198
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
199
199
|
hooks: RunHooks[TContext],
|
|
200
200
|
context_wrapper: RunContextWrapper[TContext],
|
|
201
201
|
run_config: RunConfig,
|
|
@@ -335,7 +335,7 @@ class RunImpl:
|
|
|
335
335
|
agent: Agent[Any],
|
|
336
336
|
all_tools: list[Tool],
|
|
337
337
|
response: ModelResponse,
|
|
338
|
-
output_schema:
|
|
338
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
339
339
|
handoffs: list[Handoff],
|
|
340
340
|
) -> ProcessedResponse:
|
|
341
341
|
items: list[RunItem] = []
|
agents/agent.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
|
|
8
8
|
|
|
9
9
|
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
|
10
10
|
|
|
11
|
+
from .agent_output import AgentOutputSchemaBase
|
|
11
12
|
from .guardrail import InputGuardrail, OutputGuardrail
|
|
12
13
|
from .handoffs import Handoff
|
|
13
14
|
from .items import ItemHelpers
|
|
@@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
|
|
|
141
142
|
Runs only if the agent produces a final output.
|
|
142
143
|
"""
|
|
143
144
|
|
|
144
|
-
output_type: type[Any] | None = None
|
|
145
|
-
"""The type of the output object. If not provided, the output will be `str`.
|
|
145
|
+
output_type: type[Any] | AgentOutputSchemaBase | None = None
|
|
146
|
+
"""The type of the output object. If not provided, the output will be `str`. In most cases,
|
|
147
|
+
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
|
|
148
|
+
You can customize this in two ways:
|
|
149
|
+
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
|
|
150
|
+
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
|
|
151
|
+
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
|
|
152
|
+
"""
|
|
146
153
|
|
|
147
154
|
hooks: AgentHooks[TContext] | None = None
|
|
148
155
|
"""A class that receives callbacks on various lifecycle events for this agent.
|
agents/agent_output.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import abc
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Any
|
|
3
4
|
|
|
@@ -12,8 +13,46 @@ from .util import _error_tracing, _json
|
|
|
12
13
|
_WRAPPER_DICT_KEY = "response"
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
class AgentOutputSchemaBase(abc.ABC):
|
|
17
|
+
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
|
18
|
+
produced by the LLM into the output type.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abc.abstractmethod
|
|
22
|
+
def is_plain_text(self) -> bool:
|
|
23
|
+
"""Whether the output type is plain text (versus a JSON object)."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abc.abstractmethod
|
|
27
|
+
def name(self) -> str:
|
|
28
|
+
"""The name of the output type."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abc.abstractmethod
|
|
32
|
+
def json_schema(self) -> dict[str, Any]:
|
|
33
|
+
"""Returns the JSON schema of the output. Will only be called if the output type is not
|
|
34
|
+
plain text.
|
|
35
|
+
"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def is_strict_json_schema(self) -> bool:
|
|
40
|
+
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
|
|
41
|
+
features, but guarantees valis JSON. See here for details:
|
|
42
|
+
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def validate_json(self, json_str: str) -> Any:
|
|
48
|
+
"""Validate a JSON string against the output type. You must return the validated object,
|
|
49
|
+
or raise a `ModelBehaviorError` if the JSON is invalid.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
|
|
15
54
|
@dataclass(init=False)
|
|
16
|
-
class AgentOutputSchema:
|
|
55
|
+
class AgentOutputSchema(AgentOutputSchemaBase):
|
|
17
56
|
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
|
|
18
57
|
produced by the LLM into the output type.
|
|
19
58
|
"""
|
|
@@ -32,7 +71,7 @@ class AgentOutputSchema:
|
|
|
32
71
|
_output_schema: dict[str, Any]
|
|
33
72
|
"""The JSON schema of the output."""
|
|
34
73
|
|
|
35
|
-
|
|
74
|
+
_strict_json_schema: bool
|
|
36
75
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
|
37
76
|
as it increases the likelihood of correct JSON input.
|
|
38
77
|
"""
|
|
@@ -45,7 +84,7 @@ class AgentOutputSchema:
|
|
|
45
84
|
setting this to True, as it increases the likelihood of correct JSON input.
|
|
46
85
|
"""
|
|
47
86
|
self.output_type = output_type
|
|
48
|
-
self.
|
|
87
|
+
self._strict_json_schema = strict_json_schema
|
|
49
88
|
|
|
50
89
|
if output_type is None or output_type is str:
|
|
51
90
|
self._is_wrapped = False
|
|
@@ -70,24 +109,35 @@ class AgentOutputSchema:
|
|
|
70
109
|
self._type_adapter = TypeAdapter(output_type)
|
|
71
110
|
self._output_schema = self._type_adapter.json_schema()
|
|
72
111
|
|
|
73
|
-
if self.
|
|
74
|
-
|
|
112
|
+
if self._strict_json_schema:
|
|
113
|
+
try:
|
|
114
|
+
self._output_schema = ensure_strict_json_schema(self._output_schema)
|
|
115
|
+
except UserError as e:
|
|
116
|
+
raise UserError(
|
|
117
|
+
"Strict JSON schema is enabled, but the output type is not valid. "
|
|
118
|
+
"Either make the output type strict, or pass output_schema_strict=False to "
|
|
119
|
+
"your Agent()"
|
|
120
|
+
) from e
|
|
75
121
|
|
|
76
122
|
def is_plain_text(self) -> bool:
|
|
77
123
|
"""Whether the output type is plain text (versus a JSON object)."""
|
|
78
124
|
return self.output_type is None or self.output_type is str
|
|
79
125
|
|
|
126
|
+
def is_strict_json_schema(self) -> bool:
|
|
127
|
+
"""Whether the JSON schema is in strict mode."""
|
|
128
|
+
return self._strict_json_schema
|
|
129
|
+
|
|
80
130
|
def json_schema(self) -> dict[str, Any]:
|
|
81
131
|
"""The JSON schema of the output type."""
|
|
82
132
|
if self.is_plain_text():
|
|
83
133
|
raise UserError("Output type is plain text, so no JSON schema is available")
|
|
84
134
|
return self._output_schema
|
|
85
135
|
|
|
86
|
-
def validate_json(self, json_str: str
|
|
136
|
+
def validate_json(self, json_str: str) -> Any:
|
|
87
137
|
"""Validate a JSON string against the output type. Returns the validated object, or raises
|
|
88
138
|
a `ModelBehaviorError` if the JSON is invalid.
|
|
89
139
|
"""
|
|
90
|
-
validated = _json.validate_json(json_str, self._type_adapter, partial)
|
|
140
|
+
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
|
|
91
141
|
if self._is_wrapped:
|
|
92
142
|
if not isinstance(validated, dict):
|
|
93
143
|
_error_tracing.attach_error_to_current_span(
|
|
@@ -113,7 +163,7 @@ class AgentOutputSchema:
|
|
|
113
163
|
return validated[_WRAPPER_DICT_KEY]
|
|
114
164
|
return validated
|
|
115
165
|
|
|
116
|
-
def
|
|
166
|
+
def name(self) -> str:
|
|
117
167
|
"""The name of the output type."""
|
|
118
168
|
return _type_to_str(self.output_type)
|
|
119
169
|
|
|
File without changes
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Any, Literal, cast, overload
|
|
8
|
+
|
|
9
|
+
import litellm.types
|
|
10
|
+
|
|
11
|
+
from agents.exceptions import ModelBehaviorError
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import litellm
|
|
15
|
+
except ImportError as _e:
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"`litellm` is required to use the LitellmModel. You can install it via the optional "
|
|
18
|
+
"dependency group: `pip install 'openai-agents[litellm]'`."
|
|
19
|
+
) from _e
|
|
20
|
+
|
|
21
|
+
from openai import NOT_GIVEN, AsyncStream, NotGiven
|
|
22
|
+
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageToolCall
|
|
23
|
+
from openai.types.chat.chat_completion_message import (
|
|
24
|
+
Annotation,
|
|
25
|
+
AnnotationURLCitation,
|
|
26
|
+
ChatCompletionMessage,
|
|
27
|
+
)
|
|
28
|
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
|
29
|
+
from openai.types.responses import Response
|
|
30
|
+
|
|
31
|
+
from ... import _debug
|
|
32
|
+
from ...agent_output import AgentOutputSchemaBase
|
|
33
|
+
from ...handoffs import Handoff
|
|
34
|
+
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
|
35
|
+
from ...logger import logger
|
|
36
|
+
from ...model_settings import ModelSettings
|
|
37
|
+
from ...models.chatcmpl_converter import Converter
|
|
38
|
+
from ...models.chatcmpl_helpers import HEADERS
|
|
39
|
+
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
|
|
40
|
+
from ...models.fake_id import FAKE_RESPONSES_ID
|
|
41
|
+
from ...models.interface import Model, ModelTracing
|
|
42
|
+
from ...tool import Tool
|
|
43
|
+
from ...tracing import generation_span
|
|
44
|
+
from ...tracing.span_data import GenerationSpanData
|
|
45
|
+
from ...tracing.spans import Span
|
|
46
|
+
from ...usage import Usage
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class LitellmModel(Model):
|
|
50
|
+
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
|
|
51
|
+
Anthropic, Gemini, Mistral, and many other models.
|
|
52
|
+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model: str,
|
|
58
|
+
base_url: str | None = None,
|
|
59
|
+
api_key: str | None = None,
|
|
60
|
+
):
|
|
61
|
+
self.model = model
|
|
62
|
+
self.base_url = base_url
|
|
63
|
+
self.api_key = api_key
|
|
64
|
+
|
|
65
|
+
async def get_response(
|
|
66
|
+
self,
|
|
67
|
+
system_instructions: str | None,
|
|
68
|
+
input: str | list[TResponseInputItem],
|
|
69
|
+
model_settings: ModelSettings,
|
|
70
|
+
tools: list[Tool],
|
|
71
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
72
|
+
handoffs: list[Handoff],
|
|
73
|
+
tracing: ModelTracing,
|
|
74
|
+
previous_response_id: str | None,
|
|
75
|
+
) -> ModelResponse:
|
|
76
|
+
with generation_span(
|
|
77
|
+
model=str(self.model),
|
|
78
|
+
model_config=dataclasses.asdict(model_settings)
|
|
79
|
+
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
|
80
|
+
disabled=tracing.is_disabled(),
|
|
81
|
+
) as span_generation:
|
|
82
|
+
response = await self._fetch_response(
|
|
83
|
+
system_instructions,
|
|
84
|
+
input,
|
|
85
|
+
model_settings,
|
|
86
|
+
tools,
|
|
87
|
+
output_schema,
|
|
88
|
+
handoffs,
|
|
89
|
+
span_generation,
|
|
90
|
+
tracing,
|
|
91
|
+
stream=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
assert isinstance(response.choices[0], litellm.types.utils.Choices)
|
|
95
|
+
|
|
96
|
+
if _debug.DONT_LOG_MODEL_DATA:
|
|
97
|
+
logger.debug("Received model response")
|
|
98
|
+
else:
|
|
99
|
+
logger.debug(
|
|
100
|
+
f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if hasattr(response, "usage"):
|
|
104
|
+
response_usage = response.usage
|
|
105
|
+
usage = (
|
|
106
|
+
Usage(
|
|
107
|
+
requests=1,
|
|
108
|
+
input_tokens=response_usage.prompt_tokens,
|
|
109
|
+
output_tokens=response_usage.completion_tokens,
|
|
110
|
+
total_tokens=response_usage.total_tokens,
|
|
111
|
+
)
|
|
112
|
+
if response.usage
|
|
113
|
+
else Usage()
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
usage = Usage()
|
|
117
|
+
logger.warning("No usage information returned from Litellm")
|
|
118
|
+
|
|
119
|
+
if tracing.include_data():
|
|
120
|
+
span_generation.span_data.output = [response.choices[0].message.model_dump()]
|
|
121
|
+
span_generation.span_data.usage = {
|
|
122
|
+
"input_tokens": usage.input_tokens,
|
|
123
|
+
"output_tokens": usage.output_tokens,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
items = Converter.message_to_output_items(
|
|
127
|
+
LitellmConverter.convert_message_to_openai(response.choices[0].message)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return ModelResponse(
|
|
131
|
+
output=items,
|
|
132
|
+
usage=usage,
|
|
133
|
+
response_id=None,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def stream_response(
|
|
137
|
+
self,
|
|
138
|
+
system_instructions: str | None,
|
|
139
|
+
input: str | list[TResponseInputItem],
|
|
140
|
+
model_settings: ModelSettings,
|
|
141
|
+
tools: list[Tool],
|
|
142
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
143
|
+
handoffs: list[Handoff],
|
|
144
|
+
tracing: ModelTracing,
|
|
145
|
+
*,
|
|
146
|
+
previous_response_id: str | None,
|
|
147
|
+
) -> AsyncIterator[TResponseStreamEvent]:
|
|
148
|
+
with generation_span(
|
|
149
|
+
model=str(self.model),
|
|
150
|
+
model_config=dataclasses.asdict(model_settings)
|
|
151
|
+
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
|
|
152
|
+
disabled=tracing.is_disabled(),
|
|
153
|
+
) as span_generation:
|
|
154
|
+
response, stream = await self._fetch_response(
|
|
155
|
+
system_instructions,
|
|
156
|
+
input,
|
|
157
|
+
model_settings,
|
|
158
|
+
tools,
|
|
159
|
+
output_schema,
|
|
160
|
+
handoffs,
|
|
161
|
+
span_generation,
|
|
162
|
+
tracing,
|
|
163
|
+
stream=True,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
final_response: Response | None = None
|
|
167
|
+
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
|
|
168
|
+
yield chunk
|
|
169
|
+
|
|
170
|
+
if chunk.type == "response.completed":
|
|
171
|
+
final_response = chunk.response
|
|
172
|
+
|
|
173
|
+
if tracing.include_data() and final_response:
|
|
174
|
+
span_generation.span_data.output = [final_response.model_dump()]
|
|
175
|
+
|
|
176
|
+
if final_response and final_response.usage:
|
|
177
|
+
span_generation.span_data.usage = {
|
|
178
|
+
"input_tokens": final_response.usage.input_tokens,
|
|
179
|
+
"output_tokens": final_response.usage.output_tokens,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
@overload
|
|
183
|
+
async def _fetch_response(
|
|
184
|
+
self,
|
|
185
|
+
system_instructions: str | None,
|
|
186
|
+
input: str | list[TResponseInputItem],
|
|
187
|
+
model_settings: ModelSettings,
|
|
188
|
+
tools: list[Tool],
|
|
189
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
190
|
+
handoffs: list[Handoff],
|
|
191
|
+
span: Span[GenerationSpanData],
|
|
192
|
+
tracing: ModelTracing,
|
|
193
|
+
stream: Literal[True],
|
|
194
|
+
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
|
195
|
+
|
|
196
|
+
@overload
|
|
197
|
+
async def _fetch_response(
|
|
198
|
+
self,
|
|
199
|
+
system_instructions: str | None,
|
|
200
|
+
input: str | list[TResponseInputItem],
|
|
201
|
+
model_settings: ModelSettings,
|
|
202
|
+
tools: list[Tool],
|
|
203
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
204
|
+
handoffs: list[Handoff],
|
|
205
|
+
span: Span[GenerationSpanData],
|
|
206
|
+
tracing: ModelTracing,
|
|
207
|
+
stream: Literal[False],
|
|
208
|
+
) -> litellm.types.utils.ModelResponse: ...
|
|
209
|
+
|
|
210
|
+
async def _fetch_response(
|
|
211
|
+
self,
|
|
212
|
+
system_instructions: str | None,
|
|
213
|
+
input: str | list[TResponseInputItem],
|
|
214
|
+
model_settings: ModelSettings,
|
|
215
|
+
tools: list[Tool],
|
|
216
|
+
output_schema: AgentOutputSchemaBase | None,
|
|
217
|
+
handoffs: list[Handoff],
|
|
218
|
+
span: Span[GenerationSpanData],
|
|
219
|
+
tracing: ModelTracing,
|
|
220
|
+
stream: bool = False,
|
|
221
|
+
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
|
222
|
+
converted_messages = Converter.items_to_messages(input)
|
|
223
|
+
|
|
224
|
+
if system_instructions:
|
|
225
|
+
converted_messages.insert(
|
|
226
|
+
0,
|
|
227
|
+
{
|
|
228
|
+
"content": system_instructions,
|
|
229
|
+
"role": "system",
|
|
230
|
+
},
|
|
231
|
+
)
|
|
232
|
+
if tracing.include_data():
|
|
233
|
+
span.span_data.input = converted_messages
|
|
234
|
+
|
|
235
|
+
parallel_tool_calls = (
|
|
236
|
+
True
|
|
237
|
+
if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
|
238
|
+
else False
|
|
239
|
+
if model_settings.parallel_tool_calls is False
|
|
240
|
+
else None
|
|
241
|
+
)
|
|
242
|
+
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
|
|
243
|
+
response_format = Converter.convert_response_format(output_schema)
|
|
244
|
+
|
|
245
|
+
converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else []
|
|
246
|
+
|
|
247
|
+
for handoff in handoffs:
|
|
248
|
+
converted_tools.append(Converter.convert_handoff_tool(handoff))
|
|
249
|
+
|
|
250
|
+
if _debug.DONT_LOG_MODEL_DATA:
|
|
251
|
+
logger.debug("Calling LLM")
|
|
252
|
+
else:
|
|
253
|
+
logger.debug(
|
|
254
|
+
f"Calling Litellm model: {self.model}\n"
|
|
255
|
+
f"{json.dumps(converted_messages, indent=2)}\n"
|
|
256
|
+
f"Tools:\n{json.dumps(converted_tools, indent=2)}\n"
|
|
257
|
+
f"Stream: {stream}\n"
|
|
258
|
+
f"Tool choice: {tool_choice}\n"
|
|
259
|
+
f"Response format: {response_format}\n"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
|
263
|
+
|
|
264
|
+
stream_options = None
|
|
265
|
+
if stream and model_settings.include_usage is not None:
|
|
266
|
+
stream_options = {"include_usage": model_settings.include_usage}
|
|
267
|
+
|
|
268
|
+
extra_kwargs = {}
|
|
269
|
+
if model_settings.extra_query:
|
|
270
|
+
extra_kwargs["extra_query"] = model_settings.extra_query
|
|
271
|
+
if model_settings.metadata:
|
|
272
|
+
extra_kwargs["metadata"] = model_settings.metadata
|
|
273
|
+
|
|
274
|
+
ret = await litellm.acompletion(
|
|
275
|
+
model=self.model,
|
|
276
|
+
messages=converted_messages,
|
|
277
|
+
tools=converted_tools or None,
|
|
278
|
+
temperature=model_settings.temperature,
|
|
279
|
+
top_p=model_settings.top_p,
|
|
280
|
+
frequency_penalty=model_settings.frequency_penalty,
|
|
281
|
+
presence_penalty=model_settings.presence_penalty,
|
|
282
|
+
max_tokens=model_settings.max_tokens,
|
|
283
|
+
tool_choice=self._remove_not_given(tool_choice),
|
|
284
|
+
response_format=self._remove_not_given(response_format),
|
|
285
|
+
parallel_tool_calls=parallel_tool_calls,
|
|
286
|
+
stream=stream,
|
|
287
|
+
stream_options=stream_options,
|
|
288
|
+
reasoning_effort=reasoning_effort,
|
|
289
|
+
extra_headers=HEADERS,
|
|
290
|
+
api_key=self.api_key,
|
|
291
|
+
base_url=self.base_url,
|
|
292
|
+
**extra_kwargs,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if isinstance(ret, litellm.types.utils.ModelResponse):
|
|
296
|
+
return ret
|
|
297
|
+
|
|
298
|
+
response = Response(
|
|
299
|
+
id=FAKE_RESPONSES_ID,
|
|
300
|
+
created_at=time.time(),
|
|
301
|
+
model=self.model,
|
|
302
|
+
object="response",
|
|
303
|
+
output=[],
|
|
304
|
+
tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
|
|
305
|
+
if tool_choice != NOT_GIVEN
|
|
306
|
+
else "auto",
|
|
307
|
+
top_p=model_settings.top_p,
|
|
308
|
+
temperature=model_settings.temperature,
|
|
309
|
+
tools=[],
|
|
310
|
+
parallel_tool_calls=parallel_tool_calls or False,
|
|
311
|
+
reasoning=model_settings.reasoning,
|
|
312
|
+
)
|
|
313
|
+
return response, ret
|
|
314
|
+
|
|
315
|
+
def _remove_not_given(self, value: Any) -> Any:
|
|
316
|
+
if isinstance(value, NotGiven):
|
|
317
|
+
return None
|
|
318
|
+
return value
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class LitellmConverter:
|
|
322
|
+
@classmethod
|
|
323
|
+
def convert_message_to_openai(
|
|
324
|
+
cls, message: litellm.types.utils.Message
|
|
325
|
+
) -> ChatCompletionMessage:
|
|
326
|
+
if message.role != "assistant":
|
|
327
|
+
raise ModelBehaviorError(f"Unsupported role: {message.role}")
|
|
328
|
+
|
|
329
|
+
tool_calls = (
|
|
330
|
+
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
|
|
331
|
+
if message.tool_calls
|
|
332
|
+
else None
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
provider_specific_fields = message.get("provider_specific_fields", None)
|
|
336
|
+
refusal = (
|
|
337
|
+
provider_specific_fields.get("refusal", None) if provider_specific_fields else None
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
return ChatCompletionMessage(
|
|
341
|
+
content=message.content,
|
|
342
|
+
refusal=refusal,
|
|
343
|
+
role="assistant",
|
|
344
|
+
annotations=cls.convert_annotations_to_openai(message),
|
|
345
|
+
audio=message.get("audio", None), # litellm deletes audio if not present
|
|
346
|
+
tool_calls=tool_calls,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
@classmethod
|
|
350
|
+
def convert_annotations_to_openai(
|
|
351
|
+
cls, message: litellm.types.utils.Message
|
|
352
|
+
) -> list[Annotation] | None:
|
|
353
|
+
annotations: list[litellm.types.llms.openai.ChatCompletionAnnotation] | None = message.get(
|
|
354
|
+
"annotations", None
|
|
355
|
+
)
|
|
356
|
+
if not annotations:
|
|
357
|
+
return None
|
|
358
|
+
|
|
359
|
+
return [
|
|
360
|
+
Annotation(
|
|
361
|
+
type="url_citation",
|
|
362
|
+
url_citation=AnnotationURLCitation(
|
|
363
|
+
start_index=annotation["url_citation"]["start_index"],
|
|
364
|
+
end_index=annotation["url_citation"]["end_index"],
|
|
365
|
+
url=annotation["url_citation"]["url"],
|
|
366
|
+
title=annotation["url_citation"]["title"],
|
|
367
|
+
),
|
|
368
|
+
)
|
|
369
|
+
for annotation in annotations
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
@classmethod
|
|
373
|
+
def convert_tool_call_to_openai(
|
|
374
|
+
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
|
|
375
|
+
) -> ChatCompletionMessageToolCall:
|
|
376
|
+
return ChatCompletionMessageToolCall(
|
|
377
|
+
id=tool_call.id,
|
|
378
|
+
type="function",
|
|
379
|
+
function=Function(
|
|
380
|
+
name=tool_call.function.name or "", arguments=tool_call.function.arguments
|
|
381
|
+
),
|
|
382
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from ...models.interface import Model, ModelProvider
|
|
2
|
+
from .litellm_model import LitellmModel
|
|
3
|
+
|
|
4
|
+
DEFAULT_MODEL: str = "gpt-4.1"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LitellmProvider(ModelProvider):
|
|
8
|
+
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
|
|
9
|
+
```python
|
|
10
|
+
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
|
|
11
|
+
```
|
|
12
|
+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
|
|
13
|
+
|
|
14
|
+
NOTE: API keys must be set via environment variables. If you're using models that require
|
|
15
|
+
additional configuration (e.g. Azure API base or version), those must also be set via the
|
|
16
|
+
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
|
|
17
|
+
copy-pasting this class and making any modifications you need.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def get_model(self, model_name: str | None) -> Model:
|
|
21
|
+
return LitellmModel(model_name or DEFAULT_MODEL)
|
|
@@ -132,6 +132,6 @@ def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
|
|
|
132
132
|
graph = graphviz.Source(dot_code)
|
|
133
133
|
|
|
134
134
|
if filename:
|
|
135
|
-
graph.render(filename, format="png")
|
|
135
|
+
graph.render(filename, format="png", cleanup=True)
|
|
136
136
|
|
|
137
137
|
return graph
|
agents/mcp/server.py
CHANGED
|
@@ -137,9 +137,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
|
|
137
137
|
async with self._cleanup_lock:
|
|
138
138
|
try:
|
|
139
139
|
await self.exit_stack.aclose()
|
|
140
|
-
self.session = None
|
|
141
140
|
except Exception as e:
|
|
142
141
|
logger.error(f"Error cleaning up server: {e}")
|
|
142
|
+
finally:
|
|
143
|
+
self.session = None
|
|
143
144
|
|
|
144
145
|
|
|
145
146
|
class MCPServerStdioParams(TypedDict):
|