chuk-tool-processor 0.6.13__py3-none-any.whl → 0.9.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (35) hide show
  1. chuk_tool_processor/core/__init__.py +31 -0
  2. chuk_tool_processor/core/exceptions.py +218 -12
  3. chuk_tool_processor/core/processor.py +38 -7
  4. chuk_tool_processor/execution/strategies/__init__.py +6 -0
  5. chuk_tool_processor/execution/strategies/subprocess_strategy.py +2 -1
  6. chuk_tool_processor/execution/wrappers/__init__.py +42 -0
  7. chuk_tool_processor/execution/wrappers/caching.py +48 -13
  8. chuk_tool_processor/execution/wrappers/circuit_breaker.py +370 -0
  9. chuk_tool_processor/execution/wrappers/rate_limiting.py +31 -1
  10. chuk_tool_processor/execution/wrappers/retry.py +93 -53
  11. chuk_tool_processor/logging/metrics.py +2 -2
  12. chuk_tool_processor/mcp/mcp_tool.py +5 -5
  13. chuk_tool_processor/mcp/setup_mcp_http_streamable.py +44 -2
  14. chuk_tool_processor/mcp/setup_mcp_sse.py +44 -2
  15. chuk_tool_processor/mcp/setup_mcp_stdio.py +2 -0
  16. chuk_tool_processor/mcp/stream_manager.py +130 -75
  17. chuk_tool_processor/mcp/transport/__init__.py +10 -0
  18. chuk_tool_processor/mcp/transport/http_streamable_transport.py +193 -108
  19. chuk_tool_processor/mcp/transport/models.py +100 -0
  20. chuk_tool_processor/mcp/transport/sse_transport.py +155 -59
  21. chuk_tool_processor/mcp/transport/stdio_transport.py +58 -10
  22. chuk_tool_processor/models/__init__.py +20 -0
  23. chuk_tool_processor/models/tool_call.py +34 -1
  24. chuk_tool_processor/models/tool_spec.py +350 -0
  25. chuk_tool_processor/models/validated_tool.py +22 -2
  26. chuk_tool_processor/observability/__init__.py +30 -0
  27. chuk_tool_processor/observability/metrics.py +312 -0
  28. chuk_tool_processor/observability/setup.py +105 -0
  29. chuk_tool_processor/observability/tracing.py +345 -0
  30. chuk_tool_processor/plugins/discovery.py +1 -1
  31. chuk_tool_processor-0.9.7.dist-info/METADATA +1813 -0
  32. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/RECORD +34 -27
  33. chuk_tool_processor-0.6.13.dist-info/METADATA +0 -698
  34. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/WHEEL +0 -0
  35. {chuk_tool_processor-0.6.13.dist-info → chuk_tool_processor-0.9.7.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from typing import Any
19
19
  import httpx
20
20
 
21
21
  from .base_transport import MCPBaseTransport
22
+ from .models import TimeoutConfig, TransportMetrics
22
23
 
23
24
  logger = logging.getLogger(__name__)
24
25
 
@@ -38,6 +39,8 @@ class SSETransport(MCPBaseTransport):
38
39
  connection_timeout: float = 30.0,
39
40
  default_timeout: float = 60.0,
40
41
  enable_metrics: bool = True,
42
+ oauth_refresh_callback: Any | None = None,
43
+ timeout_config: TimeoutConfig | None = None,
41
44
  ):
42
45
  """
43
46
  Initialize SSE transport.
@@ -45,9 +48,16 @@ class SSETransport(MCPBaseTransport):
45
48
  self.url = url.rstrip("/")
46
49
  self.api_key = api_key
47
50
  self.configured_headers = headers or {}
48
- self.connection_timeout = connection_timeout
49
- self.default_timeout = default_timeout
50
51
  self.enable_metrics = enable_metrics
52
+ self.oauth_refresh_callback = oauth_refresh_callback
53
+
54
+ # Use timeout config or create from individual parameters
55
+ if timeout_config is None:
56
+ timeout_config = TimeoutConfig(connect=connection_timeout, operation=default_timeout)
57
+
58
+ self.timeout_config = timeout_config
59
+ self.connection_timeout = timeout_config.connect
60
+ self.default_timeout = timeout_config.operation
51
61
 
52
62
  logger.debug("SSE Transport initialized with URL: %s", self.url)
53
63
 
@@ -73,18 +83,8 @@ class SSETransport(MCPBaseTransport):
73
83
  self._connection_grace_period = 30.0 # NEW: Grace period after initialization
74
84
  self._initialization_time = None # NEW: Track when we initialized
75
85
 
76
- # Performance metrics
77
- self._metrics = {
78
- "total_calls": 0,
79
- "successful_calls": 0,
80
- "failed_calls": 0,
81
- "total_time": 0.0,
82
- "avg_response_time": 0.0,
83
- "last_ping_time": None,
84
- "initialization_time": None,
85
- "session_discoveries": 0,
86
- "stream_errors": 0,
87
- }
86
+ # Performance metrics - use Pydantic model
87
+ self._metrics = TransportMetrics() if enable_metrics else None
88
88
 
89
89
  def _construct_sse_url(self, base_url: str) -> str:
90
90
  """Construct the SSE endpoint URL from the base URL."""
@@ -110,8 +110,9 @@ class SSETransport(MCPBaseTransport):
110
110
  if self.configured_headers:
111
111
  headers.update(self.configured_headers)
112
112
 
113
- # Add API key as Bearer token if provided
114
- if self.api_key:
113
+ # Add API key as Bearer token if provided and no Authorization header exists
114
+ # This prevents clobbering OAuth tokens from configured_headers
115
+ if self.api_key and "Authorization" not in headers:
115
116
  headers["Authorization"] = f"Bearer {self.api_key}"
116
117
 
117
118
  return headers
@@ -173,7 +174,7 @@ class SSETransport(MCPBaseTransport):
173
174
 
174
175
  # Wait for session discovery
175
176
  logger.debug("Waiting for session discovery...")
176
- session_timeout = 10.0
177
+ session_timeout = self.timeout_config.connect
177
178
  session_start = time.time()
178
179
 
179
180
  while not self.message_url and (time.time() - session_start) < session_timeout:
@@ -183,17 +184,17 @@ class SSETransport(MCPBaseTransport):
183
184
  if self.sse_task.done():
184
185
  exception = self.sse_task.exception()
185
186
  if exception:
186
- logger.error(f"SSE task died during session discovery: {exception}")
187
+ logger.debug(f"SSE task died during session discovery: {exception}")
187
188
  await self._cleanup()
188
189
  return False
189
190
 
190
191
  if not self.message_url:
191
- logger.error("Failed to discover session endpoint within %.1fs", session_timeout)
192
+ logger.warning("Failed to discover session endpoint within %.1fs", session_timeout)
192
193
  await self._cleanup()
193
194
  return False
194
195
 
195
- if self.enable_metrics:
196
- self._metrics["session_discoveries"] += 1
196
+ if self.enable_metrics and self._metrics:
197
+ self._metrics.session_discoveries += 1
197
198
 
198
199
  logger.debug("Session endpoint discovered: %s", self.message_url)
199
200
 
@@ -210,7 +211,7 @@ class SSETransport(MCPBaseTransport):
210
211
  )
211
212
 
212
213
  if "error" in init_response:
213
- logger.error("MCP initialize failed: %s", init_response["error"])
214
+ logger.warning("MCP initialize failed: %s", init_response["error"])
214
215
  await self._cleanup()
215
216
  return False
216
217
 
@@ -223,9 +224,9 @@ class SSETransport(MCPBaseTransport):
223
224
  self._last_successful_ping = time.time()
224
225
  self._consecutive_failures = 0 # Reset failure count
225
226
 
226
- if self.enable_metrics:
227
+ if self.enable_metrics and self._metrics:
227
228
  init_time = time.time() - start_time
228
- self._metrics["initialization_time"] = init_time
229
+ self._metrics.initialization_time = init_time
229
230
 
230
231
  logger.debug("SSE transport initialized successfully in %.3fs", time.time() - start_time)
231
232
  return True
@@ -269,12 +270,30 @@ class SSETransport(MCPBaseTransport):
269
270
  # Extract session ID from URL if present
270
271
  if "session_id=" in data_part:
271
272
  self.session_id = data_part.split("session_id=")[1].split("&")[0]
273
+ elif "sessionId=" in data_part:
274
+ self.session_id = data_part.split("sessionId=")[1].split("&")[0]
272
275
  else:
273
276
  self.session_id = str(uuid.uuid4())
274
277
 
275
278
  logger.debug("Session endpoint discovered via event format: %s", self.message_url)
276
279
  continue
277
280
 
281
+ # RELATIVE PATH FORMAT: event: endpoint + data: /sse/message?sessionId=...
282
+ elif current_event == "endpoint" and data_part.startswith("/"):
283
+ endpoint_path = data_part
284
+ self.message_url = f"{self.url}{endpoint_path}"
285
+
286
+ # Extract session ID if present
287
+ if "session_id=" in endpoint_path:
288
+ self.session_id = endpoint_path.split("session_id=")[1].split("&")[0]
289
+ elif "sessionId=" in endpoint_path:
290
+ self.session_id = endpoint_path.split("sessionId=")[1].split("&")[0]
291
+ else:
292
+ self.session_id = str(uuid.uuid4())
293
+
294
+ logger.debug("Session endpoint discovered via relative path: %s", self.message_url)
295
+ continue
296
+
278
297
  # OLD FORMAT: data: /messages/... (backwards compatibility)
279
298
  elif "/messages/" in data_part:
280
299
  endpoint_path = data_part
@@ -315,8 +334,8 @@ class SSETransport(MCPBaseTransport):
315
334
  logger.debug("Non-JSON data in SSE stream (ignoring): %s", e)
316
335
 
317
336
  except Exception as e:
318
- if self.enable_metrics:
319
- self._metrics["stream_errors"] += 1
337
+ if self.enable_metrics and self._metrics:
338
+ self._metrics.stream_errors += 1
320
339
  logger.error("SSE stream processing error: %s", e)
321
340
  # FIXED: Don't increment consecutive failures for stream processing errors
322
341
  # These are often temporary and don't indicate connection health
@@ -400,7 +419,7 @@ class SSETransport(MCPBaseTransport):
400
419
  start_time = time.time()
401
420
  try:
402
421
  # Use tools/list as a lightweight ping since not all servers support ping
403
- response = await self._send_request("tools/list", {}, timeout=10.0)
422
+ response = await self._send_request("tools/list", {}, timeout=self.timeout_config.quick)
404
423
 
405
424
  success = "error" not in response
406
425
 
@@ -408,9 +427,9 @@ class SSETransport(MCPBaseTransport):
408
427
  self._last_successful_ping = time.time()
409
428
  # FIXED: Don't reset consecutive failures here - let tool calls do that
410
429
 
411
- if self.enable_metrics:
430
+ if self.enable_metrics and self._metrics:
412
431
  ping_time = time.time() - start_time
413
- self._metrics["last_ping_time"] = ping_time
432
+ self._metrics.last_ping_time = ping_time
414
433
  logger.debug("SSE ping completed in %.3fs: %s", ping_time, success)
415
434
 
416
435
  return success
@@ -457,7 +476,7 @@ class SSETransport(MCPBaseTransport):
457
476
  async def get_tools(self) -> list[dict[str, Any]]:
458
477
  """Get list of available tools from the server."""
459
478
  if not self._initialized:
460
- logger.error("Cannot get tools: transport not initialized")
479
+ logger.debug("Cannot get tools: transport not initialized")
461
480
  return []
462
481
 
463
482
  start_time = time.time()
@@ -465,7 +484,7 @@ class SSETransport(MCPBaseTransport):
465
484
  response = await self._send_request("tools/list", {})
466
485
 
467
486
  if "error" in response:
468
- logger.error("Error getting tools: %s", response["error"])
487
+ logger.warning("Error getting tools: %s", response["error"])
469
488
  return []
470
489
 
471
490
  tools = response.get("result", {}).get("tools", [])
@@ -488,8 +507,8 @@ class SSETransport(MCPBaseTransport):
488
507
  return {"isError": True, "error": "Transport not initialized"}
489
508
 
490
509
  start_time = time.time()
491
- if self.enable_metrics:
492
- self._metrics["total_calls"] += 1
510
+ if self.enable_metrics and self._metrics:
511
+ self._metrics.total_calls += 1
493
512
 
494
513
  try:
495
514
  logger.debug("Calling tool '%s' with arguments: %s", tool_name, arguments)
@@ -498,11 +517,55 @@ class SSETransport(MCPBaseTransport):
498
517
  "tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
499
518
  )
500
519
 
520
+ # Check for errors
501
521
  if "error" in response:
522
+ error_msg = response["error"].get("message", "Unknown error")
523
+
524
+ # NEW: Check for OAuth errors and attempt refresh if callback is available
525
+ if self._is_oauth_error(error_msg):
526
+ logger.warning("OAuth error detected: %s", error_msg)
527
+
528
+ if self.oauth_refresh_callback:
529
+ logger.debug("Attempting OAuth token refresh...")
530
+ try:
531
+ # Call the refresh callback
532
+ new_headers = await self.oauth_refresh_callback()
533
+
534
+ if new_headers and "Authorization" in new_headers:
535
+ # Update configured headers with new token
536
+ self.configured_headers.update(new_headers)
537
+ logger.debug("OAuth token refreshed, retrying tool call...")
538
+
539
+ # Retry the tool call once with new token
540
+ response = await self._send_request(
541
+ "tools/call", {"name": tool_name, "arguments": arguments}, timeout=timeout
542
+ )
543
+
544
+ # Check if retry succeeded
545
+ if "error" not in response:
546
+ logger.debug("Tool call succeeded after token refresh")
547
+ result = response.get("result", {})
548
+ normalized_result = self._normalize_mcp_response({"result": result})
549
+
550
+ if self.enable_metrics:
551
+ self._update_metrics(time.time() - start_time, True)
552
+
553
+ return normalized_result
554
+ else:
555
+ error_msg = response["error"].get("message", "Unknown error")
556
+ logger.error("Tool call failed after token refresh: %s", error_msg)
557
+ else:
558
+ logger.warning("Token refresh did not return valid Authorization header")
559
+ except Exception as refresh_error:
560
+ logger.error("OAuth token refresh failed: %s", refresh_error)
561
+ else:
562
+ logger.warning("OAuth error detected but no refresh callback configured")
563
+
564
+ # Return error (original or from failed retry)
502
565
  if self.enable_metrics:
503
566
  self._update_metrics(time.time() - start_time, False)
504
567
 
505
- return {"isError": True, "error": response["error"].get("message", "Unknown error")}
568
+ return {"isError": True, "error": error_msg}
506
569
 
507
570
  # Extract and normalize result using base class method
508
571
  result = response.get("result", {})
@@ -527,14 +590,41 @@ class SSETransport(MCPBaseTransport):
527
590
 
528
591
  def _update_metrics(self, response_time: float, success: bool) -> None:
529
592
  """Update performance metrics."""
530
- if success:
531
- self._metrics["successful_calls"] += 1
532
- else:
533
- self._metrics["failed_calls"] += 1
593
+ if not self._metrics:
594
+ return
595
+
596
+ self._metrics.update_call_metrics(response_time, success)
597
+
598
+ def _is_oauth_error(self, error_msg: str) -> bool:
599
+ """
600
+ Detect if error is OAuth-related per RFC 6750 and MCP OAuth spec.
601
+
602
+ Checks for:
603
+ - RFC 6750 Section 3.1 Bearer token errors (invalid_token, insufficient_scope)
604
+ - OAuth 2.1 token refresh errors (invalid_grant)
605
+ - MCP spec OAuth validation failures (401/403 responses)
606
+ """
607
+ if not error_msg:
608
+ return False
534
609
 
535
- self._metrics["total_time"] += response_time
536
- if self._metrics["total_calls"] > 0:
537
- self._metrics["avg_response_time"] = self._metrics["total_time"] / self._metrics["total_calls"]
610
+ error_lower = error_msg.lower()
611
+ oauth_indicators = [
612
+ # RFC 6750 Section 3.1 - Standard Bearer token errors
613
+ "invalid_token", # Token expired, revoked, malformed, or invalid
614
+ "insufficient_scope", # Request requires higher privileges (403 Forbidden)
615
+ # OAuth 2.1 token refresh errors
616
+ "invalid_grant", # Refresh token errors
617
+ # MCP spec - OAuth validation failures (401 Unauthorized)
618
+ "oauth validation",
619
+ "unauthorized",
620
+ # Common OAuth error descriptions
621
+ "expired token",
622
+ "token expired",
623
+ "authentication failed",
624
+ "invalid access token",
625
+ ]
626
+
627
+ return any(indicator in error_lower for indicator in oauth_indicators)
538
628
 
539
629
  async def list_resources(self) -> dict[str, Any]:
540
630
  """List available resources from the server."""
@@ -542,7 +632,7 @@ class SSETransport(MCPBaseTransport):
542
632
  return {}
543
633
 
544
634
  try:
545
- response = await self._send_request("resources/list", {}, timeout=10.0)
635
+ response = await self._send_request("resources/list", {}, timeout=self.timeout_config.operation)
546
636
  if "error" in response:
547
637
  logger.debug("Resources not supported: %s", response["error"])
548
638
  return {}
@@ -557,7 +647,7 @@ class SSETransport(MCPBaseTransport):
557
647
  return {}
558
648
 
559
649
  try:
560
- response = await self._send_request("prompts/list", {}, timeout=10.0)
650
+ response = await self._send_request("prompts/list", {}, timeout=self.timeout_config.operation)
561
651
  if "error" in response:
562
652
  logger.debug("Prompts not supported: %s", response["error"])
563
653
  return {}
@@ -572,12 +662,12 @@ class SSETransport(MCPBaseTransport):
572
662
  return
573
663
 
574
664
  # Log final metrics
575
- if self.enable_metrics and self._metrics["total_calls"] > 0:
665
+ if self.enable_metrics and self._metrics and self._metrics.total_calls > 0:
576
666
  logger.debug(
577
667
  "SSE transport closing - Total calls: %d, Success rate: %.1f%%, Avg response time: %.3fs",
578
- self._metrics["total_calls"],
579
- (self._metrics["successful_calls"] / self._metrics["total_calls"] * 100),
580
- self._metrics["avg_response_time"],
668
+ self._metrics.total_calls,
669
+ (self._metrics.successful_calls / self._metrics.total_calls * 100),
670
+ self._metrics.avg_response_time,
581
671
  )
582
672
 
583
673
  await self._cleanup()
@@ -626,7 +716,10 @@ class SSETransport(MCPBaseTransport):
626
716
 
627
717
  def get_metrics(self) -> dict[str, Any]:
628
718
  """Get performance and connection metrics with health info."""
629
- metrics = self._metrics.copy()
719
+ if not self._metrics:
720
+ return {}
721
+
722
+ metrics = self._metrics.to_dict()
630
723
  metrics.update(
631
724
  {
632
725
  "is_connected": self.is_connected(),
@@ -646,17 +739,20 @@ class SSETransport(MCPBaseTransport):
646
739
 
647
740
  def reset_metrics(self) -> None:
648
741
  """Reset performance metrics."""
649
- self._metrics = {
650
- "total_calls": 0,
651
- "successful_calls": 0,
652
- "failed_calls": 0,
653
- "total_time": 0.0,
654
- "avg_response_time": 0.0,
655
- "last_ping_time": self._metrics.get("last_ping_time"),
656
- "initialization_time": self._metrics.get("initialization_time"),
657
- "session_discoveries": self._metrics.get("session_discoveries", 0),
658
- "stream_errors": 0,
659
- }
742
+ if not self._metrics:
743
+ return
744
+
745
+ # Preserve important historical values
746
+ preserved_last_ping = self._metrics.last_ping_time
747
+ preserved_init_time = self._metrics.initialization_time
748
+ preserved_discoveries = self._metrics.session_discoveries
749
+
750
+ # Create new metrics instance with preserved values
751
+ self._metrics = TransportMetrics(
752
+ last_ping_time=preserved_last_ping,
753
+ initialization_time=preserved_init_time,
754
+ session_discoveries=preserved_discoveries,
755
+ )
660
756
 
661
757
  def get_streams(self) -> list[tuple]:
662
758
  """SSE transport doesn't expose raw streams."""
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
  import asyncio
5
5
  import json
6
6
  import logging
7
+ import os
7
8
  import time
8
9
  from typing import Any
9
10
 
@@ -54,13 +55,28 @@ class StdioTransport(MCPBaseTransport):
54
55
  """
55
56
  # Convert dict to StdioParameters if needed
56
57
  if isinstance(server_params, dict):
58
+ # Merge provided env with system environment to ensure PATH is available
59
+ merged_env = os.environ.copy()
60
+ if server_params.get("env"):
61
+ merged_env.update(server_params["env"])
62
+
57
63
  self.server_params = StdioParameters(
58
64
  command=server_params.get("command", "python"),
59
65
  args=server_params.get("args", []),
60
- env=server_params.get("env"),
66
+ env=merged_env,
61
67
  )
62
68
  else:
63
- self.server_params = server_params
69
+ # Also handle StdioParameters object - merge env if provided
70
+ # Create a new StdioParameters with merged env (Pydantic models are immutable)
71
+ merged_env = os.environ.copy()
72
+ if hasattr(server_params, "env") and server_params.env:
73
+ merged_env.update(server_params.env)
74
+
75
+ self.server_params = StdioParameters(
76
+ command=server_params.command,
77
+ args=server_params.args,
78
+ env=merged_env,
79
+ )
64
80
 
65
81
  self.connection_timeout = connection_timeout
66
82
  self.default_timeout = default_timeout
@@ -184,7 +200,8 @@ class StdioTransport(MCPBaseTransport):
184
200
  # Enhanced health verification (like SSE)
185
201
  logger.debug("Verifying connection with ping...")
186
202
  ping_start = time.time()
187
- ping_success = await asyncio.wait_for(send_ping(*self._streams), timeout=10.0)
203
+ # Use default timeout for initial ping verification
204
+ ping_success = await asyncio.wait_for(send_ping(*self._streams), timeout=self.default_timeout)
188
205
  ping_time = time.time() - ping_start
189
206
 
190
207
  if ping_success:
@@ -204,7 +221,7 @@ class StdioTransport(MCPBaseTransport):
204
221
  )
205
222
  return True
206
223
  else:
207
- logger.warning("STDIO connection established but ping failed")
224
+ logger.debug("STDIO connection established but ping failed")
208
225
  # Still consider it initialized
209
226
  self._initialized = True
210
227
  self._consecutive_failures = 1
@@ -212,7 +229,7 @@ class StdioTransport(MCPBaseTransport):
212
229
  self._metrics["initialization_time"] = time.time() - start_time
213
230
  return True
214
231
  else:
215
- logger.error("STDIO initialization failed")
232
+ logger.warning("STDIO initialization failed")
216
233
  await self._cleanup()
217
234
  return False
218
235
 
@@ -365,16 +382,47 @@ class StdioTransport(MCPBaseTransport):
365
382
  async def get_tools(self) -> list[dict[str, Any]]:
366
383
  """Enhanced tools retrieval with recovery."""
367
384
  if not self._initialized:
368
- logger.error("Cannot get tools: transport not initialized")
385
+ logger.debug("Cannot get tools: transport not initialized")
369
386
  return []
370
387
 
371
388
  start_time = time.time()
372
389
  try:
373
390
  response = await asyncio.wait_for(send_tools_list(*self._streams), timeout=self.default_timeout)
374
391
 
375
- # Normalize response
376
- if isinstance(response, dict):
377
- tools = response.get("tools", [])
392
+ # Normalize response - handle multiple formats including Pydantic models
393
+ # 1. Check if it's a Pydantic model with tools attribute (e.g., ListToolsResult from chuk_mcp)
394
+ if hasattr(response, "tools"):
395
+ tools = response.tools
396
+ # Convert Pydantic Tool models to dicts if needed
397
+ if tools and len(tools) > 0 and hasattr(tools[0], "model_dump"):
398
+ tools = [tool.model_dump() if hasattr(tool, "model_dump") else tool for tool in tools]
399
+ elif tools and len(tools) > 0 and hasattr(tools[0], "dict"):
400
+ # Older Pydantic versions use dict() instead of model_dump()
401
+ tools = [tool.dict() if hasattr(tool, "dict") else tool for tool in tools]
402
+ # 2. Check if it's a Pydantic model that can be dumped
403
+ elif hasattr(response, "model_dump"):
404
+ dumped = response.model_dump()
405
+ tools = dumped.get("tools", [])
406
+ # 3. Handle dict responses
407
+ elif isinstance(response, dict):
408
+ # Check for tools at top level
409
+ if "tools" in response:
410
+ tools = response["tools"]
411
+ # Check for nested result.tools (common in some MCP implementations)
412
+ elif "result" in response and isinstance(response["result"], dict):
413
+ tools = response["result"].get("tools", [])
414
+ # Check if response itself is the result with MCP structure
415
+ elif "jsonrpc" in response and "result" in response:
416
+ result = response["result"]
417
+ if isinstance(result, dict):
418
+ tools = result.get("tools", [])
419
+ elif isinstance(result, list):
420
+ tools = result
421
+ else:
422
+ tools = []
423
+ else:
424
+ tools = []
425
+ # 4. Handle list responses
378
426
  elif isinstance(response, list):
379
427
  tools = response
380
428
  else:
@@ -426,7 +474,7 @@ class StdioTransport(MCPBaseTransport):
426
474
  return {"isError": True, "error": "Failed to recover connection"}
427
475
 
428
476
  response = await asyncio.wait_for(
429
- send_tools_call(*self._streams, tool_name, arguments), timeout=tool_timeout
477
+ send_tools_call(*self._streams, tool_name, arguments, timeout=tool_timeout), timeout=tool_timeout
430
478
  )
431
479
 
432
480
  response_time = time.time() - start_time
@@ -1 +1,21 @@
1
1
  # chuk_tool_processor/models/__init__.py
2
+ """Data models for the tool processor."""
3
+
4
+ from chuk_tool_processor.models.execution_strategy import ExecutionStrategy
5
+ from chuk_tool_processor.models.streaming_tool import StreamingTool
6
+ from chuk_tool_processor.models.tool_call import ToolCall
7
+ from chuk_tool_processor.models.tool_result import ToolResult
8
+ from chuk_tool_processor.models.tool_spec import ToolCapability, ToolSpec, tool_spec
9
+ from chuk_tool_processor.models.validated_tool import ValidatedTool, with_validation
10
+
11
+ __all__ = [
12
+ "ExecutionStrategy",
13
+ "StreamingTool",
14
+ "ToolCall",
15
+ "ToolResult",
16
+ "ToolSpec",
17
+ "ToolCapability",
18
+ "tool_spec",
19
+ "ValidatedTool",
20
+ "with_validation",
21
+ ]
@@ -5,10 +5,12 @@ Model representing a tool call with arguments.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import hashlib
9
+ import json
8
10
  import uuid
9
11
  from typing import Any
10
12
 
11
- from pydantic import BaseModel, ConfigDict, Field
13
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
12
14
 
13
15
 
14
16
  class ToolCall(BaseModel):
@@ -20,6 +22,7 @@ class ToolCall(BaseModel):
20
22
  tool: Name of the tool to call
21
23
  namespace: Namespace the tool belongs to
22
24
  arguments: Arguments to pass to the tool
25
+ idempotency_key: Optional key for deduplicating duplicate calls (auto-generated)
23
26
  """
24
27
 
25
28
  model_config = ConfigDict(extra="ignore")
@@ -28,6 +31,36 @@ class ToolCall(BaseModel):
28
31
  tool: str = Field(..., min_length=1, description="Name of the tool to call; must be non-empty")
29
32
  namespace: str = Field(default="default", description="Namespace the tool belongs to")
30
33
  arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the tool")
34
+ idempotency_key: str | None = Field(
35
+ None,
36
+ description="Idempotency key for deduplication. Auto-generated if not provided.",
37
+ )
38
+
39
+ @model_validator(mode="after")
40
+ def generate_idempotency_key(self) -> ToolCall:
41
+ """Generate idempotency key if not provided."""
42
+ if self.idempotency_key is None:
43
+ self.idempotency_key = self._compute_idempotency_key()
44
+ return self
45
+
46
+ def _compute_idempotency_key(self) -> str:
47
+ """
48
+ Compute a stable idempotency key from tool name, namespace, and arguments.
49
+
50
+ Uses SHA256 hash of the sorted JSON representation.
51
+ Returns first 16 characters of the hex digest for brevity.
52
+ """
53
+ # Create a stable representation
54
+ payload = {
55
+ "tool": self.tool,
56
+ "namespace": self.namespace,
57
+ "arguments": self.arguments,
58
+ }
59
+ # Sort keys for stability
60
+ json_str = json.dumps(payload, sort_keys=True, default=str)
61
+ # Hash it
62
+ hash_obj = hashlib.sha256(json_str.encode(), usedforsecurity=False)
63
+ return hash_obj.hexdigest()[:16] # Use first 16 chars for brevity
31
64
 
32
65
  async def to_dict(self) -> dict[str, Any]:
33
66
  """Convert to a dictionary for serialization."""