chuk-tool-processor 0.7.0__py3-none-any.whl → 0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (39) hide show
  1. chuk_tool_processor/__init__.py +114 -0
  2. chuk_tool_processor/core/__init__.py +31 -0
  3. chuk_tool_processor/core/exceptions.py +218 -12
  4. chuk_tool_processor/core/processor.py +391 -43
  5. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  6. chuk_tool_processor/execution/wrappers/caching.py +43 -10
  7. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  8. chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
  9. chuk_tool_processor/execution/wrappers/retry.py +93 -53
  10. chuk_tool_processor/logging/__init__.py +5 -8
  11. chuk_tool_processor/logging/context.py +2 -5
  12. chuk_tool_processor/mcp/__init__.py +3 -0
  13. chuk_tool_processor/mcp/mcp_tool.py +8 -3
  14. chuk_tool_processor/mcp/models.py +87 -0
  15. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +38 -2
  16. chuk_tool_processor/mcp/setup_mcp_sse.py +38 -2
  17. chuk_tool_processor/mcp/setup_mcp_stdio.py +92 -12
  18. chuk_tool_processor/mcp/stream_manager.py +109 -6
  19. chuk_tool_processor/mcp/transport/http_streamable_transport.py +18 -5
  20. chuk_tool_processor/mcp/transport/sse_transport.py +16 -3
  21. chuk_tool_processor/models/__init__.py +20 -0
  22. chuk_tool_processor/models/tool_call.py +34 -1
  23. chuk_tool_processor/models/tool_export_mixin.py +4 -4
  24. chuk_tool_processor/models/tool_spec.py +350 -0
  25. chuk_tool_processor/models/validated_tool.py +22 -2
  26. chuk_tool_processor/observability/__init__.py +30 -0
  27. chuk_tool_processor/observability/metrics.py +312 -0
  28. chuk_tool_processor/observability/setup.py +105 -0
  29. chuk_tool_processor/observability/tracing.py +346 -0
  30. chuk_tool_processor/py.typed +0 -0
  31. chuk_tool_processor/registry/interface.py +7 -7
  32. chuk_tool_processor/registry/providers/__init__.py +2 -1
  33. chuk_tool_processor/registry/tool_export.py +1 -6
  34. chuk_tool_processor-0.10.dist-info/METADATA +2326 -0
  35. chuk_tool_processor-0.10.dist-info/RECORD +69 -0
  36. chuk_tool_processor-0.7.0.dist-info/METADATA +0 -1230
  37. chuk_tool_processor-0.7.0.dist-info/RECORD +0 -61
  38. {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.10.dist-info}/WHEEL +0 -0
  39. {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.10.dist-info}/top_level.txt +0 -0
@@ -19,14 +19,11 @@ import sys
19
19
 
20
20
 
21
21
  # Auto-initialize shutdown error suppression when logging package is imported
22
- def _initialize_shutdown_fixes():
22
+ def _initialize_shutdown_fixes() -> None:
23
23
  """Initialize shutdown error suppression when the package is imported."""
24
- try:
25
- from .context import _setup_shutdown_error_suppression
26
-
27
- _setup_shutdown_error_suppression()
28
- except ImportError:
29
- pass
24
+ # Note: _setup_shutdown_error_suppression removed as it's no longer needed
25
+ # Keeping this function as a no-op for backward compatibility
26
+ pass
30
27
 
31
28
 
32
29
  # Initialize when package is imported
@@ -64,7 +61,7 @@ __all__ = [
64
61
  async def setup_logging(
65
62
  level: int = logging.INFO,
66
63
  structured: bool = True,
67
- log_file: str = None,
64
+ log_file: str | None = None,
68
65
  ) -> None:
69
66
  """
70
67
  Set up the logging system.
@@ -77,11 +77,8 @@ class LibraryLoggingManager:
77
77
  self._initialized = False
78
78
  self._lock = threading.Lock()
79
79
 
80
- def initialize(self):
80
+ def initialize(self) -> None:
81
81
  """Initialize clean shutdown behavior for the library."""
82
- if self._initialized:
83
- return
84
-
85
82
  with self._lock:
86
83
  if self._initialized:
87
84
  return
@@ -299,7 +296,7 @@ class StructuredAdapter(logging.LoggerAdapter):
299
296
  return msg, kwargs
300
297
 
301
298
  # ----------------------- convenience wrappers ------------------------ #
302
- def _forward(self, method_name: str, msg, *args, **kwargs):
299
+ def _forward(self, method_name: str, msg: str, *args: Any, **kwargs: Any) -> None:
303
300
  """Common helper: process + forward to `self.logger.<method_name>`."""
304
301
  msg, kwargs = self.process(msg, kwargs)
305
302
  getattr(self.logger, method_name)(msg, *args, **kwargs)
@@ -9,6 +9,7 @@ Updated to support the latest MCP transports:
9
9
  """
10
10
 
11
11
  from chuk_tool_processor.mcp.mcp_tool import MCPTool
12
+ from chuk_tool_processor.mcp.models import MCPServerConfig, MCPTransport
12
13
  from chuk_tool_processor.mcp.register_mcp_tools import register_mcp_tools
13
14
  from chuk_tool_processor.mcp.setup_mcp_http_streamable import setup_mcp_http_streamable
14
15
  from chuk_tool_processor.mcp.setup_mcp_sse import setup_mcp_sse
@@ -23,6 +24,8 @@ __all__ = [
23
24
  "HTTPStreamableTransport",
24
25
  "StreamManager",
25
26
  "MCPTool",
27
+ "MCPServerConfig",
28
+ "MCPTransport",
26
29
  "register_mcp_tools",
27
30
  "setup_mcp_stdio",
28
31
  "setup_mcp_sse",
@@ -237,7 +237,12 @@ class MCPTool:
237
237
  await self._record_failure()
238
238
 
239
239
  if attempt == max_attempts - 1:
240
- return {"error": error_msg, "tool_name": self.tool_name, "available": False, "reason": "timeout"}
240
+ return {
241
+ "error": error_msg,
242
+ "tool_name": self.tool_name,
243
+ "available": False,
244
+ "reason": "timeout",
245
+ }
241
246
 
242
247
  except Exception as e:
243
248
  error_str = str(e)
@@ -260,12 +265,12 @@ class MCPTool:
260
265
  await asyncio.sleep(backoff)
261
266
  backoff = min(backoff * self.recovery_config.backoff_multiplier, self.recovery_config.max_backoff)
262
267
 
263
- # Should never reach here
268
+ # Should never reach here, but return error if we do
264
269
  return {
265
270
  "error": f"Tool '{self.tool_name}' failed after all attempts",
266
271
  "tool_name": self.tool_name,
267
272
  "available": False,
268
- "reason": "exhausted_retries",
273
+ "reason": "execution_failed",
269
274
  }
270
275
 
271
276
  async def _execute_with_timeout(self, timeout: float, **kwargs: Any) -> Any:
@@ -0,0 +1,87 @@
1
+ #!/usr/bin/env python
2
+ # chuk_tool_processor/mcp/models.py
3
+ """
4
+ Pydantic models for MCP server configurations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from enum import Enum
10
+ from typing import Any
11
+
12
+ from pydantic import BaseModel, Field, model_validator
13
+
14
+
15
+ class MCPTransport(str, Enum):
16
+ """Supported MCP transport types."""
17
+
18
+ STDIO = "stdio"
19
+ SSE = "sse"
20
+ HTTP = "http"
21
+
22
+
23
+ class MCPServerConfig(BaseModel):
24
+ """Unified configuration for MCP servers (all transport types)."""
25
+
26
+ name: str = Field(description="Server identifier name")
27
+ transport: MCPTransport = Field(default=MCPTransport.STDIO, description="Transport protocol")
28
+
29
+ # STDIO fields
30
+ command: str | None = Field(default=None, description="Command to execute (stdio only)")
31
+ args: list[str] = Field(default_factory=list, description="Command arguments (stdio only)")
32
+ env: dict[str, str] | None = Field(default=None, description="Environment variables (stdio only)")
33
+
34
+ # SSE/HTTP fields
35
+ url: str | None = Field(default=None, description="Server URL (sse/http)")
36
+ headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers (sse/http)")
37
+ timeout: float = Field(default=10.0, description="Connection timeout in seconds")
38
+ sse_read_timeout: float = Field(default=300.0, description="SSE read timeout in seconds (sse only)")
39
+ api_key: str | None = Field(default=None, description="API key extracted from Authorization header")
40
+ session_id: str | None = Field(default=None, description="Session ID for HTTP transport")
41
+
42
+ @model_validator(mode="after")
43
+ def validate_transport_fields(self) -> MCPServerConfig:
44
+ """Validate required fields based on transport type."""
45
+ if self.transport == MCPTransport.STDIO:
46
+ if not self.command:
47
+ raise ValueError("command is required for stdio transport")
48
+ else:
49
+ # SSE/HTTP
50
+ if not self.url:
51
+ raise ValueError(f"url is required for {self.transport} transport")
52
+ # Extract API key from Authorization header if present
53
+ if not self.api_key and self.headers:
54
+ auth_header = self.headers.get("Authorization", "")
55
+ if "Bearer " in auth_header:
56
+ self.api_key = auth_header.split("Bearer ")[-1]
57
+ return self
58
+
59
+ def to_dict(self) -> dict[str, Any]:
60
+ """Convert to dictionary for internal use."""
61
+ if self.transport == MCPTransport.STDIO:
62
+ result = {
63
+ "name": self.name,
64
+ "command": self.command,
65
+ "args": self.args,
66
+ }
67
+ if self.env:
68
+ result["env"] = self.env
69
+ return result
70
+ else:
71
+ # SSE/HTTP
72
+ result = {
73
+ "name": self.name,
74
+ "url": self.url,
75
+ "headers": self.headers,
76
+ "timeout": self.timeout,
77
+ }
78
+ if self.transport == MCPTransport.SSE:
79
+ result["sse_read_timeout"] = self.sse_read_timeout
80
+ if self.api_key:
81
+ result["api_key"] = self.api_key
82
+ if self.session_id:
83
+ result["session_id"] = self.session_id
84
+ return result
85
+
86
+
87
+ __all__ = ["MCPServerConfig", "MCPTransport"]
@@ -41,8 +41,8 @@ async def setup_mcp_http_streamable(
41
41
  enable_rate_limiting: bool = False,
42
42
  global_rate_limit: int | None = None,
43
43
  tool_rate_limits: dict[str, tuple] | None = None,
44
- enable_retries: bool = False, # CHANGED: Disabled to allow OAuth refresh to work properly
45
- max_retries: int = 0, # CHANGED: 0 retries for HTTP (OAuth refresh happens at transport level)
44
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
45
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
46
46
  namespace: str = "http",
47
47
  oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
48
48
  ) -> tuple[ToolProcessor, StreamManager]:
@@ -102,6 +102,41 @@ async def setup_mcp_http_streamable(
102
102
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
103
103
 
104
104
  # 3️⃣ build a processor instance configured to your taste
105
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
106
+ # OAuth refresh happens at transport level with automatic retry
107
+
108
+ # Import RetryConfig to configure OAuth error exclusion
109
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
110
+
111
+ # Define OAuth error patterns that should NOT be retried at this level
112
+ # These will be handled by the transport layer's OAuth refresh mechanism
113
+ # Based on RFC 6750 (Bearer Token Usage) and MCP OAuth spec
114
+ oauth_error_patterns = [
115
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
116
+ "invalid_token", # Token expired, revoked, malformed, or invalid
117
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
118
+ # OAuth 2.1 token refresh errors
119
+ "invalid_grant", # Refresh token errors
120
+ # MCP spec - OAuth validation failures (401 Unauthorized)
121
+ "oauth validation",
122
+ "unauthorized",
123
+ # Common OAuth error descriptions
124
+ "expired token",
125
+ "token expired",
126
+ "authentication failed",
127
+ "invalid access token",
128
+ ]
129
+
130
+ # Create retry config that skips OAuth errors
131
+ retry_config = (
132
+ RetryConfig(
133
+ max_retries=max_retries,
134
+ skip_retry_on_error_substrings=oauth_error_patterns,
135
+ )
136
+ if enable_retries
137
+ else None
138
+ )
139
+
105
140
  processor = ToolProcessor(
106
141
  default_timeout=default_timeout,
107
142
  max_concurrency=max_concurrency,
@@ -112,6 +147,7 @@ async def setup_mcp_http_streamable(
112
147
  tool_rate_limits=tool_rate_limits,
113
148
  enable_retries=enable_retries,
114
149
  max_retries=max_retries,
150
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
115
151
  )
116
152
 
117
153
  logger.debug(
@@ -37,8 +37,8 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
37
37
  enable_rate_limiting: bool = False,
38
38
  global_rate_limit: int | None = None,
39
39
  tool_rate_limits: dict[str, tuple] | None = None,
40
- enable_retries: bool = False, # CHANGED: Disabled to allow OAuth refresh to work properly
41
- max_retries: int = 0, # CHANGED: 0 retries for SSE (OAuth refresh happens at transport level)
40
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
41
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
42
42
  namespace: str = "sse",
43
43
  oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
44
44
  ) -> tuple[ToolProcessor, StreamManager]:
@@ -81,6 +81,41 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
81
81
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
82
82
 
83
83
  # 3️⃣ build a processor instance configured to your taste
84
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
85
+ # OAuth refresh happens at transport level with automatic retry
86
+
87
+ # Import RetryConfig to configure OAuth error exclusion
88
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
89
+
90
+ # Define OAuth error patterns that should NOT be retried at this level
91
+ # These will be handled by the transport layer's OAuth refresh mechanism
92
+ # Based on RFC 6750 (Bearer Token Usage) and MCP OAuth spec
93
+ oauth_error_patterns = [
94
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
95
+ "invalid_token", # Token expired, revoked, malformed, or invalid
96
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
97
+ # OAuth 2.1 token refresh errors
98
+ "invalid_grant", # Refresh token errors
99
+ # MCP spec - OAuth validation failures (401 Unauthorized)
100
+ "oauth validation",
101
+ "unauthorized",
102
+ # Common OAuth error descriptions
103
+ "expired token",
104
+ "token expired",
105
+ "authentication failed",
106
+ "invalid access token",
107
+ ]
108
+
109
+ # Create retry config that skips OAuth errors
110
+ retry_config = (
111
+ RetryConfig(
112
+ max_retries=max_retries,
113
+ skip_retry_on_error_substrings=oauth_error_patterns,
114
+ )
115
+ if enable_retries
116
+ else None
117
+ )
118
+
84
119
  processor = ToolProcessor(
85
120
  default_timeout=default_timeout,
86
121
  max_concurrency=max_concurrency,
@@ -91,6 +126,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
91
126
  tool_rate_limits=tool_rate_limits,
92
127
  enable_retries=enable_retries,
93
128
  max_retries=max_retries,
129
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
94
130
  )
95
131
 
96
132
  logger.debug(
@@ -13,11 +13,16 @@ It:
13
13
 
14
14
  from __future__ import annotations
15
15
 
16
+ from typing import TYPE_CHECKING, Any
17
+
16
18
  from chuk_tool_processor.core.processor import ToolProcessor
17
19
  from chuk_tool_processor.logging import get_logger
18
20
  from chuk_tool_processor.mcp.register_mcp_tools import register_mcp_tools
19
21
  from chuk_tool_processor.mcp.stream_manager import StreamManager
20
22
 
23
+ if TYPE_CHECKING:
24
+ from chuk_tool_processor.mcp.models import MCPServerConfig
25
+
21
26
  logger = get_logger("chuk_tool_processor.mcp.setup_stdio")
22
27
 
23
28
 
@@ -26,8 +31,8 @@ logger = get_logger("chuk_tool_processor.mcp.setup_stdio")
26
31
  # --------------------------------------------------------------------------- #
27
32
  async def setup_mcp_stdio( # noqa: C901 - long but just a config facade
28
33
  *,
29
- config_file: str,
30
- servers: list[str],
34
+ config_file: str | None = None, # NOW OPTIONAL - for backward compatibility
35
+ servers: list[str] | list[dict[str, Any]] | list[MCPServerConfig], # Can be server names, dicts, OR Pydantic models
31
36
  server_names: dict[int, str] | None = None,
32
37
  default_timeout: float = 10.0,
33
38
  initialization_timeout: float = 60.0,
@@ -45,17 +50,92 @@ async def setup_mcp_stdio( # noqa: C901 - long but just a config facade
45
50
  Initialise stdio-transport MCP + a :class:`ToolProcessor`.
46
51
 
47
52
  Call with ``await`` from your async context.
53
+
54
+ Args:
55
+ config_file: Optional config file path (legacy mode)
56
+ servers: Can be:
57
+ - List of server names (legacy, requires config_file)
58
+ - List of server config dicts (new DX)
59
+ - List of MCPServerConfig Pydantic models (best DX)
60
+ server_names: Optional server name mapping
61
+ default_timeout: Default timeout for operations
62
+ initialization_timeout: Timeout for initialization
63
+ max_concurrency: Maximum concurrent operations
64
+ enable_caching: Enable result caching
65
+ cache_ttl: Cache time-to-live
66
+ enable_rate_limiting: Enable rate limiting
67
+ global_rate_limit: Global rate limit
68
+ tool_rate_limits: Per-tool rate limits
69
+ enable_retries: Enable retries
70
+ max_retries: Maximum retry attempts
71
+ namespace: Tool namespace
72
+
73
+ Returns:
74
+ Tuple of (ToolProcessor, StreamManager)
75
+
76
+ Examples:
77
+ # Best DX (Pydantic models):
78
+ from chuk_tool_processor.mcp import MCPServerConfig, MCPTransport
79
+
80
+ processor, manager = await setup_mcp_stdio(
81
+ servers=[
82
+ MCPServerConfig(
83
+ name="echo",
84
+ transport=MCPTransport.STDIO,
85
+ command="uvx",
86
+ args=["chuk-mcp-echo", "stdio"],
87
+ ),
88
+ ],
89
+ namespace="tools",
90
+ )
91
+
92
+ # New DX (dicts, no config file):
93
+ processor, manager = await setup_mcp_stdio(
94
+ servers=[
95
+ {"name": "echo", "command": "uvx", "args": ["chuk-mcp-echo", "stdio"]},
96
+ ],
97
+ namespace="tools",
98
+ )
99
+
100
+ # Legacy (with config file):
101
+ processor, manager = await setup_mcp_stdio(
102
+ config_file="mcp_config.json",
103
+ servers=["echo"],
104
+ namespace="tools",
105
+ )
48
106
  """
49
- # 1️⃣ create & connect the stream-manager
50
- # FIXED: Pass the default_timeout parameter to StreamManager.create
51
- stream_manager = await StreamManager.create(
52
- config_file=config_file,
53
- servers=servers,
54
- server_names=server_names,
55
- transport_type="stdio",
56
- default_timeout=default_timeout, # 🔧 ADD THIS LINE
57
- initialization_timeout=initialization_timeout,
58
- )
107
+ # Import here to avoid circular dependency at module level
108
+ from chuk_tool_processor.mcp.models import MCPServerConfig as MCPServerConfigModel
109
+
110
+ # Check what format the servers are in
111
+ if servers and isinstance(servers[0], str):
112
+ # LEGACY: servers are names, config_file is required
113
+ if config_file is None:
114
+ raise ValueError("config_file is required when servers is a list of strings")
115
+
116
+ stream_manager = await StreamManager.create(
117
+ config_file=config_file,
118
+ servers=servers, # type: ignore[arg-type]
119
+ server_names=server_names,
120
+ transport_type="stdio",
121
+ default_timeout=default_timeout,
122
+ initialization_timeout=initialization_timeout,
123
+ )
124
+ else:
125
+ # NEW DX: servers are config dicts or Pydantic models
126
+ # Convert Pydantic models to dicts if needed
127
+ server_dicts: list[dict[str, Any]]
128
+ if servers and isinstance(servers[0], MCPServerConfigModel):
129
+ server_dicts = [s.to_dict() for s in servers] # type: ignore[union-attr]
130
+ else:
131
+ server_dicts = servers # type: ignore[assignment]
132
+
133
+ stream_manager = await StreamManager.create_with_stdio(
134
+ servers=server_dicts,
135
+ server_names=server_names,
136
+ default_timeout=default_timeout,
137
+ initialization_timeout=initialization_timeout,
138
+ )
59
139
 
60
140
  # 2️⃣ pull the remote tool list and register each one locally
61
141
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
@@ -96,6 +96,24 @@ class StreamManager:
96
96
  )
97
97
  return inst
98
98
 
99
+ @classmethod
100
+ async def create_with_stdio(
101
+ cls,
102
+ servers: list[dict[str, Any]],
103
+ server_names: dict[int, str] | None = None,
104
+ default_timeout: float = 30.0,
105
+ initialization_timeout: float = 60.0,
106
+ ) -> StreamManager:
107
+ """Create StreamManager with STDIO transport and timeout protection (no config file needed)."""
108
+ inst = cls()
109
+ await inst.initialize_with_stdio(
110
+ servers,
111
+ server_names,
112
+ default_timeout=default_timeout,
113
+ initialization_timeout=initialization_timeout,
114
+ )
115
+ return inst
116
+
99
117
  @classmethod
100
118
  async def create_with_http_streamable(
101
119
  cls,
@@ -176,17 +194,24 @@ class StreamManager:
176
194
  for idx, server_name in enumerate(servers):
177
195
  try:
178
196
  if transport_type == "stdio":
179
- params = await load_config(config_file, server_name)
197
+ params, server_timeout = await load_config(config_file, server_name)
198
+ # Use per-server timeout if specified, otherwise use global default
199
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
200
+ logger.info(
201
+ f"Server '{server_name}' using timeout: {effective_timeout}s (per-server: {server_timeout}, default: {default_timeout})"
202
+ )
180
203
  # Use initialization_timeout for connection_timeout since subprocess
181
204
  # launch can take time (e.g., uvx downloading packages)
182
205
  transport: MCPBaseTransport = StdioTransport(
183
- params, connection_timeout=initialization_timeout, default_timeout=default_timeout
206
+ params, connection_timeout=initialization_timeout, default_timeout=effective_timeout
184
207
  )
185
208
  elif transport_type == "sse":
186
209
  logger.debug(
187
210
  "Using SSE transport in initialize() - consider using initialize_with_sse() instead"
188
211
  )
189
- params = await load_config(config_file, server_name)
212
+ params, server_timeout = await load_config(config_file, server_name)
213
+ # Use per-server timeout if specified, otherwise use global default
214
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
190
215
 
191
216
  if isinstance(params, dict) and "url" in params:
192
217
  sse_url = params["url"]
@@ -199,7 +224,7 @@ class StreamManager:
199
224
  logger.debug("No URL configured for SSE transport, using default: %s", sse_url)
200
225
 
201
226
  # Build SSE transport with optional headers
202
- transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": default_timeout}
227
+ transport_params = {"url": sse_url, "api_key": api_key, "default_timeout": effective_timeout}
203
228
  if headers:
204
229
  transport_params["headers"] = headers
205
230
 
@@ -209,7 +234,9 @@ class StreamManager:
209
234
  logger.debug(
210
235
  "Using HTTP Streamable transport in initialize() - consider using initialize_with_http_streamable() instead"
211
236
  )
212
- params = await load_config(config_file, server_name)
237
+ params, server_timeout = await load_config(config_file, server_name)
238
+ # Use per-server timeout if specified, otherwise use global default
239
+ effective_timeout = server_timeout if server_timeout is not None else default_timeout
213
240
 
214
241
  if isinstance(params, dict) and "url" in params:
215
242
  http_url = params["url"]
@@ -227,7 +254,7 @@ class StreamManager:
227
254
  transport_params = {
228
255
  "url": http_url,
229
256
  "api_key": api_key,
230
- "default_timeout": default_timeout,
257
+ "default_timeout": effective_timeout,
231
258
  "session_id": session_id,
232
259
  }
233
260
  # Note: headers not added until HTTPStreamableTransport supports them
@@ -364,6 +391,82 @@ class StreamManager:
364
391
  len(self.all_tools),
365
392
  )
366
393
 
394
+ async def initialize_with_stdio(
395
+ self,
396
+ servers: list[dict[str, Any]],
397
+ server_names: dict[int, str] | None = None,
398
+ default_timeout: float = 30.0,
399
+ initialization_timeout: float = 60.0,
400
+ ) -> None:
401
+ """Initialize with STDIO transport directly from server configs (no config file needed)."""
402
+ if self._closed:
403
+ raise RuntimeError("Cannot initialize a closed StreamManager")
404
+
405
+ async with self._lock:
406
+ self.server_names = server_names or {}
407
+
408
+ for idx, cfg in enumerate(servers):
409
+ name = cfg.get("name")
410
+ command = cfg.get("command")
411
+ args = cfg.get("args", [])
412
+ env = cfg.get("env")
413
+
414
+ if not (name and command):
415
+ logger.error("Bad STDIO server config (missing name or command): %s", cfg)
416
+ continue
417
+
418
+ try:
419
+ # Build STDIO transport parameters
420
+ transport_params = {
421
+ "command": command,
422
+ "args": args,
423
+ }
424
+ if env:
425
+ transport_params["env"] = env
426
+
427
+ logger.debug("STDIO %s: command=%s, args=%s", name, command, args)
428
+
429
+ transport = StdioTransport(
430
+ transport_params, connection_timeout=initialization_timeout, default_timeout=default_timeout
431
+ )
432
+
433
+ try:
434
+ if not await asyncio.wait_for(transport.initialize(), timeout=initialization_timeout):
435
+ logger.warning("Failed to init STDIO %s", name)
436
+ continue
437
+ except TimeoutError:
438
+ logger.error("Timeout initialising STDIO %s (timeout=%ss)", name, initialization_timeout)
439
+ continue
440
+
441
+ self.transports[name] = transport
442
+
443
+ # Ping and get tools with timeout protection
444
+ status = (
445
+ "Up"
446
+ if await asyncio.wait_for(transport.send_ping(), timeout=self.timeout_config.operation)
447
+ else "Down"
448
+ )
449
+ tools = await asyncio.wait_for(transport.get_tools(), timeout=self.timeout_config.operation)
450
+
451
+ for t in tools:
452
+ tname = t.get("name")
453
+ if tname:
454
+ self.tool_to_server_map[tname] = name
455
+ self.all_tools.extend(tools)
456
+
457
+ self.server_info.append({"id": idx, "name": name, "tools": len(tools), "status": status})
458
+ logger.debug("Initialised STDIO %s - %d tool(s)", name, len(tools))
459
+ except TimeoutError:
460
+ logger.error("Timeout initialising STDIO %s", name)
461
+ except Exception as exc:
462
+ logger.error("Error initialising STDIO %s: %s", name, exc)
463
+
464
+ logger.debug(
465
+ "StreamManager ready - %d STDIO server(s), %d tool(s)",
466
+ len(self.transports),
467
+ len(self.all_tools),
468
+ )
469
+
367
470
  async def initialize_with_http_streamable(
368
471
  self,
369
472
  servers: list[dict[str, str]],
@@ -239,13 +239,13 @@ class HTTPStreamableTransport(MCPBaseTransport):
239
239
  await self._cleanup()
240
240
  if self.enable_metrics and self._metrics:
241
241
  self._metrics.connection_errors += 1
242
- return False
242
+ raise # Re-raise for OAuth error detection in mcp-cli
243
243
  except Exception as e:
244
244
  logger.error("Error initializing HTTP Streamable transport: %s", e, exc_info=True)
245
245
  await self._cleanup()
246
246
  if self.enable_metrics and self._metrics:
247
247
  self._metrics.connection_errors += 1
248
- return False
248
+ raise # Re-raise for OAuth error detection in mcp-cli
249
249
 
250
250
  async def _attempt_recovery(self) -> bool:
251
251
  """Attempt to recover from connection issues (NEW - like SSE resilience)."""
@@ -519,16 +519,29 @@ class HTTPStreamableTransport(MCPBaseTransport):
519
519
  self._metrics.update_call_metrics(response_time, success)
520
520
 
521
521
  def _is_oauth_error(self, error_msg: str) -> bool:
522
- """Detect if error is OAuth-related (NEW)."""
522
+ """
523
+ Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
524
+
525
+ Checks for:
526
+ - RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
527
+ - OAuth 2.1 token refresh errors (invalid_grant)
528
+ - MCP spec OAuth validation failures (401/403 responses)
529
+ """
523
530
  if not error_msg:
524
531
  return False
525
532
 
526
533
  error_lower = error_msg.lower()
527
534
  oauth_indicators = [
528
- "invalid_token",
529
- "expired token",
535
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
536
+ "invalid_token", # Token expired, revoked, malformed, or invalid
537
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
538
+ # OAuth 2.1 token refresh errors
539
+ "invalid_grant", # Refresh token errors
540
+ # MCP spec - OAuth validation failures (401 Unauthorized)
530
541
  "oauth validation",
531
542
  "unauthorized",
543
+ # Common OAuth error descriptions
544
+ "expired token",
532
545
  "token expired",
533
546
  "authentication failed",
534
547
  "invalid access token",