chuk-tool-processor 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -5,299 +5,289 @@ StreamManager for CHUK Tool Processor.
5
5
  from __future__ import annotations
6
6
 
7
7
  import asyncio
8
- import json
9
- from typing import Dict, List, Optional, Any
8
+ from typing import Any, Dict, List, Optional, Tuple
10
9
 
11
- # tool processor imports
10
+ # --------------------------------------------------------------------------- #
11
+ # CHUK imports #
12
+ # --------------------------------------------------------------------------- #
12
13
  from chuk_mcp.config import load_config
13
- from chuk_tool_processor.mcp.transport import MCPBaseTransport, StdioTransport, SSETransport
14
+ from chuk_tool_processor.mcp.transport import (
15
+ MCPBaseTransport,
16
+ StdioTransport,
17
+ SSETransport,
18
+ )
14
19
  from chuk_tool_processor.logging import get_logger
15
20
 
16
- # logger
17
21
  logger = get_logger("chuk_tool_processor.mcp.stream_manager")
18
22
 
23
+
19
24
  class StreamManager:
20
25
  """
21
26
  Manager for MCP server streams with support for multiple transport types.
22
27
  """
23
-
24
- def __init__(self):
25
- """Initialize the StreamManager."""
28
+
29
+ # ------------------------------------------------------------------ #
30
+ # construction #
31
+ # ------------------------------------------------------------------ #
32
+ def __init__(self) -> None:
26
33
  self.transports: Dict[str, MCPBaseTransport] = {}
27
34
  self.server_info: List[Dict[str, Any]] = []
28
35
  self.tool_to_server_map: Dict[str, str] = {}
29
36
  self.server_names: Dict[int, str] = {}
30
37
  self.all_tools: List[Dict[str, Any]] = []
31
38
  self._lock = asyncio.Lock()
32
-
39
+
40
+ # ------------------------------------------------------------------ #
41
+ # factory helpers #
42
+ # ------------------------------------------------------------------ #
33
43
  @classmethod
34
44
  async def create(
35
45
  cls,
36
46
  config_file: str,
37
47
  servers: List[str],
38
48
  server_names: Optional[Dict[int, str]] = None,
39
- transport_type: str = "stdio"
40
- ) -> StreamManager:
41
- """
42
- Create and initialize a StreamManager.
43
-
44
- Args:
45
- config_file: Path to the config file
46
- servers: List of server names to connect to
47
- server_names: Optional mapping of server indices to names
48
- transport_type: Transport type ("stdio" or "sse")
49
-
50
- Returns:
51
- Initialized StreamManager
52
- """
53
- manager = cls()
54
- await manager.initialize(config_file, servers, server_names, transport_type)
55
- return manager
56
-
49
+ transport_type: str = "stdio",
50
+ ) -> "StreamManager":
51
+ inst = cls()
52
+ await inst.initialize(config_file, servers, server_names, transport_type)
53
+ return inst
54
+
57
55
  @classmethod
58
56
  async def create_with_sse(
59
57
  cls,
60
58
  servers: List[Dict[str, str]],
61
- server_names: Optional[Dict[int, str]] = None
62
- ) -> StreamManager:
63
- """
64
- Create and initialize a StreamManager with SSE transport.
65
-
66
- Args:
67
- servers: List of server configurations with "name" and "url" keys
68
- server_names: Optional mapping of server indices to names
69
-
70
- Returns:
71
- Initialized StreamManager
72
- """
73
- manager = cls()
74
- await manager.initialize_with_sse(servers, server_names)
75
- return manager
76
-
59
+ server_names: Optional[Dict[int, str]] = None,
60
+ ) -> "StreamManager":
61
+ inst = cls()
62
+ await inst.initialize_with_sse(servers, server_names)
63
+ return inst
64
+
65
+ # ------------------------------------------------------------------ #
66
+ # initialisation stdio / sse #
67
+ # ------------------------------------------------------------------ #
77
68
  async def initialize(
78
69
  self,
79
70
  config_file: str,
80
71
  servers: List[str],
81
72
  server_names: Optional[Dict[int, str]] = None,
82
- transport_type: str = "stdio"
73
+ transport_type: str = "stdio",
83
74
  ) -> None:
84
- """
85
- Initialize the StreamManager.
86
-
87
- Args:
88
- config_file: Path to the config file
89
- servers: List of server names to connect to
90
- server_names: Optional mapping of server indices to names
91
- transport_type: Transport type ("stdio" or "sse")
92
- """
93
75
  async with self._lock:
94
- # Store server names mapping
95
76
  self.server_names = server_names or {}
96
-
97
- # Initialize servers
98
- for i, server_name in enumerate(servers):
77
+
78
+ for idx, server_name in enumerate(servers):
99
79
  try:
100
80
  if transport_type == "stdio":
101
- # Load configuration
102
- server_params = await load_config(config_file, server_name)
103
-
104
- # Create transport
105
- transport = StdioTransport(server_params)
81
+ params = await load_config(config_file, server_name)
82
+ transport: MCPBaseTransport = StdioTransport(params)
106
83
  elif transport_type == "sse":
107
- # For SSE, we would parse the config differently
108
- # This is just a placeholder
109
84
  transport = SSETransport("http://localhost:8000")
110
85
  else:
111
- logger.error(f"Unsupported transport type: {transport_type}")
86
+ logger.error("Unsupported transport type: %s", transport_type)
112
87
  continue
113
-
114
- # Initialize transport
88
+
115
89
  if not await transport.initialize():
116
- logger.error(f"Failed to initialize transport for server: {server_name}")
90
+ logger.error("Failed to init %s", server_name)
117
91
  continue
118
-
119
- # Store transport
92
+
93
+ # store transport
120
94
  self.transports[server_name] = transport
121
-
122
- # Check server is responsive
123
- ping_result = await transport.send_ping()
124
- status = "Up" if ping_result else "Down"
125
-
126
- # Get available tools
95
+
96
+ # ping + gather tools
97
+ status = "Up" if await transport.send_ping() else "Down"
127
98
  tools = await transport.get_tools()
128
-
129
- # Map tools to server
130
- for tool in tools:
131
- tool_name = tool.get("name")
132
- if tool_name:
133
- self.tool_to_server_map[tool_name] = server_name
134
-
135
- # Add to all tools
99
+
100
+ for t in tools:
101
+ name = t.get("name")
102
+ if name:
103
+ self.tool_to_server_map[name] = server_name
136
104
  self.all_tools.extend(tools)
137
-
138
- # Add server info
139
- self.server_info.append({
140
- "id": i,
141
- "name": server_name,
142
- "tools": len(tools),
143
- "status": status
144
- })
145
-
146
- logger.info(f"Initialized server {server_name} with {len(tools)} tools")
147
-
148
- except Exception as e:
149
- logger.error(f"Error initializing server {server_name}: {e}")
150
-
151
- logger.info(f"StreamManager initialized with {len(self.transports)} servers and {len(self.all_tools)} tools")
152
-
105
+
106
+ self.server_info.append(
107
+ {
108
+ "id": idx,
109
+ "name": server_name,
110
+ "tools": len(tools),
111
+ "status": status,
112
+ }
113
+ )
114
+ logger.info("Initialised %s %d tool(s)", server_name, len(tools))
115
+ except Exception as exc: # noqa: BLE001
116
+ logger.error("Error initialising %s: %s", server_name, exc)
117
+
118
+ logger.info(
119
+ "StreamManager ready %d server(s), %d tool(s)",
120
+ len(self.transports),
121
+ len(self.all_tools),
122
+ )
123
+
153
124
  async def initialize_with_sse(
154
125
  self,
155
126
  servers: List[Dict[str, str]],
156
- server_names: Optional[Dict[int, str]] = None
127
+ server_names: Optional[Dict[int, str]] = None,
157
128
  ) -> None:
158
- """
159
- Initialize the StreamManager with SSE transport.
160
-
161
- Args:
162
- servers: List of server configurations with "name" and "url" keys
163
- server_names: Optional mapping of server indices to names
164
- """
165
129
  async with self._lock:
166
- # Store server names mapping
167
130
  self.server_names = server_names or {}
168
-
169
- # Initialize servers
170
- for i, server_config in enumerate(servers):
171
- server_name = server_config.get("name")
172
- url = server_config.get("url")
173
- api_key = server_config.get("api_key")
174
-
175
- if not server_name or not url:
176
- logger.error(f"Invalid server configuration: {server_config}")
131
+
132
+ for idx, cfg in enumerate(servers):
133
+ name, url = cfg.get("name"), cfg.get("url")
134
+ if not (name and url):
135
+ logger.error("Bad server config: %s", cfg)
177
136
  continue
178
-
179
137
  try:
180
- # Create transport
181
- transport = SSETransport(url, api_key)
182
-
183
- # Initialize transport
138
+ transport = SSETransport(url, cfg.get("api_key"))
184
139
  if not await transport.initialize():
185
- logger.error(f"Failed to initialize SSE transport for server: {server_name}")
140
+ logger.error("Failed to init SSE %s", name)
186
141
  continue
187
-
188
- # Store transport
189
- self.transports[server_name] = transport
190
-
191
- # Check server is responsive
192
- ping_result = await transport.send_ping()
193
- status = "Up" if ping_result else "Down"
194
-
195
- # Get available tools
142
+
143
+ self.transports[name] = transport
144
+ status = "Up" if await transport.send_ping() else "Down"
196
145
  tools = await transport.get_tools()
197
-
198
- # Map tools to server
199
- for tool in tools:
200
- tool_name = tool.get("name")
201
- if tool_name:
202
- self.tool_to_server_map[tool_name] = server_name
203
-
204
- # Add to all tools
146
+
147
+ for t in tools:
148
+ tname = t.get("name")
149
+ if tname:
150
+ self.tool_to_server_map[tname] = name
205
151
  self.all_tools.extend(tools)
206
-
207
- # Add server info
208
- self.server_info.append({
209
- "id": i,
210
- "name": server_name,
211
- "tools": len(tools),
212
- "status": status
213
- })
214
-
215
- logger.info(f"Initialized SSE server {server_name} with {len(tools)} tools")
216
-
217
- except Exception as e:
218
- logger.error(f"Error initializing SSE server {server_name}: {e}")
219
-
220
- logger.info(f"StreamManager initialized with {len(self.transports)} SSE servers and {len(self.all_tools)} tools")
221
-
152
+
153
+ self.server_info.append(
154
+ {"id": idx, "name": name, "tools": len(tools), "status": status}
155
+ )
156
+ logger.info("Initialised SSE %s – %d tool(s)", name, len(tools))
157
+ except Exception as exc: # noqa: BLE001
158
+ logger.error("Error initialising SSE %s: %s", name, exc)
159
+
160
+ logger.info(
161
+ "StreamManager ready – %d SSE server(s), %d tool(s)",
162
+ len(self.transports),
163
+ len(self.all_tools),
164
+ )
165
+
166
+ # ------------------------------------------------------------------ #
167
+ # queries #
168
+ # ------------------------------------------------------------------ #
222
169
  def get_all_tools(self) -> List[Dict[str, Any]]:
223
- """
224
- Get all available tools.
225
-
226
- Returns:
227
- List of tool definitions
228
- """
229
170
  return self.all_tools
230
-
171
+
231
172
  def get_server_for_tool(self, tool_name: str) -> Optional[str]:
232
- """
233
- Get the server name for a tool.
234
-
235
- Args:
236
- tool_name: Tool name
237
-
238
- Returns:
239
- Server name or None if not found
240
- """
241
173
  return self.tool_to_server_map.get(tool_name)
242
-
174
+
243
175
  def get_server_info(self) -> List[Dict[str, Any]]:
244
- """
245
- Get information about all servers.
246
-
247
- Returns:
248
- List of server info dictionaries
249
- """
250
176
  return self.server_info
251
-
177
+
178
+ # ------------------------------------------------------------------ #
179
+ # EXTRA HELPERS – ping / resources / prompts #
180
+ # ------------------------------------------------------------------ #
181
+ async def ping_servers(self) -> List[Dict[str, Any]]:
182
+ async def _ping_one(name: str, tr: MCPBaseTransport):
183
+ try:
184
+ ok = await tr.send_ping()
185
+ except Exception: # pragma: no cover
186
+ ok = False
187
+ return {"server": name, "ok": ok}
188
+
189
+ return await asyncio.gather(*(_ping_one(n, t) for n, t in self.transports.items()))
190
+
191
+ async def list_resources(self) -> List[Dict[str, Any]]:
192
+ out: List[Dict[str, Any]] = []
193
+
194
+ async def _one(name: str, tr: MCPBaseTransport):
195
+ if not hasattr(tr, "list_resources"):
196
+ return
197
+ try:
198
+ res = await tr.list_resources() # type: ignore[attr-defined]
199
+ # accept either {"resources": [...]} **or** a plain list
200
+ resources = (
201
+ res.get("resources", []) if isinstance(res, dict) else res
202
+ )
203
+ for item in resources:
204
+ item = dict(item)
205
+ item["server"] = name
206
+ out.append(item)
207
+ except Exception as exc:
208
+ logger.debug("resources/list failed for %s: %s", name, exc)
209
+
210
+ await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()))
211
+ return out
212
+
213
+ async def list_prompts(self) -> List[Dict[str, Any]]:
214
+ out: List[Dict[str, Any]] = []
215
+
216
+ async def _one(name: str, tr: MCPBaseTransport):
217
+ if not hasattr(tr, "list_prompts"):
218
+ return
219
+ try:
220
+ res = await tr.list_prompts() # type: ignore[attr-defined]
221
+ prompts = res.get("prompts", []) if isinstance(res, dict) else res
222
+ for item in prompts:
223
+ item = dict(item)
224
+ item["server"] = name
225
+ out.append(item)
226
+ except Exception as exc:
227
+ logger.debug("prompts/list failed for %s: %s", name, exc)
228
+
229
+ await asyncio.gather(*(_one(n, t) for n, t in self.transports.items()))
230
+ return out
231
+
232
+ # ------------------------------------------------------------------ #
233
+ # tool execution #
234
+ # ------------------------------------------------------------------ #
252
235
  async def call_tool(
253
236
  self,
254
237
  tool_name: str,
255
238
  arguments: Dict[str, Any],
256
- server_name: Optional[str] = None
239
+ server_name: Optional[str] = None,
257
240
  ) -> Dict[str, Any]:
258
- """
259
- Call a tool.
260
-
261
- Args:
262
- tool_name: Tool name
263
- arguments: Tool arguments
264
- server_name: Optional server name override
265
-
266
- Returns:
267
- Tool result
268
- """
269
- # Get server name
270
- if not server_name:
271
- server_name = self.get_server_for_tool(tool_name)
272
-
241
+ server_name = server_name or self.get_server_for_tool(tool_name)
273
242
  if not server_name or server_name not in self.transports:
243
+ # wording kept exactly for unit-test expectation
274
244
  return {
275
245
  "isError": True,
276
- "error": f"No server found for tool: {tool_name}"
246
+ "error": f"No server found for tool: {tool_name}",
277
247
  }
278
-
279
- # Get transport
280
- transport = self.transports[server_name]
281
-
282
- # Call tool
283
- return await transport.call_tool(tool_name, arguments)
284
-
248
+ return await self.transports[server_name].call_tool(tool_name, arguments)
249
+
250
+ # ------------------------------------------------------------------ #
251
+ # shutdown #
252
+ # ------------------------------------------------------------------ #
285
253
  async def close(self) -> None:
286
- """Close all transports."""
287
- close_tasks = []
288
- for name, transport in self.transports.items():
289
- close_tasks.append(transport.close())
290
-
291
- if close_tasks:
254
+ tasks = [tr.close() for tr in self.transports.values()]
255
+ if tasks:
292
256
  try:
293
- await asyncio.gather(*close_tasks)
294
- except asyncio.CancelledError:
295
- # Ignore cancellation during cleanup
257
+ await asyncio.gather(*tasks)
258
+ except asyncio.CancelledError: # pragma: no cover
296
259
  pass
297
- except Exception as e:
298
- logger.error(f"Error closing transports: {e}")
299
-
260
+ except Exception as exc: # noqa: BLE001
261
+ logger.error("Error during close: %s", exc)
262
+
300
263
  self.transports.clear()
301
264
  self.server_info.clear()
302
265
  self.tool_to_server_map.clear()
303
- self.all_tools.clear()
266
+ self.all_tools.clear()
267
+
268
+ # ------------------------------------------------------------------ #
269
+ # backwards-compat: streams helper #
270
+ # ------------------------------------------------------------------ #
271
+ def get_streams(self) -> List[Tuple[Any, Any]]:
272
+ """
273
+ Return a list of ``(read_stream, write_stream)`` tuples for **all**
274
+ transports. Older CLI commands rely on this helper.
275
+ """
276
+ pairs: List[Tuple[Any, Any]] = []
277
+
278
+ for tr in self.transports.values():
279
+ if hasattr(tr, "get_streams") and callable(tr.get_streams):
280
+ pairs.extend(tr.get_streams()) # type: ignore[arg-type]
281
+ continue
282
+
283
+ rd = getattr(tr, "read_stream", None)
284
+ wr = getattr(tr, "write_stream", None)
285
+ if rd and wr:
286
+ pairs.append((rd, wr))
287
+
288
+ return pairs
289
+
290
+ # convenience alias
291
+ @property
292
+ def streams(self) -> List[Tuple[Any, Any]]: # pragma: no cover
293
+ return self.get_streams()
@@ -2,63 +2,102 @@
2
2
  """
3
3
  Abstract transport layer for MCP communication.
4
4
  """
5
+ from __future__ import annotations
6
+
5
7
  from abc import ABC, abstractmethod
6
8
  from typing import Any, Dict, List
7
9
 
10
+
8
11
  class MCPBaseTransport(ABC):
9
12
  """
10
13
  Abstract base class for MCP transport mechanisms.
11
14
  """
12
-
15
+
16
+ # ------------------------------------------------------------------ #
17
+ # connection lifecycle #
18
+ # ------------------------------------------------------------------ #
13
19
  @abstractmethod
14
20
  async def initialize(self) -> bool:
15
21
  """
16
- Initialize the transport connection.
17
-
18
- Returns:
19
- True if successful, False otherwise
22
+ Establish the connection.
23
+
24
+ Returns
25
+ -------
26
+ bool
27
+ ``True`` if the connection was initialised successfully.
20
28
  """
21
- pass
22
-
29
+ raise NotImplementedError
30
+
31
+ @abstractmethod
32
+ async def close(self) -> None:
33
+ """Tear down the connection and release all resources."""
34
+ raise NotImplementedError
35
+
36
+ # ------------------------------------------------------------------ #
37
+ # diagnostics #
38
+ # ------------------------------------------------------------------ #
23
39
  @abstractmethod
24
40
  async def send_ping(self) -> bool:
25
41
  """
26
- Send a ping message.
27
-
28
- Returns:
29
- True if successful, False otherwise
42
+ Send a **ping** request.
43
+
44
+ Returns
45
+ -------
46
+ bool
47
+ ``True`` on success, ``False`` otherwise.
30
48
  """
31
- pass
32
-
49
+ raise NotImplementedError
50
+
51
+ # ------------------------------------------------------------------ #
52
+ # tool handling #
53
+ # ------------------------------------------------------------------ #
33
54
  @abstractmethod
34
55
  async def get_tools(self) -> List[Dict[str, Any]]:
35
56
  """
36
- Get available tools.
37
-
38
- Returns:
39
- List of tool definitions
57
+ Return a list with *all* tool definitions exposed by the server.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ @abstractmethod
62
+ async def call_tool(
63
+ self, tool_name: str, arguments: Dict[str, Any]
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Execute *tool_name* with *arguments* and return the normalised result.
40
67
  """
41
- pass
42
-
68
+ raise NotImplementedError
69
+
70
+ # ------------------------------------------------------------------ #
71
+ # new: resources & prompts #
72
+ # ------------------------------------------------------------------ #
43
73
  @abstractmethod
44
- async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
74
+ async def list_resources(self) -> Dict[str, Any]:
45
75
  """
46
- Call a tool.
47
-
48
- Args:
49
- tool_name: Tool name
50
- arguments: Tool arguments
51
-
52
- Returns:
53
- Tool result
76
+ Retrieve the server’s resources catalogue.
77
+
78
+ Expected shape::
79
+ { "resources": [ {...}, ... ], "nextCursor": "…", … }
54
80
  """
55
- pass
56
-
81
+ raise NotImplementedError
82
+
57
83
  @abstractmethod
58
- async def close(self) -> None:
59
- """Close the transport connection."""
60
- pass
84
+ async def list_prompts(self) -> Dict[str, Any]:
85
+ """
86
+ Retrieve the server’s prompt catalogue.
61
87
 
88
+ Expected shape::
89
+ { "prompts": [ {...}, ... ], "nextCursor": "…", … }
90
+ """
91
+ raise NotImplementedError
62
92
 
93
+ # ------------------------------------------------------------------ #
94
+ # optional helper (non-abstract) #
95
+ # ------------------------------------------------------------------ #
96
+ def get_streams(self):
97
+ """
98
+ Return a list of ``(read_stream, write_stream)`` tuples.
63
99
 
64
-
100
+ Transports that do not expose their low-level streams can simply leave
101
+ the default implementation (which returns an empty list).
102
+ """
103
+ return []