chuk-tool-processor 0.6.4__py3-none-any.whl → 0.9.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (66) hide show
  1. chuk_tool_processor/core/__init__.py +32 -1
  2. chuk_tool_processor/core/exceptions.py +225 -13
  3. chuk_tool_processor/core/processor.py +135 -104
  4. chuk_tool_processor/execution/strategies/__init__.py +6 -0
  5. chuk_tool_processor/execution/strategies/inprocess_strategy.py +142 -150
  6. chuk_tool_processor/execution/strategies/subprocess_strategy.py +202 -206
  7. chuk_tool_processor/execution/tool_executor.py +82 -84
  8. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  9. chuk_tool_processor/execution/wrappers/caching.py +150 -116
  10. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  11. chuk_tool_processor/execution/wrappers/rate_limiting.py +76 -43
  12. chuk_tool_processor/execution/wrappers/retry.py +116 -78
  13. chuk_tool_processor/logging/__init__.py +23 -17
  14. chuk_tool_processor/logging/context.py +40 -45
  15. chuk_tool_processor/logging/formatter.py +22 -21
  16. chuk_tool_processor/logging/helpers.py +28 -42
  17. chuk_tool_processor/logging/metrics.py +13 -15
  18. chuk_tool_processor/mcp/__init__.py +8 -12
  19. chuk_tool_processor/mcp/mcp_tool.py +158 -114
  20. chuk_tool_processor/mcp/register_mcp_tools.py +22 -22
  21. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +57 -17
  22. chuk_tool_processor/mcp/setup_mcp_sse.py +57 -17
  23. chuk_tool_processor/mcp/setup_mcp_stdio.py +11 -11
  24. chuk_tool_processor/mcp/stream_manager.py +333 -276
  25. chuk_tool_processor/mcp/transport/__init__.py +22 -29
  26. chuk_tool_processor/mcp/transport/base_transport.py +180 -44
  27. chuk_tool_processor/mcp/transport/http_streamable_transport.py +505 -325
  28. chuk_tool_processor/mcp/transport/models.py +100 -0
  29. chuk_tool_processor/mcp/transport/sse_transport.py +607 -276
  30. chuk_tool_processor/mcp/transport/stdio_transport.py +597 -116
  31. chuk_tool_processor/models/__init__.py +21 -1
  32. chuk_tool_processor/models/execution_strategy.py +16 -21
  33. chuk_tool_processor/models/streaming_tool.py +28 -25
  34. chuk_tool_processor/models/tool_call.py +49 -31
  35. chuk_tool_processor/models/tool_export_mixin.py +22 -8
  36. chuk_tool_processor/models/tool_result.py +40 -77
  37. chuk_tool_processor/models/tool_spec.py +350 -0
  38. chuk_tool_processor/models/validated_tool.py +36 -18
  39. chuk_tool_processor/observability/__init__.py +30 -0
  40. chuk_tool_processor/observability/metrics.py +312 -0
  41. chuk_tool_processor/observability/setup.py +105 -0
  42. chuk_tool_processor/observability/tracing.py +345 -0
  43. chuk_tool_processor/plugins/__init__.py +1 -1
  44. chuk_tool_processor/plugins/discovery.py +11 -11
  45. chuk_tool_processor/plugins/parsers/__init__.py +1 -1
  46. chuk_tool_processor/plugins/parsers/base.py +1 -2
  47. chuk_tool_processor/plugins/parsers/function_call_tool.py +13 -8
  48. chuk_tool_processor/plugins/parsers/json_tool.py +4 -3
  49. chuk_tool_processor/plugins/parsers/openai_tool.py +12 -7
  50. chuk_tool_processor/plugins/parsers/xml_tool.py +4 -4
  51. chuk_tool_processor/registry/__init__.py +12 -12
  52. chuk_tool_processor/registry/auto_register.py +22 -30
  53. chuk_tool_processor/registry/decorators.py +127 -129
  54. chuk_tool_processor/registry/interface.py +26 -23
  55. chuk_tool_processor/registry/metadata.py +27 -22
  56. chuk_tool_processor/registry/provider.py +17 -18
  57. chuk_tool_processor/registry/providers/__init__.py +16 -19
  58. chuk_tool_processor/registry/providers/memory.py +18 -25
  59. chuk_tool_processor/registry/tool_export.py +42 -51
  60. chuk_tool_processor/utils/validation.py +15 -16
  61. chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
  62. chuk_tool_processor-0.9.7.dist-info/RECORD +67 -0
  63. chuk_tool_processor-0.6.4.dist-info/METADATA +0 -697
  64. chuk_tool_processor-0.6.4.dist-info/RECORD +0 -60
  65. {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
  66. {chuk_tool_processor-0.6.4.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
@@ -1,439 +1,770 @@
1
1
  # chuk_tool_processor/mcp/transport/sse_transport.py
2
2
  """
3
- Fixed SSE transport that matches your server's actual behavior.
4
- Based on your working debug script.
3
+ SSE transport for MCP communication.
4
+
5
+ FIXED: Improved health monitoring to avoid false unhealthy states.
6
+ The SSE endpoint works perfectly, so we need more lenient health checks.
5
7
  """
8
+
6
9
  from __future__ import annotations
7
10
 
8
11
  import asyncio
12
+ import contextlib
9
13
  import json
10
- import uuid
11
- from typing import Dict, Any, List, Optional, Tuple
12
14
  import logging
15
+ import time
16
+ import uuid
17
+ from typing import Any
13
18
 
14
19
  import httpx
15
20
 
16
21
  from .base_transport import MCPBaseTransport
22
+ from .models import TimeoutConfig, TransportMetrics
17
23
 
18
24
  logger = logging.getLogger(__name__)
19
25
 
20
26
 
21
27
  class SSETransport(MCPBaseTransport):
22
28
  """
23
- SSE transport that works with your server's two-step async pattern:
24
- 1. POST messages to /messages endpoint
25
- 2. Receive responses via SSE stream
29
+ SSE transport implementing the MCP protocol over Server-Sent Events.
30
+
31
+ FIXED: More lenient health monitoring to avoid false unhealthy states.
26
32
  """
27
33
 
28
- def __init__(self, url: str, api_key: Optional[str] = None,
29
- connection_timeout: float = 30.0, default_timeout: float = 30.0):
30
- """Initialize SSE transport."""
31
- self.url = url.rstrip('/')
34
+ def __init__(
35
+ self,
36
+ url: str,
37
+ api_key: str | None = None,
38
+ headers: dict[str, str] | None = None,
39
+ connection_timeout: float = 30.0,
40
+ default_timeout: float = 60.0,
41
+ enable_metrics: bool = True,
42
+ oauth_refresh_callback: Any | None = None,
43
+ timeout_config: TimeoutConfig | None = None,
44
+ ):
45
+ """
46
+ Initialize SSE transport.
47
+ """
48
+ self.url = url.rstrip("/")
32
49
  self.api_key = api_key
33
- self.connection_timeout = connection_timeout
34
- self.default_timeout = default_timeout
35
-
36
- # State
50
+ self.configured_headers = headers or {}
51
+ self.enable_metrics = enable_metrics
52
+ self.oauth_refresh_callback = oauth_refresh_callback
53
+
54
+ # Use timeout config or create from individual parameters
55
+ if timeout_config is None:
56
+ timeout_config = TimeoutConfig(connect=connection_timeout, operation=default_timeout)
57
+
58
+ self.timeout_config = timeout_config
59
+ self.connection_timeout = timeout_config.connect
60
+ self.default_timeout = timeout_config.operation
61
+
62
+ logger.debug("SSE Transport initialized with URL: %s", self.url)
63
+
64
+ # Connection state
37
65
  self.session_id = None
38
66
  self.message_url = None
39
- self.pending_requests: Dict[str, asyncio.Future] = {}
67
+ self.pending_requests: dict[str, asyncio.Future] = {}
40
68
  self._initialized = False
41
-
69
+
42
70
  # HTTP clients
43
71
  self.stream_client = None
44
72
  self.send_client = None
45
-
46
- # SSE stream
73
+
74
+ # SSE stream management
47
75
  self.sse_task = None
48
76
  self.sse_response = None
49
77
  self.sse_stream_context = None
50
78
 
51
- def _get_headers(self) -> Dict[str, str]:
52
- """Get headers with auth if available."""
53
- headers = {}
54
- if self.api_key:
55
- headers['Authorization'] = f'Bearer {self.api_key}'
79
+ # FIXED: More lenient health monitoring
80
+ self._last_successful_ping = None
81
+ self._consecutive_failures = 0
82
+ self._max_consecutive_failures = 5 # INCREASED: was 3, now 5
83
+ self._connection_grace_period = 30.0 # NEW: Grace period after initialization
84
+ self._initialization_time = None # NEW: Track when we initialized
85
+
86
+ # Performance metrics - use Pydantic model
87
+ self._metrics = TransportMetrics() if enable_metrics else None
88
+
89
+ def _construct_sse_url(self, base_url: str) -> str:
90
+ """Construct the SSE endpoint URL from the base URL."""
91
+ base_url = base_url.rstrip("/")
92
+
93
+ if base_url.endswith("/sse"):
94
+ logger.debug("URL already contains /sse endpoint: %s", base_url)
95
+ return base_url
96
+
97
+ sse_url = f"{base_url}/sse"
98
+ logger.debug("Constructed SSE URL: %s -> %s", base_url, sse_url)
99
+ return sse_url
100
+
101
+ def _get_headers(self) -> dict[str, str]:
102
+ """Get headers with authentication and custom headers."""
103
+ headers = {
104
+ "User-Agent": "chuk-tool-processor/1.0.0",
105
+ "Accept": "text/event-stream",
106
+ "Cache-Control": "no-cache",
107
+ }
108
+
109
+ # Add configured headers first
110
+ if self.configured_headers:
111
+ headers.update(self.configured_headers)
112
+
113
+ # Add API key as Bearer token if provided and no Authorization header exists
114
+ # This prevents clobbering OAuth tokens from configured_headers
115
+ if self.api_key and "Authorization" not in headers:
116
+ headers["Authorization"] = f"Bearer {self.api_key}"
117
+
56
118
  return headers
57
119
 
120
+ async def _test_gateway_connectivity(self) -> bool:
121
+ """
122
+ Skip connectivity test - we know the SSE endpoint works.
123
+
124
+ FIXED: The diagnostic proves SSE endpoint works perfectly.
125
+ No need to test base URL that causes 401 errors.
126
+ """
127
+ logger.debug("Skipping gateway connectivity test - using direct SSE connection")
128
+ return True
129
+
58
130
  async def initialize(self) -> bool:
59
- """Initialize SSE connection and MCP handshake."""
131
+ """Initialize SSE connection with improved health tracking."""
60
132
  if self._initialized:
61
133
  logger.warning("Transport already initialized")
62
134
  return True
63
-
135
+
136
+ start_time = time.time()
137
+
64
138
  try:
65
- logger.info("Initializing SSE transport...")
66
-
139
+ logger.debug("Initializing SSE transport...")
140
+
141
+ # FIXED: Skip problematic connectivity test
142
+ if not await self._test_gateway_connectivity():
143
+ logger.error("Gateway connectivity test failed")
144
+ return False
145
+
67
146
  # Create HTTP clients
68
- self.stream_client = httpx.AsyncClient(timeout=self.connection_timeout)
69
- self.send_client = httpx.AsyncClient(timeout=self.default_timeout)
70
-
71
- # Connect to SSE stream
72
- sse_url = f"{self.url}/sse"
73
- logger.debug(f"Connecting to SSE: {sse_url}")
74
-
75
- self.sse_stream_context = self.stream_client.stream(
76
- 'GET', sse_url, headers=self._get_headers()
147
+ self.stream_client = httpx.AsyncClient(
148
+ timeout=httpx.Timeout(self.connection_timeout),
149
+ follow_redirects=True,
150
+ limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
77
151
  )
152
+ self.send_client = httpx.AsyncClient(
153
+ timeout=httpx.Timeout(self.default_timeout),
154
+ follow_redirects=True,
155
+ limits=httpx.Limits(max_connections=10, max_keepalive_connections=5),
156
+ )
157
+
158
+ # Connect to SSE stream
159
+ sse_url = self._construct_sse_url(self.url)
160
+ logger.debug("Connecting to SSE endpoint: %s", sse_url)
161
+
162
+ self.sse_stream_context = self.stream_client.stream("GET", sse_url, headers=self._get_headers())
78
163
  self.sse_response = await self.sse_stream_context.__aenter__()
79
-
164
+
80
165
  if self.sse_response.status_code != 200:
81
- logger.error(f"SSE connection failed: {self.sse_response.status_code}")
166
+ logger.error("SSE connection failed with status: %s", self.sse_response.status_code)
167
+ await self._cleanup()
82
168
  return False
83
-
84
- logger.info("SSE streaming connection established")
85
-
169
+
170
+ logger.debug("SSE streaming connection established")
171
+
86
172
  # Start SSE processing task
87
- self.sse_task = asyncio.create_task(self._process_sse_stream())
88
-
173
+ self.sse_task = asyncio.create_task(self._process_sse_stream(), name="sse_stream_processor")
174
+
89
175
  # Wait for session discovery
90
176
  logger.debug("Waiting for session discovery...")
91
- for i in range(50): # 5 seconds max
92
- if self.message_url:
93
- break
177
+ session_timeout = self.timeout_config.connect
178
+ session_start = time.time()
179
+
180
+ while not self.message_url and (time.time() - session_start) < session_timeout:
94
181
  await asyncio.sleep(0.1)
95
-
182
+
183
+ # Check if SSE task died
184
+ if self.sse_task.done():
185
+ exception = self.sse_task.exception()
186
+ if exception:
187
+ logger.debug(f"SSE task died during session discovery: {exception}")
188
+ await self._cleanup()
189
+ return False
190
+
96
191
  if not self.message_url:
97
- logger.error("Failed to get session info from SSE")
192
+ logger.warning("Failed to discover session endpoint within %.1fs", session_timeout)
193
+ await self._cleanup()
98
194
  return False
99
-
100
- logger.info(f"Session ready: {self.session_id}")
101
-
102
- # Now do MCP initialization
195
+
196
+ if self.enable_metrics and self._metrics:
197
+ self._metrics.session_discoveries += 1
198
+
199
+ logger.debug("Session endpoint discovered: %s", self.message_url)
200
+
201
+ # Perform MCP initialization handshake
103
202
  try:
104
- init_response = await self._send_request("initialize", {
105
- "protocolVersion": "2024-11-05",
106
- "capabilities": {},
107
- "clientInfo": {
108
- "name": "chuk-tool-processor",
109
- "version": "1.0.0"
110
- }
111
- })
112
-
113
- if 'error' in init_response:
114
- logger.error(f"Initialize failed: {init_response['error']}")
203
+ init_response = await self._send_request(
204
+ "initialize",
205
+ {
206
+ "protocolVersion": "2024-11-05",
207
+ "capabilities": {},
208
+ "clientInfo": {"name": "chuk-tool-processor", "version": "1.0.0"},
209
+ },
210
+ timeout=self.default_timeout,
211
+ )
212
+
213
+ if "error" in init_response:
214
+ logger.warning("MCP initialize failed: %s", init_response["error"])
215
+ await self._cleanup()
115
216
  return False
116
-
217
+
117
218
  # Send initialized notification
118
219
  await self._send_notification("notifications/initialized")
119
-
220
+
221
+ # FIXED: Set health tracking state
120
222
  self._initialized = True
121
- logger.info("SSE transport initialized successfully")
223
+ self._initialization_time = time.time()
224
+ self._last_successful_ping = time.time()
225
+ self._consecutive_failures = 0 # Reset failure count
226
+
227
+ if self.enable_metrics and self._metrics:
228
+ init_time = time.time() - start_time
229
+ self._metrics.initialization_time = init_time
230
+
231
+ logger.debug("SSE transport initialized successfully in %.3fs", time.time() - start_time)
122
232
  return True
123
-
233
+
124
234
  except Exception as e:
125
- logger.error(f"MCP initialization failed: {e}")
235
+ logger.error("MCP handshake failed: %s", e)
236
+ await self._cleanup()
126
237
  return False
127
-
238
+
128
239
  except Exception as e:
129
- logger.error(f"Error initializing SSE transport: {e}", exc_info=True)
240
+ logger.error("Error initializing SSE transport: %s", e, exc_info=True)
130
241
  await self._cleanup()
131
242
  return False
132
243
 
133
244
  async def _process_sse_stream(self):
134
- """Process the persistent SSE stream."""
245
+ """Process the SSE stream for responses and session discovery."""
135
246
  try:
136
247
  logger.debug("Starting SSE stream processing...")
137
-
248
+
249
+ current_event = None
250
+
138
251
  async for line in self.sse_response.aiter_lines():
139
252
  line = line.strip()
140
253
  if not line:
141
254
  continue
142
-
143
- # Handle session endpoint discovery
144
- if not self.message_url and line.startswith('data:') and '/messages/' in line:
145
- endpoint_path = line.split(':', 1)[1].strip()
146
- self.message_url = f"{self.url}{endpoint_path}"
147
-
148
- if 'session_id=' in endpoint_path:
149
- self.session_id = endpoint_path.split('session_id=')[1].split('&')[0]
150
-
151
- logger.debug(f"Got session info: {self.session_id}")
255
+
256
+ # Handle event type declarations
257
+ if line.startswith("event:"):
258
+ current_event = line.split(":", 1)[1].strip()
259
+ logger.debug("SSE event type: %s", current_event)
152
260
  continue
153
-
261
+
262
+ # Handle session endpoint discovery
263
+ if not self.message_url and line.startswith("data:"):
264
+ data_part = line.split(":", 1)[1].strip()
265
+
266
+ # NEW FORMAT: event: endpoint + data: https://...
267
+ if current_event == "endpoint" and data_part.startswith("http"):
268
+ self.message_url = data_part
269
+
270
+ # Extract session ID from URL if present
271
+ if "session_id=" in data_part:
272
+ self.session_id = data_part.split("session_id=")[1].split("&")[0]
273
+ elif "sessionId=" in data_part:
274
+ self.session_id = data_part.split("sessionId=")[1].split("&")[0]
275
+ else:
276
+ self.session_id = str(uuid.uuid4())
277
+
278
+ logger.debug("Session endpoint discovered via event format: %s", self.message_url)
279
+ continue
280
+
281
+ # RELATIVE PATH FORMAT: event: endpoint + data: /sse/message?sessionId=...
282
+ elif current_event == "endpoint" and data_part.startswith("/"):
283
+ endpoint_path = data_part
284
+ self.message_url = f"{self.url}{endpoint_path}"
285
+
286
+ # Extract session ID if present
287
+ if "session_id=" in endpoint_path:
288
+ self.session_id = endpoint_path.split("session_id=")[1].split("&")[0]
289
+ elif "sessionId=" in endpoint_path:
290
+ self.session_id = endpoint_path.split("sessionId=")[1].split("&")[0]
291
+ else:
292
+ self.session_id = str(uuid.uuid4())
293
+
294
+ logger.debug("Session endpoint discovered via relative path: %s", self.message_url)
295
+ continue
296
+
297
+ # OLD FORMAT: data: /messages/... (backwards compatibility)
298
+ elif "/messages/" in data_part:
299
+ endpoint_path = data_part
300
+ self.message_url = f"{self.url}{endpoint_path}"
301
+
302
+ # Extract session ID if present
303
+ if "session_id=" in endpoint_path:
304
+ self.session_id = endpoint_path.split("session_id=")[1].split("&")[0]
305
+ else:
306
+ self.session_id = str(uuid.uuid4())
307
+
308
+ logger.debug("Session endpoint discovered via old format: %s", self.message_url)
309
+ continue
310
+
154
311
  # Handle JSON-RPC responses
155
- if line.startswith('data:'):
156
- data_part = line.split(':', 1)[1].strip()
157
-
158
- # Skip pings and empty data
159
- if not data_part or data_part.startswith('ping'):
312
+ if line.startswith("data:"):
313
+ data_part = line.split(":", 1)[1].strip()
314
+
315
+ # Skip keepalive pings and empty data
316
+ if not data_part or data_part.startswith("ping") or data_part in ("{}", "[]"):
160
317
  continue
161
-
318
+
162
319
  try:
163
320
  response_data = json.loads(data_part)
164
-
165
- if 'jsonrpc' in response_data and 'id' in response_data:
166
- request_id = str(response_data['id'])
167
-
168
- # Resolve pending request
321
+
322
+ # Handle JSON-RPC responses with request IDs
323
+ if "jsonrpc" in response_data and "id" in response_data:
324
+ request_id = str(response_data["id"])
325
+
326
+ # Resolve pending request if found
169
327
  if request_id in self.pending_requests:
170
328
  future = self.pending_requests.pop(request_id)
171
329
  if not future.done():
172
330
  future.set_result(response_data)
173
- logger.debug(f"Resolved request: {request_id}")
174
-
175
- except json.JSONDecodeError:
176
- pass # Not JSON, ignore
177
-
331
+ logger.debug("Resolved request ID: %s", request_id)
332
+
333
+ except json.JSONDecodeError as e:
334
+ logger.debug("Non-JSON data in SSE stream (ignoring): %s", e)
335
+
178
336
  except Exception as e:
179
- logger.error(f"SSE stream error: {e}")
337
+ if self.enable_metrics and self._metrics:
338
+ self._metrics.stream_errors += 1
339
+ logger.error("SSE stream processing error: %s", e)
340
+ # FIXED: Don't increment consecutive failures for stream processing errors
341
+ # These are often temporary and don't indicate connection health
180
342
 
181
- async def _send_request(self, method: str, params: Dict[str, Any] = None,
182
- timeout: Optional[float] = None) -> Dict[str, Any]:
183
- """Send request and wait for async response."""
343
+ async def _send_request(
344
+ self, method: str, params: dict[str, Any] = None, timeout: float | None = None
345
+ ) -> dict[str, Any]:
346
+ """Send JSON-RPC request and wait for async response via SSE."""
184
347
  if not self.message_url:
185
- raise RuntimeError("Not connected")
186
-
348
+ raise RuntimeError("SSE transport not connected - no message URL")
349
+
187
350
  request_id = str(uuid.uuid4())
188
- message = {
189
- "jsonrpc": "2.0",
190
- "id": request_id,
191
- "method": method,
192
- "params": params or {}
193
- }
194
-
195
- # Create future for response
351
+ message = {"jsonrpc": "2.0", "id": request_id, "method": method, "params": params or {}}
352
+
353
+ # Create future for async response
196
354
  future = asyncio.Future()
197
355
  self.pending_requests[request_id] = future
198
-
356
+
199
357
  try:
200
- # Send message
201
- headers = {
202
- 'Content-Type': 'application/json',
203
- **self._get_headers()
204
- }
205
-
206
- response = await self.send_client.post(
207
- self.message_url,
208
- headers=headers,
209
- json=message
210
- )
211
-
358
+ # Send HTTP POST request
359
+ headers = {"Content-Type": "application/json", **self._get_headers()}
360
+
361
+ response = await self.send_client.post(self.message_url, headers=headers, json=message)
362
+
212
363
  if response.status_code == 202:
213
- # Wait for async response
214
- timeout = timeout or self.default_timeout
215
- result = await asyncio.wait_for(future, timeout=timeout)
364
+ # Async response - wait for result via SSE
365
+ request_timeout = timeout or self.default_timeout
366
+ result = await asyncio.wait_for(future, timeout=request_timeout)
367
+ # FIXED: Only reset failures on successful tool calls, not all requests
368
+ if method.startswith("tools/"):
369
+ self._consecutive_failures = 0
370
+ self._last_successful_ping = time.time()
216
371
  return result
217
372
  elif response.status_code == 200:
218
373
  # Immediate response
219
374
  self.pending_requests.pop(request_id, None)
375
+ # FIXED: Only reset failures on successful tool calls
376
+ if method.startswith("tools/"):
377
+ self._consecutive_failures = 0
378
+ self._last_successful_ping = time.time()
220
379
  return response.json()
221
380
  else:
222
381
  self.pending_requests.pop(request_id, None)
223
- raise RuntimeError(f"Request failed: {response.status_code}")
224
-
225
- except asyncio.TimeoutError:
382
+ # FIXED: Only increment failures for tool calls, not initialization
383
+ if method.startswith("tools/"):
384
+ self._consecutive_failures += 1
385
+ raise RuntimeError(f"HTTP request failed with status: {response.status_code}")
386
+
387
+ except TimeoutError:
226
388
  self.pending_requests.pop(request_id, None)
389
+ # FIXED: Only increment failures for tool calls
390
+ if method.startswith("tools/"):
391
+ self._consecutive_failures += 1
227
392
  raise
228
393
  except Exception:
229
394
  self.pending_requests.pop(request_id, None)
395
+ # FIXED: Only increment failures for tool calls
396
+ if method.startswith("tools/"):
397
+ self._consecutive_failures += 1
230
398
  raise
231
399
 
232
- async def _send_notification(self, method: str, params: Dict[str, Any] = None):
233
- """Send notification (no response expected)."""
400
+ async def _send_notification(self, method: str, params: dict[str, Any] = None):
401
+ """Send JSON-RPC notification (no response expected)."""
234
402
  if not self.message_url:
235
- raise RuntimeError("Not connected")
236
-
237
- message = {
238
- "jsonrpc": "2.0",
239
- "method": method,
240
- "params": params or {}
241
- }
242
-
243
- headers = {
244
- 'Content-Type': 'application/json',
245
- **self._get_headers()
246
- }
247
-
248
- await self.send_client.post(
249
- self.message_url,
250
- headers=headers,
251
- json=message
252
- )
403
+ raise RuntimeError("SSE transport not connected - no message URL")
404
+
405
+ message = {"jsonrpc": "2.0", "method": method, "params": params or {}}
406
+
407
+ headers = {"Content-Type": "application/json", **self._get_headers()}
408
+
409
+ response = await self.send_client.post(self.message_url, headers=headers, json=message)
410
+
411
+ if response.status_code not in (200, 202):
412
+ logger.warning("Notification failed with status: %s", response.status_code)
253
413
 
254
414
  async def send_ping(self) -> bool:
255
- """Send ping to check connection."""
415
+ """Send ping to check connection health with improved logic."""
256
416
  if not self._initialized:
257
417
  return False
258
-
418
+
419
+ start_time = time.time()
259
420
  try:
260
- # Your server might not support ping, so we'll just check if we can list tools
261
- response = await self._send_request("tools/list", {}, timeout=5.0)
262
- return 'error' not in response
263
- except Exception:
421
+ # Use tools/list as a lightweight ping since not all servers support ping
422
+ response = await self._send_request("tools/list", {}, timeout=self.timeout_config.quick)
423
+
424
+ success = "error" not in response
425
+
426
+ if success:
427
+ self._last_successful_ping = time.time()
428
+ # FIXED: Don't reset consecutive failures here - let tool calls do that
429
+
430
+ if self.enable_metrics and self._metrics:
431
+ ping_time = time.time() - start_time
432
+ self._metrics.last_ping_time = ping_time
433
+ logger.debug("SSE ping completed in %.3fs: %s", ping_time, success)
434
+
435
+ return success
436
+ except Exception as e:
437
+ logger.debug("SSE ping failed: %s", e)
438
+ # FIXED: Don't increment consecutive failures for ping failures
439
+ return False
440
+
441
+ def is_connected(self) -> bool:
442
+ """
443
+ FIXED: More lenient connection health check.
444
+
445
+ The diagnostic shows the connection works fine, so we need to be less aggressive
446
+ about marking it as unhealthy.
447
+ """
448
+ if not self._initialized or not self.session_id:
264
449
  return False
265
450
 
266
- async def get_tools(self) -> List[Dict[str, Any]]:
267
- """Get tools list."""
451
+ # FIXED: Grace period after initialization - always return True for a while
452
+ if self._initialization_time and time.time() - self._initialization_time < self._connection_grace_period:
453
+ logger.debug("Within grace period - connection considered healthy")
454
+ return True
455
+
456
+ # FIXED: More lenient failure threshold
457
+ if self._consecutive_failures >= self._max_consecutive_failures:
458
+ logger.warning(f"Connection marked unhealthy after {self._consecutive_failures} consecutive failures")
459
+ return False
460
+
461
+ # Check if SSE task is still running
462
+ if self.sse_task and self.sse_task.done():
463
+ exception = self.sse_task.exception()
464
+ if exception:
465
+ logger.warning(f"SSE task died: {exception}")
466
+ return False
467
+
468
+ # FIXED: If we have a recent successful ping/tool call, we're healthy
469
+ if self._last_successful_ping and time.time() - self._last_successful_ping < 60.0: # Success within last minute
470
+ return True
471
+
472
+ # FIXED: Default to healthy if no clear indicators of problems
473
+ logger.debug("No clear health indicators - defaulting to healthy")
474
+ return True
475
+
476
+ async def get_tools(self) -> list[dict[str, Any]]:
477
+ """Get list of available tools from the server."""
268
478
  if not self._initialized:
269
- logger.error("Cannot get tools: transport not initialized")
479
+ logger.debug("Cannot get tools: transport not initialized")
270
480
  return []
271
-
481
+
482
+ start_time = time.time()
272
483
  try:
273
484
  response = await self._send_request("tools/list", {})
274
-
275
- if 'error' in response:
276
- logger.error(f"Error getting tools: {response['error']}")
485
+
486
+ if "error" in response:
487
+ logger.warning("Error getting tools: %s", response["error"])
277
488
  return []
278
-
279
- tools = response.get('result', {}).get('tools', [])
280
- logger.debug(f"Retrieved {len(tools)} tools")
489
+
490
+ tools = response.get("result", {}).get("tools", [])
491
+
492
+ if self.enable_metrics:
493
+ response_time = time.time() - start_time
494
+ logger.debug("Retrieved %d tools in %.3fs", len(tools), response_time)
495
+
281
496
  return tools
282
-
497
+
283
498
  except Exception as e:
284
- logger.error(f"Error getting tools: {e}")
499
+ logger.error("Error getting tools: %s", e)
285
500
  return []
286
501
 
287
- async def call_tool(self, tool_name: str, arguments: Dict[str, Any],
288
- timeout: Optional[float] = None) -> Dict[str, Any]:
289
- """Call a tool."""
502
+ async def call_tool(
503
+ self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
504
+ ) -> dict[str, Any]:
505
+ """Execute a tool with the given arguments."""
290
506
  if not self._initialized:
291
- return {
292
- "isError": True,
293
- "error": "Transport not initialized"
294
- }
507
+ return {"isError": True, "error": "Transport not initialized"}
508
+
509
+ start_time = time.time()
510
+ if self.enable_metrics and self._metrics:
511
+ self._metrics.total_calls += 1
295
512
 
296
513
  try:
297
- logger.debug(f"Calling tool {tool_name} with args: {arguments}")
298
-
514
+ logger.debug("Calling tool '%s' with arguments: %s", tool_name, arguments)
515
+
299
516
  response = await self._send_request(
300
- "tools/call",
301
- {
302
- "name": tool_name,
303
- "arguments": arguments
304
- },
305
- timeout=timeout
517
+ "tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
306
518
  )
307
-
308
- if 'error' in response:
309
- return {
310
- "isError": True,
311
- "error": response['error'].get('message', 'Unknown error')
312
- }
313
-
314
- # Extract result
315
- result = response.get('result', {})
316
-
317
- # Handle content format
318
- if 'content' in result:
319
- content = result['content']
320
- if isinstance(content, list) and len(content) == 1:
321
- content_item = content[0]
322
- if isinstance(content_item, dict) and content_item.get('type') == 'text':
323
- text_content = content_item.get('text', '')
519
+
520
+ # Check for errors
521
+ if "error" in response:
522
+ error_msg = response["error"].get("message", "Unknown error")
523
+
524
+ # NEW: Check for OAuth errors and attempt refresh if callback is available
525
+ if self._is_oauth_error(error_msg):
526
+ logger.warning("OAuth error detected: %s", error_msg)
527
+
528
+ if self.oauth_refresh_callback:
529
+ logger.debug("Attempting OAuth token refresh...")
324
530
  try:
325
- # Try to parse as JSON
326
- parsed_content = json.loads(text_content)
327
- return {
328
- "isError": False,
329
- "content": parsed_content
330
- }
331
- except json.JSONDecodeError:
332
- return {
333
- "isError": False,
334
- "content": text_content
335
- }
336
-
337
- return {
338
- "isError": False,
339
- "content": content
340
- }
341
-
342
- return {
343
- "isError": False,
344
- "content": result
345
- }
346
-
347
- except asyncio.TimeoutError:
348
- return {
349
- "isError": True,
350
- "error": f"Tool execution timed out"
351
- }
531
+ # Call the refresh callback
532
+ new_headers = await self.oauth_refresh_callback()
533
+
534
+ if new_headers and "Authorization" in new_headers:
535
+ # Update configured headers with new token
536
+ self.configured_headers.update(new_headers)
537
+ logger.debug("OAuth token refreshed, retrying tool call...")
538
+
539
+ # Retry the tool call once with new token
540
+ response = await self._send_request(
541
+ "tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
542
+ )
543
+
544
+ # Check if retry succeeded
545
+ if "error" not in response:
546
+ logger.debug("Tool call succeeded after token refresh")
547
+ result = response.get("result", {})
548
+ normalized_result = self._normalize_mcp_response({"result": result})
549
+
550
+ if self.enable_metrics:
551
+ self._update_metrics(time.time() - start_time, True)
552
+
553
+ return normalized_result
554
+ else:
555
+ error_msg = response["error"].get("message", "Unknown error")
556
+ logger.error("Tool call failed after token refresh: %s", error_msg)
557
+ else:
558
+ logger.warning("Token refresh did not return valid Authorization header")
559
+ except Exception as refresh_error:
560
+ logger.error("OAuth token refresh failed: %s", refresh_error)
561
+ else:
562
+ logger.warning("OAuth error detected but no refresh callback configured")
563
+
564
+ # Return error (original or from failed retry)
565
+ if self.enable_metrics:
566
+ self._update_metrics(time.time() - start_time, False)
567
+
568
+ return {"isError": True, "error": error_msg}
569
+
570
+ # Extract and normalize result using base class method
571
+ result = response.get("result", {})
572
+ normalized_result = self._normalize_mcp_response({"result": result})
573
+
574
+ if self.enable_metrics:
575
+ self._update_metrics(time.time() - start_time, True)
576
+
577
+ return normalized_result
578
+
579
+ except TimeoutError:
580
+ if self.enable_metrics:
581
+ self._update_metrics(time.time() - start_time, False)
582
+
583
+ return {"isError": True, "error": "Tool execution timed out"}
352
584
  except Exception as e:
353
- logger.error(f"Error calling tool {tool_name}: {e}")
354
- return {
355
- "isError": True,
356
- "error": str(e)
357
- }
585
+ if self.enable_metrics:
586
+ self._update_metrics(time.time() - start_time, False)
587
+
588
+ logger.error("Error calling tool '%s': %s", tool_name, e)
589
+ return {"isError": True, "error": str(e)}
590
+
591
+ def _update_metrics(self, response_time: float, success: bool) -> None:
592
+ """Update performance metrics."""
593
+ if not self._metrics:
594
+ return
595
+
596
+ self._metrics.update_call_metrics(response_time, success)
358
597
 
359
- async def list_resources(self) -> Dict[str, Any]:
360
- """List resources."""
598
+ def _is_oauth_error(self, error_msg: str) -> bool:
599
+ """
600
+ Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
601
+
602
+ Checks for:
603
+ - RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
604
+ - OAuth 2.1 token refresh errors (invalid_grant)
605
+ - MCP spec OAuth validation failures (401/403 responses)
606
+ """
607
+ if not error_msg:
608
+ return False
609
+
610
+ error_lower = error_msg.lower()
611
+ oauth_indicators = [
612
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
613
+ "invalid_token", # Token expired, revoked, malformed, or invalid
614
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
615
+ # OAuth 2.1 token refresh errors
616
+ "invalid_grant", # Refresh token errors
617
+ # MCP spec - OAuth validation failures (401 Unauthorized)
618
+ "oauth validation",
619
+ "unauthorized",
620
+ # Common OAuth error descriptions
621
+ "expired token",
622
+ "token expired",
623
+ "authentication failed",
624
+ "invalid access token",
625
+ ]
626
+
627
+ return any(indicator in error_lower for indicator in oauth_indicators)
628
+
629
+ async def list_resources(self) -> dict[str, Any]:
630
+ """List available resources from the server."""
361
631
  if not self._initialized:
362
632
  return {}
363
-
633
+
364
634
  try:
365
- response = await self._send_request("resources/list", {}, timeout=10.0)
366
- if 'error' in response:
367
- logger.debug(f"Resources not supported: {response['error']}")
635
+ response = await self._send_request("resources/list", {}, timeout=self.timeout_config.operation)
636
+ if "error" in response:
637
+ logger.debug("Resources not supported: %s", response["error"])
368
638
  return {}
369
- return response.get('result', {})
370
- except Exception:
639
+ return response.get("result", {})
640
+ except Exception as e:
641
+ logger.debug("Error listing resources: %s", e)
371
642
  return {}
372
643
 
373
- async def list_prompts(self) -> Dict[str, Any]:
374
- """List prompts."""
644
+ async def list_prompts(self) -> dict[str, Any]:
645
+ """List available prompts from the server."""
375
646
  if not self._initialized:
376
647
  return {}
377
-
648
+
378
649
  try:
379
- response = await self._send_request("prompts/list", {}, timeout=10.0)
380
- if 'error' in response:
381
- logger.debug(f"Prompts not supported: {response['error']}")
650
+ response = await self._send_request("prompts/list", {}, timeout=self.timeout_config.operation)
651
+ if "error" in response:
652
+ logger.debug("Prompts not supported: %s", response["error"])
382
653
  return {}
383
- return response.get('result', {})
384
- except Exception:
654
+ return response.get("result", {})
655
+ except Exception as e:
656
+ logger.debug("Error listing prompts: %s", e)
385
657
  return {}
386
658
 
387
659
  async def close(self) -> None:
388
- """Close the transport."""
660
+ """Close the transport and clean up resources."""
661
+ if not self._initialized:
662
+ return
663
+
664
+ # Log final metrics
665
+ if self.enable_metrics and self._metrics and self._metrics.total_calls > 0:
666
+ logger.debug(
667
+ "SSE transport closing - Total calls: %d, Success rate: %.1f%%, Avg response time: %.3fs",
668
+ self._metrics.total_calls,
669
+ (self._metrics.successful_calls / self._metrics.total_calls * 100),
670
+ self._metrics.avg_response_time,
671
+ )
672
+
389
673
  await self._cleanup()
390
674
 
391
675
  async def _cleanup(self) -> None:
392
- """Clean up resources."""
393
- if self.sse_task:
676
+ """Clean up all resources and reset state."""
677
+ # Cancel SSE processing task
678
+ if self.sse_task and not self.sse_task.done():
394
679
  self.sse_task.cancel()
395
- try:
680
+ with contextlib.suppress(asyncio.CancelledError):
396
681
  await self.sse_task
397
- except asyncio.CancelledError:
398
- pass
399
-
682
+
683
+ # Close SSE stream context
400
684
  if self.sse_stream_context:
401
685
  try:
402
686
  await self.sse_stream_context.__aexit__(None, None, None)
403
- except Exception:
404
- pass
405
-
687
+ except Exception as e:
688
+ logger.debug("Error closing SSE stream: %s", e)
689
+
690
+ # Close HTTP clients
406
691
  if self.stream_client:
407
692
  await self.stream_client.aclose()
408
-
693
+
409
694
  if self.send_client:
410
695
  await self.send_client.aclose()
411
-
696
+
697
+ # Cancel any pending requests
698
+ for _request_id, future in self.pending_requests.items():
699
+ if not future.done():
700
+ future.cancel()
701
+
702
+ # Reset state
412
703
  self._initialized = False
413
704
  self.session_id = None
414
705
  self.message_url = None
415
706
  self.pending_requests.clear()
707
+ self.sse_task = None
708
+ self.sse_response = None
709
+ self.sse_stream_context = None
710
+ self.stream_client = None
711
+ self.send_client = None
712
+ # FIXED: Reset health tracking
713
+ self._consecutive_failures = 0
714
+ self._last_successful_ping = None
715
+ self._initialization_time = None
416
716
 
417
- def get_streams(self) -> List[tuple]:
418
- """Not applicable for this transport."""
419
- return []
717
+ def get_metrics(self) -> dict[str, Any]:
718
+ """Get performance and connection metrics with health info."""
719
+ if not self._metrics:
720
+ return {}
420
721
 
421
- def is_connected(self) -> bool:
422
- """Check if connected."""
423
- return self._initialized and self.session_id is not None
722
+ metrics = self._metrics.to_dict()
723
+ metrics.update(
724
+ {
725
+ "is_connected": self.is_connected(),
726
+ "consecutive_failures": self._consecutive_failures,
727
+ "max_consecutive_failures": self._max_consecutive_failures,
728
+ "last_successful_ping": self._last_successful_ping,
729
+ "initialization_time_timestamp": self._initialization_time,
730
+ "grace_period_active": (
731
+ self._initialization_time
732
+ and time.time() - self._initialization_time < self._connection_grace_period
733
+ )
734
+ if self._initialization_time
735
+ else False,
736
+ }
737
+ )
738
+ return metrics
739
+
740
+ def reset_metrics(self) -> None:
741
+ """Reset performance metrics."""
742
+ if not self._metrics:
743
+ return
744
+
745
+ # Preserve important historical values
746
+ preserved_last_ping = self._metrics.last_ping_time
747
+ preserved_init_time = self._metrics.initialization_time
748
+ preserved_discoveries = self._metrics.session_discoveries
749
+
750
+ # Create new metrics instance with preserved values
751
+ self._metrics = TransportMetrics(
752
+ last_ping_time=preserved_last_ping,
753
+ initialization_time=preserved_init_time,
754
+ session_discoveries=preserved_discoveries,
755
+ )
756
+
757
+ def get_streams(self) -> list[tuple]:
758
+ """SSE transport doesn't expose raw streams."""
759
+ return []
424
760
 
425
761
  async def __aenter__(self):
426
- """Context manager support."""
762
+ """Context manager entry."""
427
763
  success = await self.initialize()
428
764
  if not success:
429
- raise RuntimeError("Failed to initialize SSE transport")
765
+ raise RuntimeError("Failed to initialize SSETransport")
430
766
  return self
431
767
 
432
768
  async def __aexit__(self, exc_type, exc_val, exc_tb):
433
769
  """Context manager cleanup."""
434
770
  await self.close()
435
-
436
- def __repr__(self) -> str:
437
- """String representation."""
438
- status = "initialized" if self._initialized else "not initialized"
439
- return f"SSETransport(status={status}, url={self.url}, session={self.session_id})"