openai-agents 0.2.8__py3-none-any.whl → 0.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. agents/__init__.py +105 -4
  2. agents/_debug.py +15 -4
  3. agents/_run_impl.py +1203 -96
  4. agents/agent.py +164 -19
  5. agents/apply_diff.py +329 -0
  6. agents/editor.py +47 -0
  7. agents/exceptions.py +35 -0
  8. agents/extensions/experimental/__init__.py +6 -0
  9. agents/extensions/experimental/codex/__init__.py +92 -0
  10. agents/extensions/experimental/codex/codex.py +89 -0
  11. agents/extensions/experimental/codex/codex_options.py +35 -0
  12. agents/extensions/experimental/codex/codex_tool.py +1142 -0
  13. agents/extensions/experimental/codex/events.py +162 -0
  14. agents/extensions/experimental/codex/exec.py +263 -0
  15. agents/extensions/experimental/codex/items.py +245 -0
  16. agents/extensions/experimental/codex/output_schema_file.py +50 -0
  17. agents/extensions/experimental/codex/payloads.py +31 -0
  18. agents/extensions/experimental/codex/thread.py +214 -0
  19. agents/extensions/experimental/codex/thread_options.py +54 -0
  20. agents/extensions/experimental/codex/turn_options.py +36 -0
  21. agents/extensions/handoff_filters.py +13 -1
  22. agents/extensions/memory/__init__.py +120 -0
  23. agents/extensions/memory/advanced_sqlite_session.py +1285 -0
  24. agents/extensions/memory/async_sqlite_session.py +239 -0
  25. agents/extensions/memory/dapr_session.py +423 -0
  26. agents/extensions/memory/encrypt_session.py +185 -0
  27. agents/extensions/memory/redis_session.py +261 -0
  28. agents/extensions/memory/sqlalchemy_session.py +334 -0
  29. agents/extensions/models/litellm_model.py +449 -36
  30. agents/extensions/models/litellm_provider.py +3 -1
  31. agents/function_schema.py +47 -5
  32. agents/guardrail.py +16 -2
  33. agents/{handoffs.py → handoffs/__init__.py} +89 -47
  34. agents/handoffs/history.py +268 -0
  35. agents/items.py +237 -11
  36. agents/lifecycle.py +75 -14
  37. agents/mcp/server.py +280 -37
  38. agents/mcp/util.py +24 -3
  39. agents/memory/__init__.py +22 -2
  40. agents/memory/openai_conversations_session.py +91 -0
  41. agents/memory/openai_responses_compaction_session.py +249 -0
  42. agents/memory/session.py +19 -261
  43. agents/memory/sqlite_session.py +275 -0
  44. agents/memory/util.py +20 -0
  45. agents/model_settings.py +14 -3
  46. agents/models/__init__.py +13 -0
  47. agents/models/chatcmpl_converter.py +303 -50
  48. agents/models/chatcmpl_helpers.py +63 -0
  49. agents/models/chatcmpl_stream_handler.py +290 -68
  50. agents/models/default_models.py +58 -0
  51. agents/models/interface.py +4 -0
  52. agents/models/openai_chatcompletions.py +103 -49
  53. agents/models/openai_provider.py +10 -4
  54. agents/models/openai_responses.py +162 -46
  55. agents/realtime/__init__.py +4 -0
  56. agents/realtime/_util.py +14 -3
  57. agents/realtime/agent.py +7 -0
  58. agents/realtime/audio_formats.py +53 -0
  59. agents/realtime/config.py +78 -10
  60. agents/realtime/events.py +18 -0
  61. agents/realtime/handoffs.py +2 -2
  62. agents/realtime/items.py +17 -1
  63. agents/realtime/model.py +13 -0
  64. agents/realtime/model_events.py +12 -0
  65. agents/realtime/model_inputs.py +18 -1
  66. agents/realtime/openai_realtime.py +696 -150
  67. agents/realtime/session.py +243 -23
  68. agents/repl.py +7 -3
  69. agents/result.py +197 -38
  70. agents/run.py +949 -168
  71. agents/run_context.py +13 -2
  72. agents/stream_events.py +1 -0
  73. agents/strict_schema.py +14 -0
  74. agents/tool.py +413 -15
  75. agents/tool_context.py +22 -1
  76. agents/tool_guardrails.py +279 -0
  77. agents/tracing/__init__.py +2 -0
  78. agents/tracing/config.py +9 -0
  79. agents/tracing/create.py +4 -0
  80. agents/tracing/processor_interface.py +84 -11
  81. agents/tracing/processors.py +65 -54
  82. agents/tracing/provider.py +64 -7
  83. agents/tracing/spans.py +105 -0
  84. agents/tracing/traces.py +116 -16
  85. agents/usage.py +134 -12
  86. agents/util/_json.py +19 -1
  87. agents/util/_transforms.py +12 -2
  88. agents/voice/input.py +5 -4
  89. agents/voice/models/openai_stt.py +17 -9
  90. agents/voice/pipeline.py +2 -0
  91. agents/voice/pipeline_config.py +4 -0
  92. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/METADATA +44 -19
  93. openai_agents-0.6.8.dist-info/RECORD +134 -0
  94. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/WHEEL +1 -1
  95. openai_agents-0.2.8.dist-info/RECORD +0 -103
  96. {openai_agents-0.2.8.dist-info → openai_agents-0.6.8.dist-info}/licenses/LICENSE +0 -0
agents/lifecycle.py CHANGED
@@ -1,9 +1,10 @@
1
- from typing import Any, Generic
1
+ from typing import Any, Generic, Optional
2
2
 
3
3
  from typing_extensions import TypeVar
4
4
 
5
5
  from .agent import Agent, AgentBase
6
- from .run_context import RunContextWrapper, TContext
6
+ from .items import ModelResponse, TResponseInputItem
7
+ from .run_context import AgentHookContext, RunContextWrapper, TContext
7
8
  from .tool import Tool
8
9
 
9
10
  TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
@@ -14,17 +15,47 @@ class RunHooksBase(Generic[TContext, TAgent]):
14
15
  override the methods you need.
15
16
  """
16
17
 
17
- async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
18
- """Called before the agent is invoked. Called each time the current agent changes."""
18
+ async def on_llm_start(
19
+ self,
20
+ context: RunContextWrapper[TContext],
21
+ agent: Agent[TContext],
22
+ system_prompt: Optional[str],
23
+ input_items: list[TResponseInputItem],
24
+ ) -> None:
25
+ """Called just before invoking the LLM for this agent."""
19
26
  pass
20
27
 
21
- async def on_agent_end(
28
+ async def on_llm_end(
22
29
  self,
23
30
  context: RunContextWrapper[TContext],
31
+ agent: Agent[TContext],
32
+ response: ModelResponse,
33
+ ) -> None:
34
+ """Called immediately after the LLM call returns for this agent."""
35
+ pass
36
+
37
+ async def on_agent_start(self, context: AgentHookContext[TContext], agent: TAgent) -> None:
38
+ """Called before the agent is invoked. Called each time the current agent changes.
39
+
40
+ Args:
41
+ context: The agent hook context.
42
+ agent: The agent that is about to be invoked.
43
+ """
44
+ pass
45
+
46
+ async def on_agent_end(
47
+ self,
48
+ context: AgentHookContext[TContext],
24
49
  agent: TAgent,
25
50
  output: Any,
26
51
  ) -> None:
27
- """Called when the agent produces a final output."""
52
+ """Called when the agent produces a final output.
53
+
54
+ Args:
55
+ context: The agent hook context.
56
+ agent: The agent that produced the output.
57
+ output: The final output produced by the agent.
58
+ """
28
59
  pass
29
60
 
30
61
  async def on_handoff(
@@ -42,7 +73,7 @@ class RunHooksBase(Generic[TContext, TAgent]):
42
73
  agent: TAgent,
43
74
  tool: Tool,
44
75
  ) -> None:
45
- """Called concurrently with tool invocation."""
76
+ """Called immediately before a local tool is invoked."""
46
77
  pass
47
78
 
48
79
  async def on_tool_end(
@@ -52,7 +83,7 @@ class RunHooksBase(Generic[TContext, TAgent]):
52
83
  tool: Tool,
53
84
  result: str,
54
85
  ) -> None:
55
- """Called after a tool is invoked."""
86
+ """Called immediately after a local tool is invoked."""
56
87
  pass
57
88
 
58
89
 
@@ -63,18 +94,29 @@ class AgentHooksBase(Generic[TContext, TAgent]):
63
94
  Subclass and override the methods you need.
64
95
  """
65
96
 
66
- async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
97
+ async def on_start(self, context: AgentHookContext[TContext], agent: TAgent) -> None:
67
98
  """Called before the agent is invoked. Called each time the running agent is changed to this
68
- agent."""
99
+ agent.
100
+
101
+ Args:
102
+ context: The agent hook context.
103
+ agent: This agent instance.
104
+ """
69
105
  pass
70
106
 
71
107
  async def on_end(
72
108
  self,
73
- context: RunContextWrapper[TContext],
109
+ context: AgentHookContext[TContext],
74
110
  agent: TAgent,
75
111
  output: Any,
76
112
  ) -> None:
77
- """Called when the agent produces a final output."""
113
+ """Called when the agent produces a final output.
114
+
115
+ Args:
116
+ context: The agent hook context.
117
+ agent: This agent instance.
118
+ output: The final output produced by the agent.
119
+ """
78
120
  pass
79
121
 
80
122
  async def on_handoff(
@@ -93,7 +135,7 @@ class AgentHooksBase(Generic[TContext, TAgent]):
93
135
  agent: TAgent,
94
136
  tool: Tool,
95
137
  ) -> None:
96
- """Called concurrently with tool invocation."""
138
+ """Called immediately before a local tool is invoked."""
97
139
  pass
98
140
 
99
141
  async def on_tool_end(
@@ -103,7 +145,26 @@ class AgentHooksBase(Generic[TContext, TAgent]):
103
145
  tool: Tool,
104
146
  result: str,
105
147
  ) -> None:
106
- """Called after a tool is invoked."""
148
+ """Called immediately after a local tool is invoked."""
149
+ pass
150
+
151
+ async def on_llm_start(
152
+ self,
153
+ context: RunContextWrapper[TContext],
154
+ agent: Agent[TContext],
155
+ system_prompt: Optional[str],
156
+ input_items: list[TResponseInputItem],
157
+ ) -> None:
158
+ """Called immediately before the agent issues an LLM call."""
159
+ pass
160
+
161
+ async def on_llm_end(
162
+ self,
163
+ context: RunContextWrapper[TContext],
164
+ agent: Agent[TContext],
165
+ response: ModelResponse,
166
+ ) -> None:
167
+ """Called immediately after the agent receives the LLM response."""
107
168
  pass
108
169
 
109
170
 
agents/mcp/server.py CHANGED
@@ -3,13 +3,20 @@ from __future__ import annotations
3
3
  import abc
4
4
  import asyncio
5
5
  import inspect
6
+ import sys
7
+ from collections.abc import Awaitable
6
8
  from contextlib import AbstractAsyncContextManager, AsyncExitStack
7
9
  from datetime import timedelta
8
10
  from pathlib import Path
9
- from typing import TYPE_CHECKING, Any, Literal, cast
11
+ from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
10
12
 
13
+ import httpx
14
+
15
+ if sys.version_info < (3, 11):
16
+ from exceptiongroup import BaseExceptionGroup # pyright: ignore[reportMissingImports]
11
17
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12
18
  from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
19
+ from mcp.client.session import MessageHandlerFnT
13
20
  from mcp.client.sse import sse_client
14
21
  from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
15
22
  from mcp.shared.message import SessionMessage
@@ -19,7 +26,9 @@ from typing_extensions import NotRequired, TypedDict
19
26
  from ..exceptions import UserError
20
27
  from ..logger import logger
21
28
  from ..run_context import RunContextWrapper
22
- from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
29
+ from .util import HttpClientFactory, ToolFilter, ToolFilterContext, ToolFilterStatic
30
+
31
+ T = TypeVar("T")
23
32
 
24
33
  if TYPE_CHECKING:
25
34
  from ..agent import AgentBase
@@ -98,6 +107,9 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
98
107
  client_session_timeout_seconds: float | None,
99
108
  tool_filter: ToolFilter = None,
100
109
  use_structured_content: bool = False,
110
+ max_retry_attempts: int = 0,
111
+ retry_backoff_seconds_base: float = 1.0,
112
+ message_handler: MessageHandlerFnT | None = None,
101
113
  ):
102
114
  """
103
115
  Args:
@@ -115,6 +127,12 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
115
127
  include the structured content in the `tool_result.content`, and using it by
116
128
  default will cause duplicate content. You can set this to True if you know the
117
129
  server will not duplicate the structured content in the `tool_result.content`.
130
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
131
+ Defaults to no retries.
132
+ retry_backoff_seconds_base: The base delay, in seconds, used for exponential
133
+ backoff between retries.
134
+ message_handler: Optional handler invoked for session messages as delivered by the
135
+ ClientSession.
118
136
  """
119
137
  super().__init__(use_structured_content=use_structured_content)
120
138
  self.session: ClientSession | None = None
@@ -124,6 +142,9 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
124
142
  self.server_initialize_result: InitializeResult | None = None
125
143
 
126
144
  self.client_session_timeout_seconds = client_session_timeout_seconds
145
+ self.max_retry_attempts = max_retry_attempts
146
+ self.retry_backoff_seconds_base = retry_backoff_seconds_base
147
+ self.message_handler = message_handler
127
148
 
128
149
  # The cache is always dirty at startup, so that we fetch tools at least once
129
150
  self._cache_dirty = True
@@ -134,8 +155,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
134
155
  async def _apply_tool_filter(
135
156
  self,
136
157
  tools: list[MCPTool],
137
- run_context: RunContextWrapper[Any],
138
- agent: AgentBase,
158
+ run_context: RunContextWrapper[Any] | None = None,
159
+ agent: AgentBase | None = None,
139
160
  ) -> list[MCPTool]:
140
161
  """Apply the tool filter to the list of tools."""
141
162
  if self.tool_filter is None:
@@ -147,6 +168,8 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
147
168
 
148
169
  # Handle callable tool filter (dynamic filter)
149
170
  else:
171
+ if run_context is None or agent is None:
172
+ raise UserError("run_context and agent are required for dynamic tool filtering")
150
173
  return await self._apply_dynamic_tool_filter(tools, run_context, agent)
151
174
 
152
175
  def _apply_static_tool_filter(
@@ -175,10 +198,10 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
175
198
  ) -> list[MCPTool]:
176
199
  """Apply dynamic tool filtering using a callable filter function."""
177
200
 
178
- # Ensure we have a callable filter and cast to help mypy
201
+ # Ensure we have a callable filter
179
202
  if not callable(self.tool_filter):
180
203
  raise ValueError("Tool filter must be callable for dynamic filtering")
181
- tool_filter_func = cast(ToolFilterCallable, self.tool_filter)
204
+ tool_filter_func = self.tool_filter
182
205
 
183
206
  # Create filter context
184
207
  filter_context = ToolFilterContext(
@@ -233,8 +256,50 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
233
256
  """Invalidate the tools cache."""
234
257
  self._cache_dirty = True
235
258
 
259
+ def _extract_http_error_from_exception(self, e: Exception) -> Exception | None:
260
+ """Extract HTTP error from exception or ExceptionGroup."""
261
+ if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)):
262
+ return e
263
+
264
+ # Check if it's an ExceptionGroup containing HTTP errors
265
+ if isinstance(e, BaseExceptionGroup):
266
+ for exc in e.exceptions:
267
+ if isinstance(
268
+ exc, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)
269
+ ):
270
+ return exc
271
+
272
+ return None
273
+
274
+ def _raise_user_error_for_http_error(self, http_error: Exception) -> None:
275
+ """Raise appropriate UserError for HTTP error."""
276
+ error_message = f"Failed to connect to MCP server '{self.name}': "
277
+ if isinstance(http_error, httpx.HTTPStatusError):
278
+ error_message += f"HTTP error {http_error.response.status_code} ({http_error.response.reason_phrase})" # noqa: E501
279
+
280
+ elif isinstance(http_error, httpx.ConnectError):
281
+ error_message += "Could not reach the server."
282
+
283
+ elif isinstance(http_error, httpx.TimeoutException):
284
+ error_message += "Connection timeout."
285
+
286
+ raise UserError(error_message) from http_error
287
+
288
+ async def _run_with_retries(self, func: Callable[[], Awaitable[T]]) -> T:
289
+ attempts = 0
290
+ while True:
291
+ try:
292
+ return await func()
293
+ except Exception:
294
+ attempts += 1
295
+ if self.max_retry_attempts != -1 and attempts > self.max_retry_attempts:
296
+ raise
297
+ backoff = self.retry_backoff_seconds_base * (2 ** (attempts - 1))
298
+ await asyncio.sleep(backoff)
299
+
236
300
  async def connect(self):
237
301
  """Connect to the server."""
302
+ connection_succeeded = False
238
303
  try:
239
304
  transport = await self.exit_stack.enter_async_context(self.create_streams())
240
305
  # streamablehttp_client returns (read, write, get_session_id)
@@ -249,15 +314,55 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
249
314
  timedelta(seconds=self.client_session_timeout_seconds)
250
315
  if self.client_session_timeout_seconds
251
316
  else None,
317
+ message_handler=self.message_handler,
252
318
  )
253
319
  )
254
320
  server_result = await session.initialize()
255
321
  self.server_initialize_result = server_result
256
322
  self.session = session
323
+ connection_succeeded = True
257
324
  except Exception as e:
258
- logger.error(f"Error initializing MCP server: {e}")
259
- await self.cleanup()
325
+ # Try to extract HTTP error from exception or ExceptionGroup
326
+ http_error = self._extract_http_error_from_exception(e)
327
+ if http_error:
328
+ self._raise_user_error_for_http_error(http_error)
329
+
330
+ # For CancelledError, preserve cancellation semantics - don't wrap it.
331
+ # If it's masking an HTTP error, cleanup() will extract and raise UserError.
332
+ if isinstance(e, asyncio.CancelledError):
333
+ raise
334
+
335
+ # For HTTP-related errors, wrap them
336
+ if isinstance(e, (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)):
337
+ self._raise_user_error_for_http_error(e)
338
+
339
+ # For other errors, re-raise as-is (don't wrap non-HTTP errors)
260
340
  raise
341
+ finally:
342
+ # Always attempt cleanup on error, but suppress cleanup errors that mask the original
343
+ if not connection_succeeded:
344
+ try:
345
+ await self.cleanup()
346
+ except UserError:
347
+ # Re-raise UserError from cleanup (contains the real HTTP error)
348
+ raise
349
+ except Exception as cleanup_error:
350
+ # Suppress RuntimeError about cancel scopes during cleanup - this is a known
351
+ # issue with the MCP library's async generator cleanup and shouldn't mask the
352
+ # original error
353
+ if isinstance(cleanup_error, RuntimeError) and "cancel scope" in str(
354
+ cleanup_error
355
+ ):
356
+ logger.debug(
357
+ f"Ignoring cancel scope error during cleanup of MCP server "
358
+ f"'{self.name}': {cleanup_error}"
359
+ )
360
+ else:
361
+ # Log other cleanup errors but don't raise - original error is more
362
+ # important
363
+ logger.warning(
364
+ f"Error during cleanup of MCP server '{self.name}': {cleanup_error}"
365
+ )
261
366
 
262
367
  async def list_tools(
263
368
  self,
@@ -267,31 +372,56 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
267
372
  """List the tools available on the server."""
268
373
  if not self.session:
269
374
  raise UserError("Server not initialized. Make sure you call `connect()` first.")
375
+ session = self.session
376
+ assert session is not None
270
377
 
271
- # Return from cache if caching is enabled, we have tools, and the cache is not dirty
272
- if self.cache_tools_list and not self._cache_dirty and self._tools_list:
273
- tools = self._tools_list
274
- else:
275
- # Reset the cache dirty to False
276
- self._cache_dirty = False
277
- # Fetch the tools from the server
278
- self._tools_list = (await self.session.list_tools()).tools
279
- tools = self._tools_list
280
-
281
- # Filter tools based on tool_filter
282
- filtered_tools = tools
283
- if self.tool_filter is not None:
284
- if run_context is None or agent is None:
285
- raise UserError("run_context and agent are required for dynamic tool filtering")
286
- filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
287
- return filtered_tools
378
+ try:
379
+ # Return from cache if caching is enabled, we have tools, and the cache is not dirty
380
+ if self.cache_tools_list and not self._cache_dirty and self._tools_list:
381
+ tools = self._tools_list
382
+ else:
383
+ # Fetch the tools from the server
384
+ result = await self._run_with_retries(lambda: session.list_tools())
385
+ self._tools_list = result.tools
386
+ self._cache_dirty = False
387
+ tools = self._tools_list
388
+
389
+ # Filter tools based on tool_filter
390
+ filtered_tools = tools
391
+ if self.tool_filter is not None:
392
+ filtered_tools = await self._apply_tool_filter(filtered_tools, run_context, agent)
393
+ return filtered_tools
394
+ except httpx.HTTPStatusError as e:
395
+ status_code = e.response.status_code
396
+ raise UserError(
397
+ f"Failed to list tools from MCP server '{self.name}': HTTP error {status_code}"
398
+ ) from e
399
+ except httpx.ConnectError as e:
400
+ raise UserError(
401
+ f"Failed to list tools from MCP server '{self.name}': Connection lost. "
402
+ f"The server may have disconnected."
403
+ ) from e
288
404
 
289
405
  async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
290
406
  """Invoke a tool on the server."""
291
407
  if not self.session:
292
408
  raise UserError("Server not initialized. Make sure you call `connect()` first.")
409
+ session = self.session
410
+ assert session is not None
293
411
 
294
- return await self.session.call_tool(tool_name, arguments)
412
+ try:
413
+ return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
414
+ except httpx.HTTPStatusError as e:
415
+ status_code = e.response.status_code
416
+ raise UserError(
417
+ f"Failed to call tool '{tool_name}' on MCP server '{self.name}': "
418
+ f"HTTP error {status_code}"
419
+ ) from e
420
+ except httpx.ConnectError as e:
421
+ raise UserError(
422
+ f"Failed to call tool '{tool_name}' on MCP server '{self.name}': Connection lost. "
423
+ f"The server may have disconnected."
424
+ ) from e
295
425
 
296
426
  async def list_prompts(
297
427
  self,
@@ -314,10 +444,73 @@ class _MCPServerWithClientSession(MCPServer, abc.ABC):
314
444
  async def cleanup(self):
315
445
  """Cleanup the server."""
316
446
  async with self._cleanup_lock:
447
+ # Only raise HTTP errors if we're cleaning up after a failed connection.
448
+ # During normal teardown (via __aexit__), log but don't raise to avoid
449
+ # masking the original exception.
450
+ is_failed_connection_cleanup = self.session is None
451
+
317
452
  try:
318
453
  await self.exit_stack.aclose()
454
+ except BaseExceptionGroup as eg:
455
+ # Extract HTTP errors from ExceptionGroup raised during cleanup
456
+ # This happens when background tasks fail (e.g., HTTP errors)
457
+ http_error = None
458
+ connect_error = None
459
+ timeout_error = None
460
+ error_message = f"Failed to connect to MCP server '{self.name}': "
461
+
462
+ for exc in eg.exceptions:
463
+ if isinstance(exc, httpx.HTTPStatusError):
464
+ http_error = exc
465
+ elif isinstance(exc, httpx.ConnectError):
466
+ connect_error = exc
467
+ elif isinstance(exc, httpx.TimeoutException):
468
+ timeout_error = exc
469
+
470
+ # Only raise HTTP errors if we're cleaning up after a failed connection.
471
+ # During normal teardown, log them instead.
472
+ if http_error:
473
+ if is_failed_connection_cleanup:
474
+ error_message += f"HTTP error {http_error.response.status_code} ({http_error.response.reason_phrase})" # noqa: E501
475
+ raise UserError(error_message) from http_error
476
+ else:
477
+ # Normal teardown - log but don't raise
478
+ logger.warning(
479
+ f"HTTP error during cleanup of MCP server '{self.name}': {http_error}"
480
+ )
481
+ elif connect_error:
482
+ if is_failed_connection_cleanup:
483
+ error_message += "Could not reach the server."
484
+ raise UserError(error_message) from connect_error
485
+ else:
486
+ logger.warning(
487
+ f"Connection error during cleanup of MCP server '{self.name}': {connect_error}" # noqa: E501
488
+ )
489
+ elif timeout_error:
490
+ if is_failed_connection_cleanup:
491
+ error_message += "Connection timeout."
492
+ raise UserError(error_message) from timeout_error
493
+ else:
494
+ logger.warning(
495
+ f"Timeout error during cleanup of MCP server '{self.name}': {timeout_error}" # noqa: E501
496
+ )
497
+ else:
498
+ # No HTTP error found, suppress RuntimeError about cancel scopes
499
+ has_cancel_scope_error = any(
500
+ isinstance(exc, RuntimeError) and "cancel scope" in str(exc)
501
+ for exc in eg.exceptions
502
+ )
503
+ if has_cancel_scope_error:
504
+ logger.debug(f"Ignoring cancel scope error during cleanup: {eg}")
505
+ else:
506
+ logger.error(f"Error cleaning up server: {eg}")
319
507
  except Exception as e:
320
- logger.error(f"Error cleaning up server: {e}")
508
+ # Suppress RuntimeError about cancel scopes - this is a known issue with the MCP
509
+ # library when background tasks fail during async generator cleanup
510
+ if isinstance(e, RuntimeError) and "cancel scope" in str(e):
511
+ logger.debug(f"Ignoring cancel scope error during cleanup: {e}")
512
+ else:
513
+ logger.error(f"Error cleaning up server: {e}")
321
514
  finally:
322
515
  self.session = None
323
516
 
@@ -365,6 +558,9 @@ class MCPServerStdio(_MCPServerWithClientSession):
365
558
  client_session_timeout_seconds: float | None = 5,
366
559
  tool_filter: ToolFilter = None,
367
560
  use_structured_content: bool = False,
561
+ max_retry_attempts: int = 0,
562
+ retry_backoff_seconds_base: float = 1.0,
563
+ message_handler: MessageHandlerFnT | None = None,
368
564
  ):
369
565
  """Create a new MCP server based on the stdio transport.
370
566
 
@@ -388,12 +584,21 @@ class MCPServerStdio(_MCPServerWithClientSession):
388
584
  include the structured content in the `tool_result.content`, and using it by
389
585
  default will cause duplicate content. You can set this to True if you know the
390
586
  server will not duplicate the structured content in the `tool_result.content`.
587
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
588
+ Defaults to no retries.
589
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
590
+ backoff between retries.
591
+ message_handler: Optional handler invoked for session messages as delivered by the
592
+ ClientSession.
391
593
  """
392
594
  super().__init__(
393
595
  cache_tools_list,
394
596
  client_session_timeout_seconds,
395
597
  tool_filter,
396
598
  use_structured_content,
599
+ max_retry_attempts,
600
+ retry_backoff_seconds_base,
601
+ message_handler=message_handler,
397
602
  )
398
603
 
399
604
  self.params = StdioServerParameters(
@@ -455,6 +660,9 @@ class MCPServerSse(_MCPServerWithClientSession):
455
660
  client_session_timeout_seconds: float | None = 5,
456
661
  tool_filter: ToolFilter = None,
457
662
  use_structured_content: bool = False,
663
+ max_retry_attempts: int = 0,
664
+ retry_backoff_seconds_base: float = 1.0,
665
+ message_handler: MessageHandlerFnT | None = None,
458
666
  ):
459
667
  """Create a new MCP server based on the HTTP with SSE transport.
460
668
 
@@ -480,12 +688,21 @@ class MCPServerSse(_MCPServerWithClientSession):
480
688
  include the structured content in the `tool_result.content`, and using it by
481
689
  default will cause duplicate content. You can set this to True if you know the
482
690
  server will not duplicate the structured content in the `tool_result.content`.
691
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
692
+ Defaults to no retries.
693
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
694
+ backoff between retries.
695
+ message_handler: Optional handler invoked for session messages as delivered by the
696
+ ClientSession.
483
697
  """
484
698
  super().__init__(
485
699
  cache_tools_list,
486
700
  client_session_timeout_seconds,
487
701
  tool_filter,
488
702
  use_structured_content,
703
+ max_retry_attempts,
704
+ retry_backoff_seconds_base,
705
+ message_handler=message_handler,
489
706
  )
490
707
 
491
708
  self.params = params
@@ -532,6 +749,9 @@ class MCPServerStreamableHttpParams(TypedDict):
532
749
  terminate_on_close: NotRequired[bool]
533
750
  """Terminate on close"""
534
751
 
752
+ httpx_client_factory: NotRequired[HttpClientFactory]
753
+ """Custom HTTP client factory for configuring httpx.AsyncClient behavior."""
754
+
535
755
 
536
756
  class MCPServerStreamableHttp(_MCPServerWithClientSession):
537
757
  """MCP server implementation that uses the Streamable HTTP transport. See the [spec]
@@ -547,14 +767,17 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
547
767
  client_session_timeout_seconds: float | None = 5,
548
768
  tool_filter: ToolFilter = None,
549
769
  use_structured_content: bool = False,
770
+ max_retry_attempts: int = 0,
771
+ retry_backoff_seconds_base: float = 1.0,
772
+ message_handler: MessageHandlerFnT | None = None,
550
773
  ):
551
774
  """Create a new MCP server based on the Streamable HTTP transport.
552
775
 
553
776
  Args:
554
777
  params: The params that configure the server. This includes the URL of the server,
555
- the headers to send to the server, the timeout for the HTTP request, and the
556
- timeout for the Streamable HTTP connection and whether we need to
557
- terminate on close.
778
+ the headers to send to the server, the timeout for the HTTP request, the
779
+ timeout for the Streamable HTTP connection, whether we need to
780
+ terminate on close, and an optional custom HTTP client factory.
558
781
 
559
782
  cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
560
783
  cached and only fetched from the server once. If `False`, the tools list will be
@@ -573,12 +796,21 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
573
796
  include the structured content in the `tool_result.content`, and using it by
574
797
  default will cause duplicate content. You can set this to True if you know the
575
798
  server will not duplicate the structured content in the `tool_result.content`.
799
+ max_retry_attempts: Number of times to retry failed list_tools/call_tool calls.
800
+ Defaults to no retries.
801
+ retry_backoff_seconds_base: The base delay, in seconds, for exponential
802
+ backoff between retries.
803
+ message_handler: Optional handler invoked for session messages as delivered by the
804
+ ClientSession.
576
805
  """
577
806
  super().__init__(
578
807
  cache_tools_list,
579
808
  client_session_timeout_seconds,
580
809
  tool_filter,
581
810
  use_structured_content,
811
+ max_retry_attempts,
812
+ retry_backoff_seconds_base,
813
+ message_handler=message_handler,
582
814
  )
583
815
 
584
816
  self.params = params
@@ -594,13 +826,24 @@ class MCPServerStreamableHttp(_MCPServerWithClientSession):
594
826
  ]
595
827
  ]:
596
828
  """Create the streams for the server."""
597
- return streamablehttp_client(
598
- url=self.params["url"],
599
- headers=self.params.get("headers", None),
600
- timeout=self.params.get("timeout", 5),
601
- sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
602
- terminate_on_close=self.params.get("terminate_on_close", True),
603
- )
829
+ # Only pass httpx_client_factory if it's provided
830
+ if "httpx_client_factory" in self.params:
831
+ return streamablehttp_client(
832
+ url=self.params["url"],
833
+ headers=self.params.get("headers", None),
834
+ timeout=self.params.get("timeout", 5),
835
+ sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
836
+ terminate_on_close=self.params.get("terminate_on_close", True),
837
+ httpx_client_factory=self.params["httpx_client_factory"],
838
+ )
839
+ else:
840
+ return streamablehttp_client(
841
+ url=self.params["url"],
842
+ headers=self.params.get("headers", None),
843
+ timeout=self.params.get("timeout", 5),
844
+ sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
845
+ terminate_on_close=self.params.get("terminate_on_close", True),
846
+ )
604
847
 
605
848
  @property
606
849
  def name(self) -> str: