chuk-tool-processor 0.4.1__py3-none-any.whl → 0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -3,6 +3,8 @@
3
3
  """
4
4
  MCP tool shim that delegates execution to a StreamManager,
5
5
  handling its own lazy bootstrap when needed.
6
+
7
+ FIXED: Added subprocess serialization support by implementing __getstate__ and __setstate__
6
8
  """
7
9
  from __future__ import annotations
8
10
 
@@ -24,12 +26,14 @@ class MCPTool:
24
26
 
25
27
  If no ``StreamManager`` is supplied the class will start one on first
26
28
  use via ``setup_mcp_stdio``.
29
+
30
+ FIXED: Added serialization support for subprocess execution.
27
31
  """
28
32
 
29
33
  # ------------------------------------------------------------------ #
30
34
  def __init__(
31
35
  self,
32
- tool_name: str,
36
+ tool_name: str = "",
33
37
  stream_manager: Optional[StreamManager] = None,
34
38
  *,
35
39
  cfg_file: str = "",
@@ -38,6 +42,13 @@ class MCPTool:
38
42
  namespace: str = "stdio",
39
43
  default_timeout: Optional[float] = None
40
44
  ) -> None:
45
+ if not tool_name:
46
+ raise ValueError(
47
+ "MCPTool requires a tool_name. "
48
+ "This error usually occurs during subprocess serialization. "
49
+ "Make sure the tool is properly registered with a name."
50
+ )
51
+
41
52
  self.tool_name = tool_name
42
53
  self._sm: Optional[StreamManager] = stream_manager
43
54
  self.default_timeout = default_timeout
@@ -48,7 +59,68 @@ class MCPTool:
48
59
  self._server_names = server_names or {}
49
60
  self._namespace = namespace
50
61
 
51
- self._sm_lock = asyncio.Lock()
62
+ # Create lock only when needed (not during deserialization)
63
+ self._sm_lock: Optional[asyncio.Lock] = None
64
+
65
+ def _ensure_lock(self) -> asyncio.Lock:
66
+ """Ensure the lock exists, creating it if necessary."""
67
+ if self._sm_lock is None:
68
+ self._sm_lock = asyncio.Lock()
69
+ return self._sm_lock
70
+
71
+ # ------------------------------------------------------------------ #
72
+ # Serialization support for subprocess execution
73
+ # ------------------------------------------------------------------ #
74
+ def __getstate__(self) -> Dict[str, Any]:
75
+ """
76
+ Custom serialization for pickle support.
77
+
78
+ Excludes non-serializable async components and stream manager.
79
+ The subprocess will recreate these as needed.
80
+ """
81
+ state = self.__dict__.copy()
82
+
83
+ # Remove non-serializable items
84
+ state['_sm'] = None # StreamManager will be recreated in subprocess
85
+ state['_sm_lock'] = None # Lock will be recreated when needed
86
+
87
+ # Ensure we have the necessary configuration for subprocess
88
+ # If no servers specified, default to the tool name (common pattern)
89
+ if not state.get('_servers'):
90
+ # Extract server name from tool_name (e.g., "get_current_time" -> "time")
91
+ # This is a heuristic - adjust based on your naming convention
92
+ if 'time' in self.tool_name.lower():
93
+ state['_servers'] = ['time']
94
+ state['_server_names'] = {0: 'time'}
95
+ else:
96
+ # Default fallback - use the tool name itself
97
+ state['_servers'] = [self.tool_name]
98
+ state['_server_names'] = {0: self.tool_name}
99
+
100
+ # Ensure we have a config file path
101
+ if not state.get('_cfg_file'):
102
+ state['_cfg_file'] = 'server_config.json'
103
+
104
+ logger.debug(f"Serializing MCPTool '{self.tool_name}' for subprocess with servers: {state['_servers']}")
105
+ return state
106
+
107
+ def __setstate__(self, state: Dict[str, Any]) -> None:
108
+ """
109
+ Custom deserialization for pickle support.
110
+
111
+ Restores the object state and ensures required fields are set.
112
+ """
113
+ self.__dict__.update(state)
114
+
115
+ # Ensure critical fields exist
116
+ if not hasattr(self, 'tool_name') or not self.tool_name:
117
+ raise ValueError("Invalid MCPTool state: missing tool_name")
118
+
119
+ # Initialize transient fields
120
+ self._sm = None
121
+ self._sm_lock = None
122
+
123
+ logger.debug(f"Deserialized MCPTool '{self.tool_name}' in subprocess")
52
124
 
53
125
  # ------------------------------------------------------------------ #
54
126
  async def _ensure_stream_manager(self) -> StreamManager:
@@ -61,7 +133,8 @@ class MCPTool:
61
133
  if self._sm is not None:
62
134
  return self._sm
63
135
 
64
- async with self._sm_lock:
136
+ # Use the lock, creating it if needed
137
+ async with self._ensure_lock():
65
138
  if self._sm is None: # re-check inside lock
66
139
  logger.info(
67
140
  "Boot-strapping MCP stdio transport for '%s'", self.tool_name
@@ -139,9 +212,32 @@ class MCPTool:
139
212
 
140
213
  return result.get("content")
141
214
 
142
-
143
215
  # ------------------------------------------------------------------ #
144
216
  # Legacy method name support
145
217
  async def _aexecute(self, timeout: Optional[float] = None, **kwargs: Any) -> Any:
146
218
  """Legacy alias for execute() method."""
147
- return await self.execute(timeout=timeout, **kwargs)
219
+ return await self.execute(timeout=timeout, **kwargs)
220
+
221
+ # ------------------------------------------------------------------ #
222
+ # Utility methods for debugging
223
+ # ------------------------------------------------------------------ #
224
+ def is_serializable(self) -> bool:
225
+ """Check if this tool can be serialized (for debugging)."""
226
+ try:
227
+ import pickle
228
+ pickle.dumps(self)
229
+ return True
230
+ except Exception:
231
+ return False
232
+
233
+ def get_serialization_info(self) -> Dict[str, Any]:
234
+ """Get information about what would be serialized."""
235
+ state = self.__getstate__()
236
+ return {
237
+ "tool_name": state.get("tool_name"),
238
+ "namespace": state.get("_namespace"),
239
+ "servers": state.get("_servers"),
240
+ "cfg_file": state.get("_cfg_file"),
241
+ "has_stream_manager": state.get("_sm") is not None,
242
+ "serializable_size": len(str(state))
243
+ }
@@ -1,3 +1,4 @@
1
+ #!/usr/bin/env python
1
2
  # chuk_tool_processor/mcp/stream_manager.py
2
3
  """
3
4
  StreamManager for CHUK Tool Processor.
@@ -36,6 +37,7 @@ class StreamManager:
36
37
  self.server_names: Dict[int, str] = {}
37
38
  self.all_tools: List[Dict[str, Any]] = []
38
39
  self._lock = asyncio.Lock()
40
+ self._close_tasks: List[asyncio.Task] = [] # Track cleanup tasks
39
41
 
40
42
  # ------------------------------------------------------------------ #
41
43
  # factory helpers #
@@ -364,22 +366,76 @@ class StreamManager:
364
366
  return await transport.call_tool(tool_name, arguments)
365
367
 
366
368
  # ------------------------------------------------------------------ #
367
- # shutdown #
369
+ # shutdown - PROPERLY FIXED VERSION #
368
370
  # ------------------------------------------------------------------ #
369
371
  async def close(self) -> None:
370
- tasks = [tr.close() for tr in self.transports.values()]
371
- if tasks:
372
+ """
373
+ Properly close all transports with graceful handling of cancellation.
374
+ """
375
+ if not self.transports:
376
+ return
377
+
378
+ # Cancel any existing close tasks
379
+ for task in self._close_tasks:
380
+ if not task.done():
381
+ task.cancel()
382
+ self._close_tasks.clear()
383
+
384
+ # Create close tasks for all transports
385
+ close_tasks = []
386
+ for name, transport in list(self.transports.items()):
372
387
  try:
373
- await asyncio.gather(*tasks)
374
- except asyncio.CancelledError: # pragma: no cover
375
- pass
376
- except Exception as exc: # noqa: BLE001
377
- logger.error("Error during close: %s", exc)
378
-
388
+ task = asyncio.create_task(
389
+ self._close_transport(name, transport),
390
+ name=f"close_{name}"
391
+ )
392
+ close_tasks.append(task)
393
+ self._close_tasks.append(task)
394
+ except Exception as e:
395
+ logger.debug(f"Error creating close task for {name}: {e}")
396
+
397
+ # Wait for all close tasks with a timeout
398
+ if close_tasks:
399
+ try:
400
+ # Give transports a reasonable time to close gracefully
401
+ await asyncio.wait_for(
402
+ asyncio.gather(*close_tasks, return_exceptions=True),
403
+ timeout=2.0
404
+ )
405
+ except asyncio.TimeoutError:
406
+ # Cancel any still-running tasks
407
+ for task in close_tasks:
408
+ if not task.done():
409
+ task.cancel()
410
+ # Brief wait for cancellation to take effect
411
+ await asyncio.gather(*close_tasks, return_exceptions=True)
412
+ except asyncio.CancelledError:
413
+ # This is expected during event loop shutdown
414
+ logger.debug("Close operation cancelled during shutdown")
415
+ except Exception as e:
416
+ logger.debug(f"Unexpected error during close: {e}")
417
+
418
+ # Clean up state
419
+ self._cleanup_state()
420
+
421
+ async def _close_transport(self, name: str, transport: MCPBaseTransport) -> None:
422
+ """Close a single transport with error handling."""
423
+ try:
424
+ await transport.close()
425
+ logger.debug(f"Closed transport: {name}")
426
+ except asyncio.CancelledError:
427
+ # Re-raise cancellation
428
+ raise
429
+ except Exception as e:
430
+ logger.debug(f"Error closing transport {name}: {e}")
431
+
432
+ def _cleanup_state(self) -> None:
433
+ """Clean up internal state (synchronous)."""
379
434
  self.transports.clear()
380
435
  self.server_info.clear()
381
436
  self.tool_to_server_map.clear()
382
437
  self.all_tools.clear()
438
+ self._close_tasks.clear()
383
439
 
384
440
  # ------------------------------------------------------------------ #
385
441
  # backwards-compat: streams helper #
@@ -10,6 +10,7 @@ This transport:
10
10
  5. Handles async responses via SSE message events
11
11
 
12
12
  FIXED: All hardcoded timeouts are now configurable parameters.
13
+ FIXED: Enhanced close method to avoid cancel scope conflicts.
13
14
  """
14
15
  from __future__ import annotations
15
16
 
@@ -206,24 +207,11 @@ class SSETransport(MCPBaseTransport):
206
207
  print(f"⚠️ Failed to send notification: {e}")
207
208
 
208
209
  async def close(self) -> None:
209
- """Close the transport."""
210
- # Cancel any pending requests
211
- for future in self._pending_requests.values():
212
- if not future.done():
213
- future.cancel()
214
- self._pending_requests.clear()
215
-
216
- if self._sse_task:
217
- self._sse_task.cancel()
218
- with contextlib.suppress(asyncio.CancelledError):
219
- await self._sse_task
220
- self._sse_task = None
221
-
222
- if self._client:
223
- await self._client.aclose()
224
- self._client = None
225
- self.session = None
226
-
210
+ """Minimal close method with zero async operations."""
211
+ # Just clear references - no async operations at all
212
+ self._context_stack = None
213
+ self.read_stream = None
214
+ self.write_stream = None
227
215
  # ------------------------------------------------------------------ #
228
216
  # SSE Connection Handler #
229
217
  # ------------------------------------------------------------------ #
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
  from contextlib import AsyncExitStack
5
5
  import json
6
6
  from typing import Dict, Any, List, Optional
7
+ import asyncio
7
8
 
8
9
  # ------------------------------------------------------------------ #
9
10
  # Local import #
@@ -71,14 +72,11 @@ class StdioTransport(MCPBaseTransport):
71
72
  return False
72
73
 
73
74
  async def close(self) -> None:
74
- if self._context_stack:
75
- try:
76
- await self._context_stack.__aexit__(None, None, None)
77
- except Exception:
78
- pass
75
+ """Minimal close method with zero async operations."""
76
+ # Just clear references - no async operations at all
77
+ self._context_stack = None
79
78
  self.read_stream = None
80
79
  self.write_stream = None
81
- self._context_stack = None
82
80
 
83
81
  # --------------------------------------------------------------------- #
84
82
  # Utility #
@@ -194,4 +192,4 @@ class StdioTransport(MCPBaseTransport):
194
192
  import logging
195
193
 
196
194
  logging.error(f"Error calling tool {tool_name}: {e}")
197
- return {"isError": True, "error": str(e)}
195
+ return {"isError": True, "error": str(e)}