chuk-tool-processor 0.6.13__py3-none-any.whl → 0.9.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/core/__init__.py +31 -0
- chuk_tool_processor/core/exceptions.py +218 -12
- chuk_tool_processor/core/processor.py +38 -7
- chuk_tool_processor/execution/strategies/__init__.py +6 -0
- chuk_tool_processor/execution/strategies/subprocess_strategy.py +2 -1
- chuk_tool_processor/execution/wrappers/__init__.py +42 -0
- chuk_tool_processor/execution/wrappers/caching.py +48 -13
- chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
- chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
- chuk_tool_processor/execution/wrappers/retry.py +93 -53
- chuk_tool_processor/logging/metrics.py +2 -2
- chuk_tool_processor/mcp/mcp_tool.py +5 -5
- chuk_tool_processor/mcp/setup_mcp_http_streamable.py +44 -2
- chuk_tool_processor/mcp/setup_mcp_sse.py +44 -2
- chuk_tool_processor/mcp/setup_mcp_stdio.py +2 -0
- chuk_tool_processor/mcp/stream_manager.py +130 -75
- chuk_tool_processor/mcp/transport/__init__.py +10 -0
- chuk_tool_processor/mcp/transport/http_streamable_transport.py +193 -108
- chuk_tool_processor/mcp/transport/models.py +100 -0
- chuk_tool_processor/mcp/transport/sse_transport.py +155 -59
- chuk_tool_processor/mcp/transport/stdio_transport.py +58 -10
- chuk_tool_processor/models/__init__.py +20 -0
- chuk_tool_processor/models/tool_call.py +34 -1
- chuk_tool_processor/models/tool_spec.py +350 -0
- chuk_tool_processor/models/validated_tool.py +22 -2
- chuk_tool_processor/observability/__init__.py +30 -0
- chuk_tool_processor/observability/metrics.py +312 -0
- chuk_tool_processor/observability/setup.py +105 -0
- chuk_tool_processor/observability/tracing.py +345 -0
- chuk_tool_processor/plugins/discovery.py +1 -1
- chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/RECORD +34 -27
- chuk_tool_processor-0.6.13.dist-info/METADATA +0 -698
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,7 @@ from typing import Any
|
|
|
19
19
|
import httpx
|
|
20
20
|
|
|
21
21
|
from .base_transport import MCPBaseTransport
|
|
22
|
+
from .models import TimeoutConfig, TransportMetrics
|
|
22
23
|
|
|
23
24
|
logger = logging.getLogger(__name__)
|
|
24
25
|
|
|
@@ -38,6 +39,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
38
39
|
connection_timeout: float = 30.0,
|
|
39
40
|
default_timeout: float = 60.0,
|
|
40
41
|
enable_metrics: bool = True,
|
|
42
|
+
oauth_refresh_callback: Any | None = None,
|
|
43
|
+
timeout_config: TimeoutConfig | None = None,
|
|
41
44
|
):
|
|
42
45
|
"""
|
|
43
46
|
Initialize SSE transport.
|
|
@@ -45,9 +48,16 @@ class SSETransport(MCPBaseTransport):
|
|
|
45
48
|
self.url = url.rstrip("/")
|
|
46
49
|
self.api_key = api_key
|
|
47
50
|
self.configured_headers = headers or {}
|
|
48
|
-
self.connection_timeout = connection_timeout
|
|
49
|
-
self.default_timeout = default_timeout
|
|
50
51
|
self.enable_metrics = enable_metrics
|
|
52
|
+
self.oauth_refresh_callback = oauth_refresh_callback
|
|
53
|
+
|
|
54
|
+
# Use timeout config or create from individual parameters
|
|
55
|
+
if timeout_config is None:
|
|
56
|
+
timeout_config = TimeoutConfig(connect=connection_timeout, operation=default_timeout)
|
|
57
|
+
|
|
58
|
+
self.timeout_config = timeout_config
|
|
59
|
+
self.connection_timeout = timeout_config.connect
|
|
60
|
+
self.default_timeout = timeout_config.operation
|
|
51
61
|
|
|
52
62
|
logger.debug("SSE Transport initialized with URL: %s", self.url)
|
|
53
63
|
|
|
@@ -73,18 +83,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
73
83
|
self._connection_grace_period = 30.0 # NEW: Grace period after initialization
|
|
74
84
|
self._initialization_time = None # NEW: Track when we initialized
|
|
75
85
|
|
|
76
|
-
# Performance metrics
|
|
77
|
-
self._metrics =
|
|
78
|
-
"total_calls": 0,
|
|
79
|
-
"successful_calls": 0,
|
|
80
|
-
"failed_calls": 0,
|
|
81
|
-
"total_time": 0.0,
|
|
82
|
-
"avg_response_time": 0.0,
|
|
83
|
-
"last_ping_time": None,
|
|
84
|
-
"initialization_time": None,
|
|
85
|
-
"session_discoveries": 0,
|
|
86
|
-
"stream_errors": 0,
|
|
87
|
-
}
|
|
86
|
+
# Performance metrics - use Pydantic model
|
|
87
|
+
self._metrics = TransportMetrics() if enable_metrics else None
|
|
88
88
|
|
|
89
89
|
def _construct_sse_url(self, base_url: str) -> str:
|
|
90
90
|
"""Construct the SSE endpoint URL from the base URL."""
|
|
@@ -110,8 +110,9 @@ class SSETransport(MCPBaseTransport):
|
|
|
110
110
|
if self.configured_headers:
|
|
111
111
|
headers.update(self.configured_headers)
|
|
112
112
|
|
|
113
|
-
# Add API key as Bearer token if provided
|
|
114
|
-
|
|
113
|
+
# Add API key as Bearer token if provided and no Authorization header exists
|
|
114
|
+
# This prevents clobbering OAuth tokens from configured_headers
|
|
115
|
+
if self.api_key and "Authorization" not in headers:
|
|
115
116
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
116
117
|
|
|
117
118
|
return headers
|
|
@@ -173,7 +174,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
173
174
|
|
|
174
175
|
# Wait for session discovery
|
|
175
176
|
logger.debug("Waiting for session discovery...")
|
|
176
|
-
session_timeout =
|
|
177
|
+
session_timeout = self.timeout_config.connect
|
|
177
178
|
session_start = time.time()
|
|
178
179
|
|
|
179
180
|
while not self.message_url and (time.time() - session_start) < session_timeout:
|
|
@@ -183,17 +184,17 @@ class SSETransport(MCPBaseTransport):
|
|
|
183
184
|
if self.sse_task.done():
|
|
184
185
|
exception = self.sse_task.exception()
|
|
185
186
|
if exception:
|
|
186
|
-
logger.
|
|
187
|
+
logger.debug(f"SSE task died during session discovery: {exception}")
|
|
187
188
|
await self._cleanup()
|
|
188
189
|
return False
|
|
189
190
|
|
|
190
191
|
if not self.message_url:
|
|
191
|
-
logger.
|
|
192
|
+
logger.warning("Failed to discover session endpoint within %.1fs", session_timeout)
|
|
192
193
|
await self._cleanup()
|
|
193
194
|
return False
|
|
194
195
|
|
|
195
|
-
if self.enable_metrics:
|
|
196
|
-
self._metrics
|
|
196
|
+
if self.enable_metrics and self._metrics:
|
|
197
|
+
self._metrics.session_discoveries += 1
|
|
197
198
|
|
|
198
199
|
logger.debug("Session endpoint discovered: %s", self.message_url)
|
|
199
200
|
|
|
@@ -210,7 +211,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
210
211
|
)
|
|
211
212
|
|
|
212
213
|
if "error" in init_response:
|
|
213
|
-
logger.
|
|
214
|
+
logger.warning("MCP initialize failed: %s", init_response["error"])
|
|
214
215
|
await self._cleanup()
|
|
215
216
|
return False
|
|
216
217
|
|
|
@@ -223,9 +224,9 @@ class SSETransport(MCPBaseTransport):
|
|
|
223
224
|
self._last_successful_ping = time.time()
|
|
224
225
|
self._consecutive_failures = 0 # Reset failure count
|
|
225
226
|
|
|
226
|
-
if self.enable_metrics:
|
|
227
|
+
if self.enable_metrics and self._metrics:
|
|
227
228
|
init_time = time.time() - start_time
|
|
228
|
-
self._metrics
|
|
229
|
+
self._metrics.initialization_time = init_time
|
|
229
230
|
|
|
230
231
|
logger.debug("SSE transport initialized successfully in %.3fs", time.time() - start_time)
|
|
231
232
|
return True
|
|
@@ -269,12 +270,30 @@ class SSETransport(MCPBaseTransport):
|
|
|
269
270
|
# Extract session ID from URL if present
|
|
270
271
|
if "session_id=" in data_part:
|
|
271
272
|
self.session_id = data_part.split("session_id=")[1].split("&")[0]
|
|
273
|
+
elif "sessionId=" in data_part:
|
|
274
|
+
self.session_id = data_part.split("sessionId=")[1].split("&")[0]
|
|
272
275
|
else:
|
|
273
276
|
self.session_id = str(uuid.uuid4())
|
|
274
277
|
|
|
275
278
|
logger.debug("Session endpoint discovered via event format: %s", self.message_url)
|
|
276
279
|
continue
|
|
277
280
|
|
|
281
|
+
# RELATIVE PATH FORMAT: event: endpoint + data: /sse/message?sessionId=...
|
|
282
|
+
elif current_event == "endpoint" and data_part.startswith("/"):
|
|
283
|
+
endpoint_path = data_part
|
|
284
|
+
self.message_url = f"{self.url}{endpoint_path}"
|
|
285
|
+
|
|
286
|
+
# Extract session ID if present
|
|
287
|
+
if "session_id=" in endpoint_path:
|
|
288
|
+
self.session_id = endpoint_path.split("session_id=")[1].split("&")[0]
|
|
289
|
+
elif "sessionId=" in endpoint_path:
|
|
290
|
+
self.session_id = endpoint_path.split("sessionId=")[1].split("&")[0]
|
|
291
|
+
else:
|
|
292
|
+
self.session_id = str(uuid.uuid4())
|
|
293
|
+
|
|
294
|
+
logger.debug("Session endpoint discovered via relative path: %s", self.message_url)
|
|
295
|
+
continue
|
|
296
|
+
|
|
278
297
|
# OLD FORMAT: data: /messages/... (backwards compatibility)
|
|
279
298
|
elif "/messages/" in data_part:
|
|
280
299
|
endpoint_path = data_part
|
|
@@ -315,8 +334,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
315
334
|
logger.debug("Non-JSON data in SSE stream (ignoring): %s", e)
|
|
316
335
|
|
|
317
336
|
except Exception as e:
|
|
318
|
-
if self.enable_metrics:
|
|
319
|
-
self._metrics
|
|
337
|
+
if self.enable_metrics and self._metrics:
|
|
338
|
+
self._metrics.stream_errors += 1
|
|
320
339
|
logger.error("SSE stream processing error: %s", e)
|
|
321
340
|
# FIXED: Don't increment consecutive failures for stream processing errors
|
|
322
341
|
# These are often temporary and don't indicate connection health
|
|
@@ -400,7 +419,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
400
419
|
start_time = time.time()
|
|
401
420
|
try:
|
|
402
421
|
# Use tools/list as a lightweight ping since not all servers support ping
|
|
403
|
-
response = await self._send_request("tools/list", {}, timeout=
|
|
422
|
+
response = await self._send_request("tools/list", {}, timeout=self.timeout_config.quick)
|
|
404
423
|
|
|
405
424
|
success = "error" not in response
|
|
406
425
|
|
|
@@ -408,9 +427,9 @@ class SSETransport(MCPBaseTransport):
|
|
|
408
427
|
self._last_successful_ping = time.time()
|
|
409
428
|
# FIXED: Don't reset consecutive failures here - let tool calls do that
|
|
410
429
|
|
|
411
|
-
if self.enable_metrics:
|
|
430
|
+
if self.enable_metrics and self._metrics:
|
|
412
431
|
ping_time = time.time() - start_time
|
|
413
|
-
self._metrics
|
|
432
|
+
self._metrics.last_ping_time = ping_time
|
|
414
433
|
logger.debug("SSE ping completed in %.3fs: %s", ping_time, success)
|
|
415
434
|
|
|
416
435
|
return success
|
|
@@ -457,7 +476,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
457
476
|
async def get_tools(self) -> list[dict[str, Any]]:
|
|
458
477
|
"""Get list of available tools from the server."""
|
|
459
478
|
if not self._initialized:
|
|
460
|
-
logger.
|
|
479
|
+
logger.debug("Cannot get tools: transport not initialized")
|
|
461
480
|
return []
|
|
462
481
|
|
|
463
482
|
start_time = time.time()
|
|
@@ -465,7 +484,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
465
484
|
response = await self._send_request("tools/list", {})
|
|
466
485
|
|
|
467
486
|
if "error" in response:
|
|
468
|
-
logger.
|
|
487
|
+
logger.warning("Error getting tools: %s", response["error"])
|
|
469
488
|
return []
|
|
470
489
|
|
|
471
490
|
tools = response.get("result", {}).get("tools", [])
|
|
@@ -488,8 +507,8 @@ class SSETransport(MCPBaseTransport):
|
|
|
488
507
|
return {"isError": True, "error": "Transport not initialized"}
|
|
489
508
|
|
|
490
509
|
start_time = time.time()
|
|
491
|
-
if self.enable_metrics:
|
|
492
|
-
self._metrics
|
|
510
|
+
if self.enable_metrics and self._metrics:
|
|
511
|
+
self._metrics.total_calls += 1
|
|
493
512
|
|
|
494
513
|
try:
|
|
495
514
|
logger.debug("Calling tool '%s' with arguments: %s", tool_name, arguments)
|
|
@@ -498,11 +517,55 @@ class SSETransport(MCPBaseTransport):
|
|
|
498
517
|
"tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
|
|
499
518
|
)
|
|
500
519
|
|
|
520
|
+
# Check for errors
|
|
501
521
|
if "error" in response:
|
|
522
|
+
error_msg = response["error"].get("message", "Unknown error")
|
|
523
|
+
|
|
524
|
+
# NEW: Check for OAuth errors and attempt refresh if callback is available
|
|
525
|
+
if self._is_oauth_error(error_msg):
|
|
526
|
+
logger.warning("OAuth error detected: %s", error_msg)
|
|
527
|
+
|
|
528
|
+
if self.oauth_refresh_callback:
|
|
529
|
+
logger.debug("Attempting OAuth token refresh...")
|
|
530
|
+
try:
|
|
531
|
+
# Call the refresh callback
|
|
532
|
+
new_headers = await self.oauth_refresh_callback()
|
|
533
|
+
|
|
534
|
+
if new_headers and "Authorization" in new_headers:
|
|
535
|
+
# Update configured headers with new token
|
|
536
|
+
self.configured_headers.update(new_headers)
|
|
537
|
+
logger.debug("OAuth token refreshed, retrying tool call...")
|
|
538
|
+
|
|
539
|
+
# Retry the tool call once with new token
|
|
540
|
+
response = await self._send_request(
|
|
541
|
+
"tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Check if retry succeeded
|
|
545
|
+
if "error" not in response:
|
|
546
|
+
logger.debug("Tool call succeeded after token refresh")
|
|
547
|
+
result = response.get("result", {})
|
|
548
|
+
normalized_result = self._normalize_mcp_response({"result": result})
|
|
549
|
+
|
|
550
|
+
if self.enable_metrics:
|
|
551
|
+
self._update_metrics(time.time() - start_time, True)
|
|
552
|
+
|
|
553
|
+
return normalized_result
|
|
554
|
+
else:
|
|
555
|
+
error_msg = response["error"].get("message", "Unknown error")
|
|
556
|
+
logger.error("Tool call failed after token refresh: %s", error_msg)
|
|
557
|
+
else:
|
|
558
|
+
logger.warning("Token refresh did not return valid Authorization header")
|
|
559
|
+
except Exception as refresh_error:
|
|
560
|
+
logger.error("OAuth token refresh failed: %s", refresh_error)
|
|
561
|
+
else:
|
|
562
|
+
logger.warning("OAuth error detected but no refresh callback configured")
|
|
563
|
+
|
|
564
|
+
# Return error (original or from failed retry)
|
|
502
565
|
if self.enable_metrics:
|
|
503
566
|
self._update_metrics(time.time() - start_time, False)
|
|
504
567
|
|
|
505
|
-
return {"isError": True, "error":
|
|
568
|
+
return {"isError": True, "error": error_msg}
|
|
506
569
|
|
|
507
570
|
# Extract and normalize result using base class method
|
|
508
571
|
result = response.get("result", {})
|
|
@@ -527,14 +590,41 @@ class SSETransport(MCPBaseTransport):
|
|
|
527
590
|
|
|
528
591
|
def _update_metrics(self, response_time: float, success: bool) -> None:
|
|
529
592
|
"""Update performance metrics."""
|
|
530
|
-
if
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
593
|
+
if not self._metrics:
|
|
594
|
+
return
|
|
595
|
+
|
|
596
|
+
self._metrics.update_call_metrics(response_time, success)
|
|
597
|
+
|
|
598
|
+
def _is_oauth_error(self, error_msg: str) -> bool:
|
|
599
|
+
"""
|
|
600
|
+
Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
|
|
601
|
+
|
|
602
|
+
Checks for:
|
|
603
|
+
- RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
|
|
604
|
+
- OAuth 2.1 token refresh errors (invalid_grant)
|
|
605
|
+
- MCP spec OAuth validation failures (401/403 responses)
|
|
606
|
+
"""
|
|
607
|
+
if not error_msg:
|
|
608
|
+
return False
|
|
534
609
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
610
|
+
error_lower = error_msg.lower()
|
|
611
|
+
oauth_indicators = [
|
|
612
|
+
# RFC 6750 Section 3.1 - Standard Bearer token errors
|
|
613
|
+
"invalid_token", # Token expired, revoked, malformed, or invalid
|
|
614
|
+
"insufficient_scope", # Request requires higher privileges (403 Forbidden)
|
|
615
|
+
# OAuth 2.1 token refresh errors
|
|
616
|
+
"invalid_grant", # Refresh token errors
|
|
617
|
+
# MCP spec - OAuth validation failures (401 Unauthorized)
|
|
618
|
+
"oauth validation",
|
|
619
|
+
"unauthorized",
|
|
620
|
+
# Common OAuth error descriptions
|
|
621
|
+
"expired token",
|
|
622
|
+
"token expired",
|
|
623
|
+
"authentication failed",
|
|
624
|
+
"invalid access token",
|
|
625
|
+
]
|
|
626
|
+
|
|
627
|
+
return any(indicator in error_lower for indicator in oauth_indicators)
|
|
538
628
|
|
|
539
629
|
async def list_resources(self) -> dict[str, Any]:
|
|
540
630
|
"""List available resources from the server."""
|
|
@@ -542,7 +632,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
542
632
|
return {}
|
|
543
633
|
|
|
544
634
|
try:
|
|
545
|
-
response = await self._send_request("resources/list", {}, timeout=
|
|
635
|
+
response = await self._send_request("resources/list", {}, timeout=self.timeout_config.operation)
|
|
546
636
|
if "error" in response:
|
|
547
637
|
logger.debug("Resources not supported: %s", response["error"])
|
|
548
638
|
return {}
|
|
@@ -557,7 +647,7 @@ class SSETransport(MCPBaseTransport):
|
|
|
557
647
|
return {}
|
|
558
648
|
|
|
559
649
|
try:
|
|
560
|
-
response = await self._send_request("prompts/list", {}, timeout=
|
|
650
|
+
response = await self._send_request("prompts/list", {}, timeout=self.timeout_config.operation)
|
|
561
651
|
if "error" in response:
|
|
562
652
|
logger.debug("Prompts not supported: %s", response["error"])
|
|
563
653
|
return {}
|
|
@@ -572,12 +662,12 @@ class SSETransport(MCPBaseTransport):
|
|
|
572
662
|
return
|
|
573
663
|
|
|
574
664
|
# Log final metrics
|
|
575
|
-
if self.enable_metrics and self._metrics
|
|
665
|
+
if self.enable_metrics and self._metrics and self._metrics.total_calls > 0:
|
|
576
666
|
logger.debug(
|
|
577
667
|
"SSE transport closing - Total calls: %d, Success rate: %.1f%%, Avg response time: %.3fs",
|
|
578
|
-
self._metrics
|
|
579
|
-
(self._metrics
|
|
580
|
-
self._metrics
|
|
668
|
+
self._metrics.total_calls,
|
|
669
|
+
(self._metrics.successful_calls / self._metrics.total_calls * 100),
|
|
670
|
+
self._metrics.avg_response_time,
|
|
581
671
|
)
|
|
582
672
|
|
|
583
673
|
await self._cleanup()
|
|
@@ -626,7 +716,10 @@ class SSETransport(MCPBaseTransport):
|
|
|
626
716
|
|
|
627
717
|
def get_metrics(self) -> dict[str, Any]:
|
|
628
718
|
"""Get performance and connection metrics with health info."""
|
|
629
|
-
|
|
719
|
+
if not self._metrics:
|
|
720
|
+
return {}
|
|
721
|
+
|
|
722
|
+
metrics = self._metrics.to_dict()
|
|
630
723
|
metrics.update(
|
|
631
724
|
{
|
|
632
725
|
"is_connected": self.is_connected(),
|
|
@@ -646,17 +739,20 @@ class SSETransport(MCPBaseTransport):
|
|
|
646
739
|
|
|
647
740
|
def reset_metrics(self) -> None:
|
|
648
741
|
"""Reset performance metrics."""
|
|
649
|
-
self._metrics
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
742
|
+
if not self._metrics:
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# Preserve important historical values
|
|
746
|
+
preserved_last_ping = self._metrics.last_ping_time
|
|
747
|
+
preserved_init_time = self._metrics.initialization_time
|
|
748
|
+
preserved_discoveries = self._metrics.session_discoveries
|
|
749
|
+
|
|
750
|
+
# Create new metrics instance with preserved values
|
|
751
|
+
self._metrics = TransportMetrics(
|
|
752
|
+
last_ping_time=preserved_last_ping,
|
|
753
|
+
initialization_time=preserved_init_time,
|
|
754
|
+
session_discoveries=preserved_discoveries,
|
|
755
|
+
)
|
|
660
756
|
|
|
661
757
|
def get_streams(self) -> list[tuple]:
|
|
662
758
|
"""SSE transport doesn't expose raw streams."""
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
import asyncio
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
import os
|
|
7
8
|
import time
|
|
8
9
|
from typing import Any
|
|
9
10
|
|
|
@@ -54,13 +55,28 @@ class StdioTransport(MCPBaseTransport):
|
|
|
54
55
|
"""
|
|
55
56
|
# Convert dict to StdioParameters if needed
|
|
56
57
|
if isinstance(server_params, dict):
|
|
58
|
+
# Merge provided env with system environment to ensure PATH is available
|
|
59
|
+
merged_env = os.environ.copy()
|
|
60
|
+
if server_params.get("env"):
|
|
61
|
+
merged_env.update(server_params["env"])
|
|
62
|
+
|
|
57
63
|
self.server_params = StdioParameters(
|
|
58
64
|
command=server_params.get("command", "python"),
|
|
59
65
|
args=server_params.get("args", []),
|
|
60
|
-
env=
|
|
66
|
+
env=merged_env,
|
|
61
67
|
)
|
|
62
68
|
else:
|
|
63
|
-
|
|
69
|
+
# Also handle StdioParameters object - merge env if provided
|
|
70
|
+
# Create a new StdioParameters with merged env (Pydantic models are immutable)
|
|
71
|
+
merged_env = os.environ.copy()
|
|
72
|
+
if hasattr(server_params, "env") and server_params.env:
|
|
73
|
+
merged_env.update(server_params.env)
|
|
74
|
+
|
|
75
|
+
self.server_params = StdioParameters(
|
|
76
|
+
command=server_params.command,
|
|
77
|
+
args=server_params.args,
|
|
78
|
+
env=merged_env,
|
|
79
|
+
)
|
|
64
80
|
|
|
65
81
|
self.connection_timeout = connection_timeout
|
|
66
82
|
self.default_timeout = default_timeout
|
|
@@ -184,7 +200,8 @@ class StdioTransport(MCPBaseTransport):
|
|
|
184
200
|
# Enhanced health verification (like SSE)
|
|
185
201
|
logger.debug("Verifying connection with ping...")
|
|
186
202
|
ping_start = time.time()
|
|
187
|
-
|
|
203
|
+
# Use default timeout for initial ping verification
|
|
204
|
+
ping_success = await asyncio.wait_for(send_ping(*self._streams), timeout=self.default_timeout)
|
|
188
205
|
ping_time = time.time() - ping_start
|
|
189
206
|
|
|
190
207
|
if ping_success:
|
|
@@ -204,7 +221,7 @@ class StdioTransport(MCPBaseTransport):
|
|
|
204
221
|
)
|
|
205
222
|
return True
|
|
206
223
|
else:
|
|
207
|
-
logger.
|
|
224
|
+
logger.debug("STDIO connection established but ping failed")
|
|
208
225
|
# Still consider it initialized
|
|
209
226
|
self._initialized = True
|
|
210
227
|
self._consecutive_failures = 1
|
|
@@ -212,7 +229,7 @@ class StdioTransport(MCPBaseTransport):
|
|
|
212
229
|
self._metrics["initialization_time"] = time.time() - start_time
|
|
213
230
|
return True
|
|
214
231
|
else:
|
|
215
|
-
logger.
|
|
232
|
+
logger.warning("STDIO initialization failed")
|
|
216
233
|
await self._cleanup()
|
|
217
234
|
return False
|
|
218
235
|
|
|
@@ -365,16 +382,47 @@ class StdioTransport(MCPBaseTransport):
|
|
|
365
382
|
async def get_tools(self) -> list[dict[str, Any]]:
|
|
366
383
|
"""Enhanced tools retrieval with recovery."""
|
|
367
384
|
if not self._initialized:
|
|
368
|
-
logger.
|
|
385
|
+
logger.debug("Cannot get tools: transport not initialized")
|
|
369
386
|
return []
|
|
370
387
|
|
|
371
388
|
start_time = time.time()
|
|
372
389
|
try:
|
|
373
390
|
response = await asyncio.wait_for(send_tools_list(*self._streams), timeout=self.default_timeout)
|
|
374
391
|
|
|
375
|
-
# Normalize response
|
|
376
|
-
if
|
|
377
|
-
|
|
392
|
+
# Normalize response - handle multiple formats including Pydantic models
|
|
393
|
+
# 1. Check if it's a Pydantic model with tools attribute (e.g., ListToolsResult from chuk_mcp)
|
|
394
|
+
if hasattr(response, "tools"):
|
|
395
|
+
tools = response.tools
|
|
396
|
+
# Convert Pydantic Tool models to dicts if needed
|
|
397
|
+
if tools and len(tools) > 0 and hasattr(tools[0], "model_dump"):
|
|
398
|
+
tools = [tool.model_dump() if hasattr(tool, "model_dump") else tool for tool in tools]
|
|
399
|
+
elif tools and len(tools) > 0 and hasattr(tools[0], "dict"):
|
|
400
|
+
# Older Pydantic versions use dict() instead of model_dump()
|
|
401
|
+
tools = [tool.dict() if hasattr(tool, "dict") else tool for tool in tools]
|
|
402
|
+
# 2. Check if it's a Pydantic model that can be dumped
|
|
403
|
+
elif hasattr(response, "model_dump"):
|
|
404
|
+
dumped = response.model_dump()
|
|
405
|
+
tools = dumped.get("tools", [])
|
|
406
|
+
# 3. Handle dict responses
|
|
407
|
+
elif isinstance(response, dict):
|
|
408
|
+
# Check for tools at top level
|
|
409
|
+
if "tools" in response:
|
|
410
|
+
tools = response["tools"]
|
|
411
|
+
# Check for nested result.tools (common in some MCP implementations)
|
|
412
|
+
elif "result" in response and isinstance(response["result"], dict):
|
|
413
|
+
tools = response["result"].get("tools", [])
|
|
414
|
+
# Check if response itself is the result with MCP structure
|
|
415
|
+
elif "jsonrpc" in response and "result" in response:
|
|
416
|
+
result = response["result"]
|
|
417
|
+
if isinstance(result, dict):
|
|
418
|
+
tools = result.get("tools", [])
|
|
419
|
+
elif isinstance(result, list):
|
|
420
|
+
tools = result
|
|
421
|
+
else:
|
|
422
|
+
tools = []
|
|
423
|
+
else:
|
|
424
|
+
tools = []
|
|
425
|
+
# 4. Handle list responses
|
|
378
426
|
elif isinstance(response, list):
|
|
379
427
|
tools = response
|
|
380
428
|
else:
|
|
@@ -426,7 +474,7 @@ class StdioTransport(MCPBaseTransport):
|
|
|
426
474
|
return {"isError": True, "error": "Failed to recover connection"}
|
|
427
475
|
|
|
428
476
|
response = await asyncio.wait_for(
|
|
429
|
-
send_tools_call(*self._streams, tool_name, arguments), timeout=tool_timeout
|
|
477
|
+
send_tools_call(*self._streams, tool_name, arguments, timeout=tool_timeout), timeout=tool_timeout
|
|
430
478
|
)
|
|
431
479
|
|
|
432
480
|
response_time = time.time() - start_time
|
|
@@ -1 +1,21 @@
|
|
|
1
1
|
# chuk_tool_processor/models/__init__.py
|
|
2
|
+
"""Data models for the tool processor."""
|
|
3
|
+
|
|
4
|
+
from chuk_tool_processor.models.execution_strategy import ExecutionStrategy
|
|
5
|
+
from chuk_tool_processor.models.streaming_tool import StreamingTool
|
|
6
|
+
from chuk_tool_processor.models.tool_call import ToolCall
|
|
7
|
+
from chuk_tool_processor.models.tool_result import ToolResult
|
|
8
|
+
from chuk_tool_processor.models.tool_spec import ToolCapability, ToolSpec, tool_spec
|
|
9
|
+
from chuk_tool_processor.models.validated_tool import ValidatedTool, with_validation
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ExecutionStrategy",
|
|
13
|
+
"StreamingTool",
|
|
14
|
+
"ToolCall",
|
|
15
|
+
"ToolResult",
|
|
16
|
+
"ToolSpec",
|
|
17
|
+
"ToolCapability",
|
|
18
|
+
"tool_spec",
|
|
19
|
+
"ValidatedTool",
|
|
20
|
+
"with_validation",
|
|
21
|
+
]
|
|
@@ -5,10 +5,12 @@ Model representing a tool call with arguments.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import hashlib
|
|
9
|
+
import json
|
|
8
10
|
import uuid
|
|
9
11
|
from typing import Any
|
|
10
12
|
|
|
11
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class ToolCall(BaseModel):
|
|
@@ -20,6 +22,7 @@ class ToolCall(BaseModel):
|
|
|
20
22
|
tool: Name of the tool to call
|
|
21
23
|
namespace: Namespace the tool belongs to
|
|
22
24
|
arguments: Arguments to pass to the tool
|
|
25
|
+
idempotency_key: Optional key for deduplicating duplicate calls (auto-generated)
|
|
23
26
|
"""
|
|
24
27
|
|
|
25
28
|
model_config = ConfigDict(extra="ignore")
|
|
@@ -28,6 +31,36 @@ class ToolCall(BaseModel):
|
|
|
28
31
|
tool: str = Field(..., min_length=1, description="Name of the tool to call; must be non-empty")
|
|
29
32
|
namespace: str = Field(default="default", description="Namespace the tool belongs to")
|
|
30
33
|
arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the tool")
|
|
34
|
+
idempotency_key: str | None = Field(
|
|
35
|
+
None,
|
|
36
|
+
description="Idempotency key for deduplication. Auto-generated if not provided.",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@model_validator(mode="after")
|
|
40
|
+
def generate_idempotency_key(self) -> ToolCall:
|
|
41
|
+
"""Generate idempotency key if not provided."""
|
|
42
|
+
if self.idempotency_key is None:
|
|
43
|
+
self.idempotency_key = self._compute_idempotency_key()
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
def _compute_idempotency_key(self) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Compute a stable idempotency key from tool name, namespace, and arguments.
|
|
49
|
+
|
|
50
|
+
Uses SHA256 hash of the sorted JSON representation.
|
|
51
|
+
Returns first 16 characters of the hex digest for brevity.
|
|
52
|
+
"""
|
|
53
|
+
# Create a stable representation
|
|
54
|
+
payload = {
|
|
55
|
+
"tool": self.tool,
|
|
56
|
+
"namespace": self.namespace,
|
|
57
|
+
"arguments": self.arguments,
|
|
58
|
+
}
|
|
59
|
+
# Sort keys for stability
|
|
60
|
+
json_str = json.dumps(payload, sort_keys=True, default=str)
|
|
61
|
+
# Hash it
|
|
62
|
+
hash_obj = hashlib.sha256(json_str.encode(), usedforsecurity=False)
|
|
63
|
+
return hash_obj.hexdigest()[:16] # Use first 16 chars for brevity
|
|
31
64
|
|
|
32
65
|
async def to_dict(self) -> dict[str, Any]:
|
|
33
66
|
"""Convert to a dictionary for serialization."""
|