openai-agents 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openai-agents might be problematic. Click here for more details.

agents/__init__.py CHANGED
@@ -6,7 +6,7 @@ from openai import AsyncOpenAI
6
6
 
7
7
  from . import _config
8
8
  from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9
- from .agent_output import AgentOutputSchema
9
+ from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
10
10
  from .computer import AsyncComputer, Button, Computer, Environment
11
11
  from .exceptions import (
12
12
  AgentsException,
@@ -158,6 +158,7 @@ __all__ = [
158
158
  "OpenAIProvider",
159
159
  "OpenAIResponsesModel",
160
160
  "AgentOutputSchema",
161
+ "AgentOutputSchemaBase",
161
162
  "Computer",
162
163
  "AsyncComputer",
163
164
  "Environment",
agents/_run_impl.py CHANGED
@@ -29,7 +29,7 @@ from openai.types.responses.response_input_param import ComputerCallOutput
29
29
  from openai.types.responses.response_reasoning_item import ResponseReasoningItem
30
30
 
31
31
  from .agent import Agent, ToolsToFinalOutputResult
32
- from .agent_output import AgentOutputSchema
32
+ from .agent_output import AgentOutputSchemaBase
33
33
  from .computer import AsyncComputer, Computer
34
34
  from .exceptions import AgentsException, ModelBehaviorError, UserError
35
35
  from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
@@ -195,7 +195,7 @@ class RunImpl:
195
195
  pre_step_items: list[RunItem],
196
196
  new_response: ModelResponse,
197
197
  processed_response: ProcessedResponse,
198
- output_schema: AgentOutputSchema | None,
198
+ output_schema: AgentOutputSchemaBase | None,
199
199
  hooks: RunHooks[TContext],
200
200
  context_wrapper: RunContextWrapper[TContext],
201
201
  run_config: RunConfig,
@@ -335,7 +335,7 @@ class RunImpl:
335
335
  agent: Agent[Any],
336
336
  all_tools: list[Tool],
337
337
  response: ModelResponse,
338
- output_schema: AgentOutputSchema | None,
338
+ output_schema: AgentOutputSchemaBase | None,
339
339
  handoffs: list[Handoff],
340
340
  ) -> ProcessedResponse:
341
341
  items: list[RunItem] = []
agents/agent.py CHANGED
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
8
8
 
9
9
  from typing_extensions import NotRequired, TypeAlias, TypedDict
10
10
 
11
+ from .agent_output import AgentOutputSchemaBase
11
12
  from .guardrail import InputGuardrail, OutputGuardrail
12
13
  from .handoffs import Handoff
13
14
  from .items import ItemHelpers
@@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
141
142
  Runs only if the agent produces a final output.
142
143
  """
143
144
 
144
- output_type: type[Any] | None = None
145
- """The type of the output object. If not provided, the output will be `str`."""
145
+ output_type: type[Any] | AgentOutputSchemaBase | None = None
146
+ """The type of the output object. If not provided, the output will be `str`. In most cases,
147
+ you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
148
+ You can customize this in two ways:
149
+ 1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
150
+ 2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
151
+ creation, subclass and pass an `AgentOutputSchemaBase` subclass.
152
+ """
146
153
 
147
154
  hooks: AgentHooks[TContext] | None = None
148
155
  """A class that receives callbacks on various lifecycle events for this agent.
agents/agent_output.py CHANGED
@@ -1,3 +1,4 @@
1
+ import abc
1
2
  from dataclasses import dataclass
2
3
  from typing import Any
3
4
 
@@ -12,8 +13,46 @@ from .util import _error_tracing, _json
12
13
  _WRAPPER_DICT_KEY = "response"
13
14
 
14
15
 
16
+ class AgentOutputSchemaBase(abc.ABC):
17
+ """An object that captures the JSON schema of the output, as well as validating/parsing JSON
18
+ produced by the LLM into the output type.
19
+ """
20
+
21
+ @abc.abstractmethod
22
+ def is_plain_text(self) -> bool:
23
+ """Whether the output type is plain text (versus a JSON object)."""
24
+ pass
25
+
26
+ @abc.abstractmethod
27
+ def name(self) -> str:
28
+ """The name of the output type."""
29
+ pass
30
+
31
+ @abc.abstractmethod
32
+ def json_schema(self) -> dict[str, Any]:
33
+ """Returns the JSON schema of the output. Will only be called if the output type is not
34
+ plain text.
35
+ """
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def is_strict_json_schema(self) -> bool:
40
+ """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41
+ features, but guarantees valis JSON. See here for details:
42
+ https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43
+ """
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def validate_json(self, json_str: str) -> Any:
48
+ """Validate a JSON string against the output type. You must return the validated object,
49
+ or raise a `ModelBehaviorError` if the JSON is invalid.
50
+ """
51
+ pass
52
+
53
+
15
54
  @dataclass(init=False)
16
- class AgentOutputSchema:
55
+ class AgentOutputSchema(AgentOutputSchemaBase):
17
56
  """An object that captures the JSON schema of the output, as well as validating/parsing JSON
18
57
  produced by the LLM into the output type.
19
58
  """
@@ -32,7 +71,7 @@ class AgentOutputSchema:
32
71
  _output_schema: dict[str, Any]
33
72
  """The JSON schema of the output."""
34
73
 
35
- strict_json_schema: bool
74
+ _strict_json_schema: bool
36
75
  """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
37
76
  as it increases the likelihood of correct JSON input.
38
77
  """
@@ -45,7 +84,7 @@ class AgentOutputSchema:
45
84
  setting this to True, as it increases the likelihood of correct JSON input.
46
85
  """
47
86
  self.output_type = output_type
48
- self.strict_json_schema = strict_json_schema
87
+ self._strict_json_schema = strict_json_schema
49
88
 
50
89
  if output_type is None or output_type is str:
51
90
  self._is_wrapped = False
@@ -70,24 +109,35 @@ class AgentOutputSchema:
70
109
  self._type_adapter = TypeAdapter(output_type)
71
110
  self._output_schema = self._type_adapter.json_schema()
72
111
 
73
- if self.strict_json_schema:
74
- self._output_schema = ensure_strict_json_schema(self._output_schema)
112
+ if self._strict_json_schema:
113
+ try:
114
+ self._output_schema = ensure_strict_json_schema(self._output_schema)
115
+ except UserError as e:
116
+ raise UserError(
117
+ "Strict JSON schema is enabled, but the output type is not valid. "
118
+ "Either make the output type strict, or pass output_schema_strict=False to "
119
+ "your Agent()"
120
+ ) from e
75
121
 
76
122
  def is_plain_text(self) -> bool:
77
123
  """Whether the output type is plain text (versus a JSON object)."""
78
124
  return self.output_type is None or self.output_type is str
79
125
 
126
+ def is_strict_json_schema(self) -> bool:
127
+ """Whether the JSON schema is in strict mode."""
128
+ return self._strict_json_schema
129
+
80
130
  def json_schema(self) -> dict[str, Any]:
81
131
  """The JSON schema of the output type."""
82
132
  if self.is_plain_text():
83
133
  raise UserError("Output type is plain text, so no JSON schema is available")
84
134
  return self._output_schema
85
135
 
86
- def validate_json(self, json_str: str, partial: bool = False) -> Any:
136
+ def validate_json(self, json_str: str) -> Any:
87
137
  """Validate a JSON string against the output type. Returns the validated object, or raises
88
138
  a `ModelBehaviorError` if the JSON is invalid.
89
139
  """
90
- validated = _json.validate_json(json_str, self._type_adapter, partial)
140
+ validated = _json.validate_json(json_str, self._type_adapter, partial=False)
91
141
  if self._is_wrapped:
92
142
  if not isinstance(validated, dict):
93
143
  _error_tracing.attach_error_to_current_span(
@@ -113,7 +163,7 @@ class AgentOutputSchema:
113
163
  return validated[_WRAPPER_DICT_KEY]
114
164
  return validated
115
165
 
116
- def output_type_name(self) -> str:
166
+ def name(self) -> str:
117
167
  """The name of the output type."""
118
168
  return _type_to_str(self.output_type)
119
169
 
File without changes
@@ -0,0 +1,382 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import json
5
+ import time
6
+ from collections.abc import AsyncIterator
7
+ from typing import Any, Literal, cast, overload
8
+
9
+ import litellm.types
10
+
11
+ from agents.exceptions import ModelBehaviorError
12
+
13
+ try:
14
+ import litellm
15
+ except ImportError as _e:
16
+ raise ImportError(
17
+ "`litellm` is required to use the LitellmModel. You can install it via the optional "
18
+ "dependency group: `pip install 'openai-agents[litellm]'`."
19
+ ) from _e
20
+
21
+ from openai import NOT_GIVEN, AsyncStream, NotGiven
22
+ from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageToolCall
23
+ from openai.types.chat.chat_completion_message import (
24
+ Annotation,
25
+ AnnotationURLCitation,
26
+ ChatCompletionMessage,
27
+ )
28
+ from openai.types.chat.chat_completion_message_tool_call import Function
29
+ from openai.types.responses import Response
30
+
31
+ from ... import _debug
32
+ from ...agent_output import AgentOutputSchemaBase
33
+ from ...handoffs import Handoff
34
+ from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
35
+ from ...logger import logger
36
+ from ...model_settings import ModelSettings
37
+ from ...models.chatcmpl_converter import Converter
38
+ from ...models.chatcmpl_helpers import HEADERS
39
+ from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
40
+ from ...models.fake_id import FAKE_RESPONSES_ID
41
+ from ...models.interface import Model, ModelTracing
42
+ from ...tool import Tool
43
+ from ...tracing import generation_span
44
+ from ...tracing.span_data import GenerationSpanData
45
+ from ...tracing.spans import Span
46
+ from ...usage import Usage
47
+
48
+
49
+ class LitellmModel(Model):
50
+ """This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
51
+ Anthropic, Gemini, Mistral, and many other models.
52
+ See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model: str,
58
+ base_url: str | None = None,
59
+ api_key: str | None = None,
60
+ ):
61
+ self.model = model
62
+ self.base_url = base_url
63
+ self.api_key = api_key
64
+
65
+ async def get_response(
66
+ self,
67
+ system_instructions: str | None,
68
+ input: str | list[TResponseInputItem],
69
+ model_settings: ModelSettings,
70
+ tools: list[Tool],
71
+ output_schema: AgentOutputSchemaBase | None,
72
+ handoffs: list[Handoff],
73
+ tracing: ModelTracing,
74
+ previous_response_id: str | None,
75
+ ) -> ModelResponse:
76
+ with generation_span(
77
+ model=str(self.model),
78
+ model_config=dataclasses.asdict(model_settings)
79
+ | {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
80
+ disabled=tracing.is_disabled(),
81
+ ) as span_generation:
82
+ response = await self._fetch_response(
83
+ system_instructions,
84
+ input,
85
+ model_settings,
86
+ tools,
87
+ output_schema,
88
+ handoffs,
89
+ span_generation,
90
+ tracing,
91
+ stream=False,
92
+ )
93
+
94
+ assert isinstance(response.choices[0], litellm.types.utils.Choices)
95
+
96
+ if _debug.DONT_LOG_MODEL_DATA:
97
+ logger.debug("Received model response")
98
+ else:
99
+ logger.debug(
100
+ f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
101
+ )
102
+
103
+ if hasattr(response, "usage"):
104
+ response_usage = response.usage
105
+ usage = (
106
+ Usage(
107
+ requests=1,
108
+ input_tokens=response_usage.prompt_tokens,
109
+ output_tokens=response_usage.completion_tokens,
110
+ total_tokens=response_usage.total_tokens,
111
+ )
112
+ if response.usage
113
+ else Usage()
114
+ )
115
+ else:
116
+ usage = Usage()
117
+ logger.warning("No usage information returned from Litellm")
118
+
119
+ if tracing.include_data():
120
+ span_generation.span_data.output = [response.choices[0].message.model_dump()]
121
+ span_generation.span_data.usage = {
122
+ "input_tokens": usage.input_tokens,
123
+ "output_tokens": usage.output_tokens,
124
+ }
125
+
126
+ items = Converter.message_to_output_items(
127
+ LitellmConverter.convert_message_to_openai(response.choices[0].message)
128
+ )
129
+
130
+ return ModelResponse(
131
+ output=items,
132
+ usage=usage,
133
+ response_id=None,
134
+ )
135
+
136
+ async def stream_response(
137
+ self,
138
+ system_instructions: str | None,
139
+ input: str | list[TResponseInputItem],
140
+ model_settings: ModelSettings,
141
+ tools: list[Tool],
142
+ output_schema: AgentOutputSchemaBase | None,
143
+ handoffs: list[Handoff],
144
+ tracing: ModelTracing,
145
+ *,
146
+ previous_response_id: str | None,
147
+ ) -> AsyncIterator[TResponseStreamEvent]:
148
+ with generation_span(
149
+ model=str(self.model),
150
+ model_config=dataclasses.asdict(model_settings)
151
+ | {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
152
+ disabled=tracing.is_disabled(),
153
+ ) as span_generation:
154
+ response, stream = await self._fetch_response(
155
+ system_instructions,
156
+ input,
157
+ model_settings,
158
+ tools,
159
+ output_schema,
160
+ handoffs,
161
+ span_generation,
162
+ tracing,
163
+ stream=True,
164
+ )
165
+
166
+ final_response: Response | None = None
167
+ async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
168
+ yield chunk
169
+
170
+ if chunk.type == "response.completed":
171
+ final_response = chunk.response
172
+
173
+ if tracing.include_data() and final_response:
174
+ span_generation.span_data.output = [final_response.model_dump()]
175
+
176
+ if final_response and final_response.usage:
177
+ span_generation.span_data.usage = {
178
+ "input_tokens": final_response.usage.input_tokens,
179
+ "output_tokens": final_response.usage.output_tokens,
180
+ }
181
+
182
+ @overload
183
+ async def _fetch_response(
184
+ self,
185
+ system_instructions: str | None,
186
+ input: str | list[TResponseInputItem],
187
+ model_settings: ModelSettings,
188
+ tools: list[Tool],
189
+ output_schema: AgentOutputSchemaBase | None,
190
+ handoffs: list[Handoff],
191
+ span: Span[GenerationSpanData],
192
+ tracing: ModelTracing,
193
+ stream: Literal[True],
194
+ ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
195
+
196
+ @overload
197
+ async def _fetch_response(
198
+ self,
199
+ system_instructions: str | None,
200
+ input: str | list[TResponseInputItem],
201
+ model_settings: ModelSettings,
202
+ tools: list[Tool],
203
+ output_schema: AgentOutputSchemaBase | None,
204
+ handoffs: list[Handoff],
205
+ span: Span[GenerationSpanData],
206
+ tracing: ModelTracing,
207
+ stream: Literal[False],
208
+ ) -> litellm.types.utils.ModelResponse: ...
209
+
210
+ async def _fetch_response(
211
+ self,
212
+ system_instructions: str | None,
213
+ input: str | list[TResponseInputItem],
214
+ model_settings: ModelSettings,
215
+ tools: list[Tool],
216
+ output_schema: AgentOutputSchemaBase | None,
217
+ handoffs: list[Handoff],
218
+ span: Span[GenerationSpanData],
219
+ tracing: ModelTracing,
220
+ stream: bool = False,
221
+ ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
222
+ converted_messages = Converter.items_to_messages(input)
223
+
224
+ if system_instructions:
225
+ converted_messages.insert(
226
+ 0,
227
+ {
228
+ "content": system_instructions,
229
+ "role": "system",
230
+ },
231
+ )
232
+ if tracing.include_data():
233
+ span.span_data.input = converted_messages
234
+
235
+ parallel_tool_calls = (
236
+ True
237
+ if model_settings.parallel_tool_calls and tools and len(tools) > 0
238
+ else False
239
+ if model_settings.parallel_tool_calls is False
240
+ else None
241
+ )
242
+ tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
243
+ response_format = Converter.convert_response_format(output_schema)
244
+
245
+ converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else []
246
+
247
+ for handoff in handoffs:
248
+ converted_tools.append(Converter.convert_handoff_tool(handoff))
249
+
250
+ if _debug.DONT_LOG_MODEL_DATA:
251
+ logger.debug("Calling LLM")
252
+ else:
253
+ logger.debug(
254
+ f"Calling Litellm model: {self.model}\n"
255
+ f"{json.dumps(converted_messages, indent=2)}\n"
256
+ f"Tools:\n{json.dumps(converted_tools, indent=2)}\n"
257
+ f"Stream: {stream}\n"
258
+ f"Tool choice: {tool_choice}\n"
259
+ f"Response format: {response_format}\n"
260
+ )
261
+
262
+ reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
263
+
264
+ stream_options = None
265
+ if stream and model_settings.include_usage is not None:
266
+ stream_options = {"include_usage": model_settings.include_usage}
267
+
268
+ extra_kwargs = {}
269
+ if model_settings.extra_query:
270
+ extra_kwargs["extra_query"] = model_settings.extra_query
271
+ if model_settings.metadata:
272
+ extra_kwargs["metadata"] = model_settings.metadata
273
+
274
+ ret = await litellm.acompletion(
275
+ model=self.model,
276
+ messages=converted_messages,
277
+ tools=converted_tools or None,
278
+ temperature=model_settings.temperature,
279
+ top_p=model_settings.top_p,
280
+ frequency_penalty=model_settings.frequency_penalty,
281
+ presence_penalty=model_settings.presence_penalty,
282
+ max_tokens=model_settings.max_tokens,
283
+ tool_choice=self._remove_not_given(tool_choice),
284
+ response_format=self._remove_not_given(response_format),
285
+ parallel_tool_calls=parallel_tool_calls,
286
+ stream=stream,
287
+ stream_options=stream_options,
288
+ reasoning_effort=reasoning_effort,
289
+ extra_headers=HEADERS,
290
+ api_key=self.api_key,
291
+ base_url=self.base_url,
292
+ **extra_kwargs,
293
+ )
294
+
295
+ if isinstance(ret, litellm.types.utils.ModelResponse):
296
+ return ret
297
+
298
+ response = Response(
299
+ id=FAKE_RESPONSES_ID,
300
+ created_at=time.time(),
301
+ model=self.model,
302
+ object="response",
303
+ output=[],
304
+ tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
305
+ if tool_choice != NOT_GIVEN
306
+ else "auto",
307
+ top_p=model_settings.top_p,
308
+ temperature=model_settings.temperature,
309
+ tools=[],
310
+ parallel_tool_calls=parallel_tool_calls or False,
311
+ reasoning=model_settings.reasoning,
312
+ )
313
+ return response, ret
314
+
315
+ def _remove_not_given(self, value: Any) -> Any:
316
+ if isinstance(value, NotGiven):
317
+ return None
318
+ return value
319
+
320
+
321
+ class LitellmConverter:
322
+ @classmethod
323
+ def convert_message_to_openai(
324
+ cls, message: litellm.types.utils.Message
325
+ ) -> ChatCompletionMessage:
326
+ if message.role != "assistant":
327
+ raise ModelBehaviorError(f"Unsupported role: {message.role}")
328
+
329
+ tool_calls = (
330
+ [LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
331
+ if message.tool_calls
332
+ else None
333
+ )
334
+
335
+ provider_specific_fields = message.get("provider_specific_fields", None)
336
+ refusal = (
337
+ provider_specific_fields.get("refusal", None) if provider_specific_fields else None
338
+ )
339
+
340
+ return ChatCompletionMessage(
341
+ content=message.content,
342
+ refusal=refusal,
343
+ role="assistant",
344
+ annotations=cls.convert_annotations_to_openai(message),
345
+ audio=message.get("audio", None), # litellm deletes audio if not present
346
+ tool_calls=tool_calls,
347
+ )
348
+
349
+ @classmethod
350
+ def convert_annotations_to_openai(
351
+ cls, message: litellm.types.utils.Message
352
+ ) -> list[Annotation] | None:
353
+ annotations: list[litellm.types.llms.openai.ChatCompletionAnnotation] | None = message.get(
354
+ "annotations", None
355
+ )
356
+ if not annotations:
357
+ return None
358
+
359
+ return [
360
+ Annotation(
361
+ type="url_citation",
362
+ url_citation=AnnotationURLCitation(
363
+ start_index=annotation["url_citation"]["start_index"],
364
+ end_index=annotation["url_citation"]["end_index"],
365
+ url=annotation["url_citation"]["url"],
366
+ title=annotation["url_citation"]["title"],
367
+ ),
368
+ )
369
+ for annotation in annotations
370
+ ]
371
+
372
+ @classmethod
373
+ def convert_tool_call_to_openai(
374
+ cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
375
+ ) -> ChatCompletionMessageToolCall:
376
+ return ChatCompletionMessageToolCall(
377
+ id=tool_call.id,
378
+ type="function",
379
+ function=Function(
380
+ name=tool_call.function.name or "", arguments=tool_call.function.arguments
381
+ ),
382
+ )
@@ -0,0 +1,21 @@
1
+ from ...models.interface import Model, ModelProvider
2
+ from .litellm_model import LitellmModel
3
+
4
+ DEFAULT_MODEL: str = "gpt-4.1"
5
+
6
+
7
+ class LitellmProvider(ModelProvider):
8
+ """A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
9
+ ```python
10
+ Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
11
+ ```
12
+ See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
13
+
14
+ NOTE: API keys must be set via environment variables. If you're using models that require
15
+ additional configuration (e.g. Azure API base or version), those must also be set via the
16
+ environment variables that LiteLLM expects. If you have more advanced needs, we recommend
17
+ copy-pasting this class and making any modifications you need.
18
+ """
19
+
20
+ def get_model(self, model_name: str | None) -> Model:
21
+ return LitellmModel(model_name or DEFAULT_MODEL)
@@ -132,6 +132,6 @@ def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
132
132
  graph = graphviz.Source(dot_code)
133
133
 
134
134
  if filename:
135
- graph.render(filename, format="png")
135
+ graph.render(filename, format="png", cleanup=True)
136
136
 
137
137
  return graph
agents/mcp/server.py CHANGED
@@ -137,9 +137,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
137
137
  async with self._cleanup_lock:
138
138
  try:
139
139
  await self.exit_stack.aclose()
140
- self.session = None
141
140
  except Exception as e:
142
141
  logger.error(f"Error cleaning up server: {e}")
142
+ finally:
143
+ self.session = None
143
144
 
144
145
 
145
146
  class MCPServerStdioParams(TypedDict):