chuk-tool-processor 0.1.7__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -1,11 +1,13 @@
1
1
  #!/usr/bin/env python
2
2
  # chuk_tool_processor/execution/strategies/inprocess_strategy.py
3
3
  """
4
- In-process execution strategy for tools with true streaming support.
4
+ In-process execution strategy for tools with proper timeout handling.
5
5
 
6
6
  This strategy executes tools concurrently in the same process using asyncio.
7
7
  It has special support for streaming tools, accessing their stream_execute method
8
8
  directly to enable true item-by-item streaming.
9
+
10
+ FIXED: Ensures consistent timeout handling across all execution paths.
9
11
  """
10
12
  from __future__ import annotations
11
13
 
@@ -36,7 +38,7 @@ async def _noop_cm():
36
38
 
37
39
  # --------------------------------------------------------------------------- #
38
40
  class InProcessStrategy(ExecutionStrategy):
39
- """Execute tools in the local event-loop with optional concurrency cap."""
41
+ """Execute tools in the local event-loop with optional concurrency cap and consistent timeout handling."""
40
42
 
41
43
  def __init__(
42
44
  self,
@@ -53,7 +55,7 @@ class InProcessStrategy(ExecutionStrategy):
53
55
  max_concurrency: Maximum number of concurrent executions
54
56
  """
55
57
  self.registry = registry
56
- self.default_timeout = default_timeout
58
+ self.default_timeout = default_timeout or 30.0 # Always have a default
57
59
  self._sem = asyncio.Semaphore(max_concurrency) if max_concurrency else None
58
60
 
59
61
  # Task tracking for cleanup
@@ -64,6 +66,9 @@ class InProcessStrategy(ExecutionStrategy):
64
66
  # Tracking for which calls are being handled directly by the executor
65
67
  # to prevent duplicate streaming results
66
68
  self._direct_streaming_calls = set()
69
+
70
+ logger.debug("InProcessStrategy initialized with timeout: %ss, max_concurrency: %s",
71
+ self.default_timeout, max_concurrency)
67
72
 
68
73
  # ------------------------------------------------------------------ #
69
74
  def mark_direct_streaming(self, call_ids: Set[str]) -> None:
@@ -116,11 +121,15 @@ class InProcessStrategy(ExecutionStrategy):
116
121
  """
117
122
  if not calls:
118
123
  return []
124
+
125
+ # Use default_timeout if no timeout specified
126
+ effective_timeout = timeout if timeout is not None else self.default_timeout
127
+ logger.debug("Executing %d calls with %ss timeout each", len(calls), effective_timeout)
119
128
 
120
129
  tasks = []
121
130
  for call in calls:
122
131
  task = asyncio.create_task(
123
- self._execute_single_call(call, timeout or self.default_timeout)
132
+ self._execute_single_call(call, effective_timeout) # Always pass timeout
124
133
  )
125
134
  self._active_tasks.add(task)
126
135
  task.add_done_callback(self._active_tasks.discard)
@@ -142,10 +151,13 @@ class InProcessStrategy(ExecutionStrategy):
142
151
  if not calls:
143
152
  return
144
153
 
154
+ # Use default_timeout if no timeout specified
155
+ effective_timeout = timeout if timeout is not None else self.default_timeout
156
+
145
157
  queue: asyncio.Queue[ToolResult] = asyncio.Queue()
146
158
  tasks = {
147
159
  asyncio.create_task(
148
- self._stream_tool_call(call, queue, timeout or self.default_timeout)
160
+ self._stream_tool_call(call, queue, effective_timeout) # Always pass timeout
149
161
  )
150
162
  for call in calls
151
163
  if call.id not in self._direct_streaming_calls
@@ -170,7 +182,7 @@ class InProcessStrategy(ExecutionStrategy):
170
182
  self,
171
183
  call: ToolCall,
172
184
  queue: asyncio.Queue,
173
- timeout: Optional[float],
185
+ timeout: float, # Make timeout required
174
186
  ) -> None:
175
187
  """
176
188
  Execute a tool call with streaming support.
@@ -181,7 +193,7 @@ class InProcessStrategy(ExecutionStrategy):
181
193
  Args:
182
194
  call: The tool call to execute
183
195
  queue: Queue to put results into
184
- timeout: Optional timeout in seconds
196
+ timeout: Timeout in seconds (required)
185
197
  """
186
198
  # Skip if call is being handled directly by the executor
187
199
  if call.id in self._direct_streaming_calls:
@@ -269,7 +281,7 @@ class InProcessStrategy(ExecutionStrategy):
269
281
  tool: Any,
270
282
  call: ToolCall,
271
283
  queue: asyncio.Queue,
272
- timeout: Optional[float]
284
+ timeout: float, # Make timeout required
273
285
  ) -> None:
274
286
  """
275
287
  Stream results from a streaming tool with timeout support.
@@ -281,12 +293,14 @@ class InProcessStrategy(ExecutionStrategy):
281
293
  tool: The tool instance
282
294
  call: Tool call data
283
295
  queue: Queue to put results into
284
- timeout: Optional timeout in seconds
296
+ timeout: Timeout in seconds (required)
285
297
  """
286
298
  start_time = datetime.now(timezone.utc)
287
299
  machine = os.uname().nodename
288
300
  pid = os.getpid()
289
301
 
302
+ logger.debug("Streaming %s with %ss timeout", call.tool, timeout)
303
+
290
304
  # Define the streaming task
291
305
  async def streamer():
292
306
  try:
@@ -318,15 +332,17 @@ class InProcessStrategy(ExecutionStrategy):
318
332
  await queue.put(error_result)
319
333
 
320
334
  try:
321
- # Execute with timeout if specified
322
- if timeout:
323
- await asyncio.wait_for(streamer(), timeout)
324
- else:
325
- await streamer()
335
+ # Always execute with timeout
336
+ await asyncio.wait_for(streamer(), timeout)
337
+ logger.debug("%s streaming completed within %ss", call.tool, timeout)
326
338
 
327
339
  except asyncio.TimeoutError:
328
340
  # Handle timeout
329
341
  now = datetime.now(timezone.utc)
342
+ actual_duration = (now - start_time).total_seconds()
343
+ logger.debug("%s streaming timed out after %.3fs (limit: %ss)",
344
+ call.tool, actual_duration, timeout)
345
+
330
346
  timeout_result = ToolResult(
331
347
  tool=call.tool,
332
348
  result=None,
@@ -341,6 +357,8 @@ class InProcessStrategy(ExecutionStrategy):
341
357
  except Exception as e:
342
358
  # Handle other errors
343
359
  now = datetime.now(timezone.utc)
360
+ logger.debug("%s streaming failed: %s", call.tool, e)
361
+
344
362
  error_result = ToolResult(
345
363
  tool=call.tool,
346
364
  result=None,
@@ -356,7 +374,7 @@ class InProcessStrategy(ExecutionStrategy):
356
374
  self,
357
375
  call: ToolCall,
358
376
  queue: asyncio.Queue,
359
- timeout: Optional[float],
377
+ timeout: float, # Make timeout required
360
378
  ) -> None:
361
379
  """Execute a single call and put the result in the queue."""
362
380
  # Skip if call is being handled directly by the executor
@@ -370,17 +388,17 @@ class InProcessStrategy(ExecutionStrategy):
370
388
  async def _execute_single_call(
371
389
  self,
372
390
  call: ToolCall,
373
- timeout: Optional[float],
391
+ timeout: float, # Make timeout required, not optional
374
392
  ) -> ToolResult:
375
393
  """
376
- Execute a single tool call.
394
+ Execute a single tool call with guaranteed timeout.
377
395
 
378
396
  The entire invocation – including argument validation – is wrapped
379
397
  by the semaphore to honour *max_concurrency*.
380
398
 
381
399
  Args:
382
400
  call: Tool call to execute
383
- timeout: Optional timeout in seconds
401
+ timeout: Timeout in seconds (required)
384
402
 
385
403
  Returns:
386
404
  Tool execution result
@@ -389,6 +407,8 @@ class InProcessStrategy(ExecutionStrategy):
389
407
  machine = os.uname().nodename
390
408
  start = datetime.now(timezone.utc)
391
409
 
410
+ logger.debug("Executing %s with %ss timeout", call.tool, timeout)
411
+
392
412
  # Early exit if shutting down
393
413
  if self._shutting_down:
394
414
  return ToolResult(
@@ -464,19 +484,18 @@ class InProcessStrategy(ExecutionStrategy):
464
484
  self,
465
485
  tool: Any,
466
486
  call: ToolCall,
467
- timeout: float | None,
487
+ timeout: float, # Make timeout required, not optional
468
488
  start: datetime,
469
489
  machine: str,
470
490
  pid: int,
471
491
  ) -> ToolResult:
472
492
  """
473
- Resolve the correct async entry-point and invoke it with an optional
474
- timeout.
493
+ Resolve the correct async entry-point and invoke it with a guaranteed timeout.
475
494
 
476
495
  Args:
477
496
  tool: Tool instance
478
497
  call: Tool call data
479
- timeout: Optional timeout in seconds
498
+ timeout: Timeout in seconds (required)
480
499
  start: Start time for the execution
481
500
  machine: Machine name
482
501
  pid: Process ID
@@ -507,62 +526,46 @@ class InProcessStrategy(ExecutionStrategy):
507
526
  )
508
527
 
509
528
  try:
510
- if timeout:
511
- # Use a task with explicit cancellation
512
- task = asyncio.create_task(fn(**call.arguments))
529
+ # Always apply timeout
530
+ logger.debug("Applying %ss timeout to %s", timeout, call.tool)
531
+
532
+ try:
533
+ result_val = await asyncio.wait_for(fn(**call.arguments), timeout=timeout)
534
+
535
+ end_time = datetime.now(timezone.utc)
536
+ actual_duration = (end_time - start).total_seconds()
537
+ logger.debug("%s completed in %.3fs (limit: %ss)",
538
+ call.tool, actual_duration, timeout)
513
539
 
514
- try:
515
- # Wait for the task with timeout
516
- result_val = await asyncio.wait_for(task, timeout)
517
-
518
- return ToolResult(
519
- tool=call.tool,
520
- result=result_val,
521
- error=None,
522
- start_time=start,
523
- end_time=datetime.now(timezone.utc),
524
- machine=machine,
525
- pid=pid,
526
- )
527
- except asyncio.TimeoutError:
528
- # Cancel the task if it times out
529
- if not task.done():
530
- task.cancel()
531
-
532
- # Wait for cancellation to complete
533
- try:
534
- await task
535
- except asyncio.CancelledError:
536
- # Expected - we just cancelled it
537
- pass
538
- except Exception:
539
- # Ignore any other exceptions during cancellation
540
- pass
541
-
542
- # Return a timeout error
543
- return ToolResult(
544
- tool=call.tool,
545
- result=None,
546
- error=f"Timeout after {timeout}s",
547
- start_time=start,
548
- end_time=datetime.now(timezone.utc),
549
- machine=machine,
550
- pid=pid,
551
- )
552
- else:
553
- # No timeout
554
- result_val = await fn(**call.arguments)
555
540
  return ToolResult(
556
541
  tool=call.tool,
557
542
  result=result_val,
558
543
  error=None,
559
544
  start_time=start,
560
- end_time=datetime.now(timezone.utc),
545
+ end_time=end_time,
561
546
  machine=machine,
562
547
  pid=pid,
563
548
  )
549
+ except asyncio.TimeoutError:
550
+ # Handle timeout
551
+ end_time = datetime.now(timezone.utc)
552
+ actual_duration = (end_time - start).total_seconds()
553
+ logger.debug("%s timed out after %.3fs (limit: %ss)",
554
+ call.tool, actual_duration, timeout)
555
+
556
+ return ToolResult(
557
+ tool=call.tool,
558
+ result=None,
559
+ error=f"Timeout after {timeout}s",
560
+ start_time=start,
561
+ end_time=end_time,
562
+ machine=machine,
563
+ pid=pid,
564
+ )
565
+
564
566
  except asyncio.CancelledError:
565
567
  # Handle cancellation explicitly
568
+ logger.debug("%s was cancelled", call.tool)
566
569
  return ToolResult(
567
570
  tool=call.tool,
568
571
  result=None,
@@ -574,12 +577,16 @@ class InProcessStrategy(ExecutionStrategy):
574
577
  )
575
578
  except Exception as exc:
576
579
  logger.exception("Error executing %s: %s", call.tool, exc)
580
+ end_time = datetime.now(timezone.utc)
581
+ actual_duration = (end_time - start).total_seconds()
582
+ logger.debug("%s failed after %.3fs: %s", call.tool, actual_duration, exc)
583
+
577
584
  return ToolResult(
578
585
  tool=call.tool,
579
586
  result=None,
580
587
  error=str(exc),
581
588
  start_time=start,
582
- end_time=datetime.now(timezone.utc),
589
+ end_time=end_time,
583
590
  machine=machine,
584
591
  pid=pid,
585
592
  )
@@ -4,6 +4,8 @@ Subprocess execution strategy - truly runs tools in separate OS processes.
4
4
 
5
5
  This strategy executes tools in separate Python processes using a process pool,
6
6
  providing isolation and potentially better parallelism on multi-core systems.
7
+
8
+ FIXED: Ensures consistent timeout handling across all execution paths.
7
9
  """
8
10
  from __future__ import annotations
9
11
 
@@ -133,7 +135,7 @@ def _process_worker(
133
135
 
134
136
  try:
135
137
  # Execute the tool with timeout
136
- if timeout:
138
+ if timeout is not None and timeout > 0:
137
139
  result_value = loop.run_until_complete(
138
140
  asyncio.wait_for(execute_fn(**arguments), timeout)
139
141
  )
@@ -192,7 +194,7 @@ class SubprocessStrategy(ExecutionStrategy):
192
194
  """
193
195
  self.registry = registry
194
196
  self.max_workers = max_workers
195
- self.default_timeout = default_timeout
197
+ self.default_timeout = default_timeout or 30.0 # Always have a default
196
198
  self.worker_init_timeout = worker_init_timeout
197
199
 
198
200
  # Process pool (initialized lazily)
@@ -204,6 +206,9 @@ class SubprocessStrategy(ExecutionStrategy):
204
206
  self._shutdown_event = asyncio.Event()
205
207
  self._shutting_down = False
206
208
 
209
+ logger.debug("SubprocessStrategy initialized with timeout: %ss, max_workers: %d",
210
+ self.default_timeout, max_workers)
211
+
207
212
  # Register shutdown handler if in main thread
208
213
  try:
209
214
  loop = asyncio.get_running_loop()
@@ -238,12 +243,12 @@ class SubprocessStrategy(ExecutionStrategy):
238
243
  loop.run_in_executor(self._process_pool, _pool_test_func),
239
244
  timeout=self.worker_init_timeout
240
245
  )
241
- logger.info(f"Process pool initialized with {self.max_workers} workers")
246
+ logger.info("Process pool initialized with %d workers", self.max_workers)
242
247
  except Exception as e:
243
248
  # Clean up on initialization error
244
249
  self._process_pool.shutdown(wait=False)
245
250
  self._process_pool = None
246
- logger.error(f"Failed to initialize process pool: {e}")
251
+ logger.error("Failed to initialize process pool: %s", e)
247
252
  raise RuntimeError(f"Failed to initialize process pool: {e}") from e
248
253
 
249
254
  # ------------------------------------------------------------------ #
@@ -296,12 +301,16 @@ class SubprocessStrategy(ExecutionStrategy):
296
301
  )
297
302
  for call in calls
298
303
  ]
304
+
305
+ # Use default_timeout if no timeout specified
306
+ effective_timeout = timeout if timeout is not None else self.default_timeout
307
+ logger.debug("Executing %d calls in subprocesses with %ss timeout each", len(calls), effective_timeout)
299
308
 
300
309
  # Create tasks for each call
301
310
  tasks = []
302
311
  for call in calls:
303
312
  task = asyncio.create_task(self._execute_single_call(
304
- call, timeout or self.default_timeout
313
+ call, effective_timeout # Always pass concrete timeout
305
314
  ))
306
315
  self._active_tasks.add(task)
307
316
  task.add_done_callback(self._active_tasks.discard)
@@ -342,6 +351,9 @@ class SubprocessStrategy(ExecutionStrategy):
342
351
  pid=os.getpid(),
343
352
  )
344
353
  return
354
+
355
+ # Use default_timeout if no timeout specified
356
+ effective_timeout = timeout if timeout is not None else self.default_timeout
345
357
 
346
358
  # Create a queue for results
347
359
  queue = asyncio.Queue()
@@ -350,7 +362,7 @@ class SubprocessStrategy(ExecutionStrategy):
350
362
  pending = set()
351
363
  for call in calls:
352
364
  task = asyncio.create_task(self._execute_to_queue(
353
- call, queue, timeout or self.default_timeout
365
+ call, queue, effective_timeout # Always pass concrete timeout
354
366
  ))
355
367
  self._active_tasks.add(task)
356
368
  task.add_done_callback(self._active_tasks.discard)
@@ -372,13 +384,13 @@ class SubprocessStrategy(ExecutionStrategy):
372
384
  try:
373
385
  await task
374
386
  except Exception as e:
375
- logger.exception(f"Error in task: {e}")
387
+ logger.exception("Error in task: %s", e)
376
388
 
377
389
  async def _execute_to_queue(
378
390
  self,
379
391
  call: ToolCall,
380
392
  queue: asyncio.Queue,
381
- timeout: Optional[float],
393
+ timeout: float, # Make timeout required
382
394
  ) -> None:
383
395
  """Execute a single call and put the result in the queue."""
384
396
  result = await self._execute_single_call(call, timeout)
@@ -387,20 +399,22 @@ class SubprocessStrategy(ExecutionStrategy):
387
399
  async def _execute_single_call(
388
400
  self,
389
401
  call: ToolCall,
390
- timeout: Optional[float],
402
+ timeout: float, # Make timeout required
391
403
  ) -> ToolResult:
392
404
  """
393
405
  Execute a single tool call in a separate process.
394
406
 
395
407
  Args:
396
408
  call: Tool call to execute
397
- timeout: Optional timeout in seconds
409
+ timeout: Timeout in seconds (required)
398
410
 
399
411
  Returns:
400
412
  Tool execution result
401
413
  """
402
414
  start_time = datetime.now(timezone.utc)
403
415
 
416
+ logger.debug("Executing %s in subprocess with %ss timeout", call.tool, timeout)
417
+
404
418
  try:
405
419
  # Ensure pool is initialized
406
420
  await self._ensure_pool()
@@ -429,8 +443,8 @@ class SubprocessStrategy(ExecutionStrategy):
429
443
  # Execute in subprocess
430
444
  loop = asyncio.get_running_loop()
431
445
 
432
- # We need to add safety timeout here to handle process crashes
433
- safety_timeout = (timeout or self.default_timeout or 60.0) + 5.0
446
+ # Add safety timeout to handle process crashes (tool timeout + buffer)
447
+ safety_timeout = timeout + 5.0
434
448
 
435
449
  try:
436
450
  result_data = await asyncio.wait_for(
@@ -443,7 +457,7 @@ class SubprocessStrategy(ExecutionStrategy):
443
457
  module_name,
444
458
  class_name,
445
459
  call.arguments,
446
- timeout
460
+ timeout # Pass the actual timeout to worker
447
461
  )
448
462
  ),
449
463
  timeout=safety_timeout
@@ -458,25 +472,40 @@ class SubprocessStrategy(ExecutionStrategy):
458
472
  end_time_str = result_data["end_time"]
459
473
  result_data["end_time"] = datetime.fromisoformat(end_time_str)
460
474
 
475
+ end_time = datetime.now(timezone.utc)
476
+ actual_duration = (end_time - start_time).total_seconds()
477
+
478
+ if result_data.get("error"):
479
+ logger.debug("%s subprocess failed after %.3fs: %s",
480
+ call.tool, actual_duration, result_data["error"])
481
+ else:
482
+ logger.debug("%s subprocess completed in %.3fs (limit: %ss)",
483
+ call.tool, actual_duration, timeout)
484
+
461
485
  # Create ToolResult from worker data
462
486
  return ToolResult(
463
487
  tool=result_data.get("tool", call.tool),
464
488
  result=result_data.get("result"),
465
489
  error=result_data.get("error"),
466
490
  start_time=result_data.get("start_time", start_time),
467
- end_time=result_data.get("end_time", datetime.now(timezone.utc)),
491
+ end_time=result_data.get("end_time", end_time),
468
492
  machine=result_data.get("machine", os.uname().nodename),
469
493
  pid=result_data.get("pid", os.getpid()),
470
494
  )
471
495
 
472
496
  except asyncio.TimeoutError:
473
497
  # This happens if the worker process itself hangs
498
+ end_time = datetime.now(timezone.utc)
499
+ actual_duration = (end_time - start_time).total_seconds()
500
+ logger.debug("%s subprocess timed out after %.3fs (safety limit: %ss)",
501
+ call.tool, actual_duration, safety_timeout)
502
+
474
503
  return ToolResult(
475
504
  tool=call.tool,
476
505
  result=None,
477
506
  error=f"Worker process timed out after {safety_timeout}s",
478
507
  start_time=start_time,
479
- end_time=datetime.now(timezone.utc),
508
+ end_time=end_time,
480
509
  machine=os.uname().nodename,
481
510
  pid=os.getpid(),
482
511
  )
@@ -500,6 +529,7 @@ class SubprocessStrategy(ExecutionStrategy):
500
529
 
501
530
  except asyncio.CancelledError:
502
531
  # Handle cancellation
532
+ logger.debug("%s subprocess was cancelled", call.tool)
503
533
  return ToolResult(
504
534
  tool=call.tool,
505
535
  result=None,
@@ -512,13 +542,18 @@ class SubprocessStrategy(ExecutionStrategy):
512
542
 
513
543
  except Exception as e:
514
544
  # Handle any other errors
515
- logger.exception(f"Error executing {call.tool} in subprocess: {e}")
545
+ logger.exception("Error executing %s in subprocess: %s", call.tool, e)
546
+ end_time = datetime.now(timezone.utc)
547
+ actual_duration = (end_time - start_time).total_seconds()
548
+ logger.debug("%s subprocess setup failed after %.3fs: %s",
549
+ call.tool, actual_duration, e)
550
+
516
551
  return ToolResult(
517
552
  tool=call.tool,
518
553
  result=None,
519
554
  error=f"Error: {str(e)}",
520
555
  start_time=start_time,
521
- end_time=datetime.now(timezone.utc),
556
+ end_time=end_time,
522
557
  machine=os.uname().nodename,
523
558
  pid=os.getpid(),
524
559
  )
@@ -531,7 +566,7 @@ class SubprocessStrategy(ExecutionStrategy):
531
566
  async def _signal_handler(self, sig: int) -> None:
532
567
  """Handle termination signals."""
533
568
  signame = signal.Signals(sig).name
534
- logger.info(f"Received {signame}, shutting down process pool")
569
+ logger.info("Received %s, shutting down process pool", signame)
535
570
  await self.shutdown()
536
571
 
537
572
  async def shutdown(self) -> None:
@@ -549,7 +584,7 @@ class SubprocessStrategy(ExecutionStrategy):
549
584
  # Cancel all active tasks
550
585
  active_tasks = list(self._active_tasks)
551
586
  if active_tasks:
552
- logger.info(f"Cancelling {len(active_tasks)} active tool executions")
587
+ logger.info("Cancelling %d active tool executions", len(active_tasks))
553
588
  for task in active_tasks:
554
589
  task.cancel()
555
590
 
@@ -1,6 +1,13 @@
1
1
  # chuk_tool_processor/mcp/transport/sse_transport.py
2
2
  """
3
- Server-Sent Events (SSE) transport for MCP communication implemented with **httpx**.
3
+ Proper MCP SSE transport that follows the standard MCP SSE protocol.
4
+
5
+ This transport:
6
+ 1. Connects to /sse for SSE stream
7
+ 2. Listens for 'endpoint' event to get message URL
8
+ 3. Sends MCP initialize handshake FIRST
9
+ 4. Only then proceeds with tools/list and tool calls
10
+ 5. Handles async responses via SSE message events
4
11
  """
5
12
  from __future__ import annotations
6
13
 
@@ -16,7 +23,7 @@ from .base_transport import MCPBaseTransport
16
23
  # --------------------------------------------------------------------------- #
17
24
  # Helpers #
18
25
  # --------------------------------------------------------------------------- #
19
- DEFAULT_TIMEOUT = 5.0 # seconds
26
+ DEFAULT_TIMEOUT = 30.0 # Longer timeout for real servers
20
27
  HEADERS_JSON: Dict[str, str] = {"accept": "application/json"}
21
28
 
22
29
 
@@ -30,160 +37,399 @@ def _url(base: str, path: str) -> str:
30
37
  # --------------------------------------------------------------------------- #
31
38
  class SSETransport(MCPBaseTransport):
32
39
  """
33
- Minimal SSE/REST transport. It speaks a simple REST dialect:
34
-
35
- GET /ping 200 OK
36
- GET /tools/list {"tools": [...]}
37
- POST /tools/call → {"name": ..., "result": ...}
38
- GET /resources/list {"resources": [...]}
39
- GET /prompts/list → {"prompts": [...]}
40
- GET /events → <text/event-stream>
40
+ Proper MCP SSE transport that follows the standard protocol:
41
+
42
+ 1. GET /sse Establishes SSE connection
43
+ 2. Waits for 'endpoint' event Gets message URL
44
+ 3. Sends MCP initialize handshake → Establishes session
45
+ 4. POST to message URL Sends tool calls
46
+ 5. Waits for async responses via SSE message events
41
47
  """
42
48
 
43
- EVENTS_PATH = "/events"
44
-
45
- # ------------------------------------------------------------------ #
46
- # Construction #
47
- # ------------------------------------------------------------------ #
48
49
  def __init__(self, url: str, api_key: Optional[str] = None) -> None:
49
50
  self.base_url = url.rstrip("/")
50
51
  self.api_key = api_key
51
52
 
52
53
  # httpx client (None until initialise)
53
54
  self._client: httpx.AsyncClient | None = None
54
- self.session: httpx.AsyncClient | None = None # ← kept for legacy tests
55
+ self.session: httpx.AsyncClient | None = None
55
56
 
56
- # background reader
57
- self._reader_task: asyncio.Task | None = None
58
- self._incoming_queue: "asyncio.Queue[dict[str, Any]]" = asyncio.Queue()
57
+ # MCP SSE state
58
+ self._message_url: Optional[str] = None
59
+ self._session_id: Optional[str] = None
60
+ self._sse_task: Optional[asyncio.Task] = None
61
+ self._connected = asyncio.Event()
62
+ self._initialized = asyncio.Event() # NEW: Track MCP initialization
63
+
64
+ # Async message handling
65
+ self._pending_requests: Dict[str, asyncio.Future] = {}
66
+ self._message_lock = asyncio.Lock()
59
67
 
60
68
  # ------------------------------------------------------------------ #
61
69
  # Life-cycle #
62
70
  # ------------------------------------------------------------------ #
63
71
  async def initialize(self) -> bool:
64
- """Open the httpx client and start the /events consumer."""
65
- if self._client: # already initialised
72
+ """Initialize the MCP SSE transport."""
73
+ if self._client:
66
74
  return True
67
75
 
76
+ headers = {}
77
+ if self.api_key:
78
+ headers["authorization"] = self.api_key
79
+
68
80
  self._client = httpx.AsyncClient(
69
- headers={"authorization": self.api_key} if self.api_key else None,
81
+ headers=headers,
70
82
  timeout=DEFAULT_TIMEOUT,
71
83
  )
72
- self.session = self._client # legacy attribute for tests
84
+ self.session = self._client
73
85
 
74
- # spawn reader (best-effort reconnect)
75
- self._reader_task = asyncio.create_task(self._consume_events(), name="sse-reader")
86
+ # Start SSE connection and wait for endpoint
87
+ self._sse_task = asyncio.create_task(self._handle_sse_connection())
88
+
89
+ try:
90
+ # Wait for endpoint event (up to 10 seconds)
91
+ await asyncio.wait_for(self._connected.wait(), timeout=10.0)
92
+
93
+ # NEW: Send MCP initialize handshake
94
+ if await self._initialize_mcp_session():
95
+ return True
96
+ else:
97
+ print("❌ MCP initialization failed")
98
+ return False
99
+
100
+ except asyncio.TimeoutError:
101
+ print("❌ Timeout waiting for SSE endpoint event")
102
+ return False
103
+ except Exception as e:
104
+ print(f"❌ SSE initialization failed: {e}")
105
+ return False
106
+
107
+ async def _initialize_mcp_session(self) -> bool:
108
+ """Send the required MCP initialize handshake."""
109
+ if not self._message_url:
110
+ print("❌ No message URL available for initialization")
111
+ return False
112
+
113
+ try:
114
+ print("🔄 Sending MCP initialize handshake...")
115
+
116
+ # Required MCP initialize message
117
+ init_message = {
118
+ "jsonrpc": "2.0",
119
+ "id": "initialize",
120
+ "method": "initialize",
121
+ "params": {
122
+ "protocolVersion": "2024-11-05",
123
+ "capabilities": {
124
+ "tools": {},
125
+ "resources": {},
126
+ "prompts": {},
127
+ "sampling": {}
128
+ },
129
+ "clientInfo": {
130
+ "name": "chuk-tool-processor",
131
+ "version": "1.0.0"
132
+ }
133
+ }
134
+ }
135
+
136
+ response = await self._send_message(init_message)
137
+
138
+ if "result" in response:
139
+ server_info = response["result"]
140
+ print(f"✅ MCP initialized: {server_info.get('serverInfo', {}).get('name', 'Unknown Server')}")
141
+
142
+ # Send initialized notification (required by MCP spec)
143
+ notification = {
144
+ "jsonrpc": "2.0",
145
+ "method": "notifications/initialized"
146
+ }
147
+
148
+ # Send notification (don't wait for response)
149
+ await self._send_notification(notification)
150
+ self._initialized.set()
151
+ return True
152
+ else:
153
+ print(f"❌ MCP initialization failed: {response}")
154
+ return False
155
+
156
+ except Exception as e:
157
+ print(f"❌ MCP initialization error: {e}")
158
+ return False
76
159
 
77
- # verify connection
78
- return await self.send_ping()
160
+ async def _send_notification(self, notification: Dict[str, Any]) -> None:
161
+ """Send a JSON-RPC notification (no response expected)."""
162
+ if not self._client or not self._message_url:
163
+ return
164
+
165
+ try:
166
+ headers = {"Content-Type": "application/json"}
167
+ await self._client.post(
168
+ self._message_url,
169
+ json=notification,
170
+ headers=headers
171
+ )
172
+ except Exception as e:
173
+ print(f"⚠️ Failed to send notification: {e}")
79
174
 
80
175
  async def close(self) -> None:
81
- """Stop background reader and close the httpx client."""
82
- if self._reader_task:
83
- self._reader_task.cancel()
176
+ """Close the transport."""
177
+ # Cancel any pending requests
178
+ for future in self._pending_requests.values():
179
+ if not future.done():
180
+ future.cancel()
181
+ self._pending_requests.clear()
182
+
183
+ if self._sse_task:
184
+ self._sse_task.cancel()
84
185
  with contextlib.suppress(asyncio.CancelledError):
85
- await self._reader_task
86
- self._reader_task = None
186
+ await self._sse_task
187
+ self._sse_task = None
87
188
 
88
189
  if self._client:
89
190
  await self._client.aclose()
90
191
  self._client = None
91
- self.session = None # keep tests happy
192
+ self.session = None
92
193
 
93
194
  # ------------------------------------------------------------------ #
94
- # Internal helpers #
195
+ # SSE Connection Handler #
95
196
  # ------------------------------------------------------------------ #
96
- async def _get_json(self, path: str) -> Any:
197
+ async def _handle_sse_connection(self) -> None:
198
+ """Handle the SSE connection and extract the endpoint URL."""
97
199
  if not self._client:
98
- raise RuntimeError("Transport not initialised")
99
-
100
- resp = await self._client.get(_url(self.base_url, path), headers=HEADERS_JSON)
101
- resp.raise_for_status()
102
- return resp.json()
200
+ return
103
201
 
104
- async def _post_json(self, path: str, payload: Dict[str, Any]) -> Any:
105
- if not self._client:
106
- raise RuntimeError("Transport not initialised")
202
+ try:
203
+ headers = {
204
+ "Accept": "text/event-stream",
205
+ "Cache-Control": "no-cache"
206
+ }
207
+
208
+ async with self._client.stream(
209
+ "GET", f"{self.base_url}/sse", headers=headers
210
+ ) as response:
211
+ response.raise_for_status()
212
+
213
+ async for line in response.aiter_lines():
214
+ if not line:
215
+ continue
216
+
217
+ # Parse SSE events
218
+ if line.startswith("event: "):
219
+ event_type = line[7:].strip()
220
+
221
+ elif line.startswith("data: ") and 'event_type' in locals():
222
+ data = line[6:].strip()
223
+
224
+ if event_type == "endpoint":
225
+ # Got the endpoint URL for messages - construct full URL
226
+ self._message_url = f"{self.base_url}{data}"
227
+
228
+ # Extract session_id if present
229
+ if "session_id=" in data:
230
+ self._session_id = data.split("session_id=")[1].split("&")[0]
231
+
232
+ print(f"✅ Got message endpoint: {self._message_url}")
233
+ self._connected.set()
234
+
235
+ elif event_type == "message":
236
+ # Handle incoming JSON-RPC responses
237
+ try:
238
+ message = json.loads(data)
239
+ await self._handle_incoming_message(message)
240
+ except json.JSONDecodeError:
241
+ print(f"❌ Failed to parse message: {data}")
242
+
243
+ except asyncio.CancelledError:
244
+ pass
245
+ except Exception as e:
246
+ print(f"❌ SSE connection failed: {e}")
107
247
 
108
- resp = await self._client.post(
109
- _url(self.base_url, path), json=payload, headers=HEADERS_JSON
110
- )
111
- resp.raise_for_status()
112
- return resp.json()
248
+ async def _handle_incoming_message(self, message: Dict[str, Any]) -> None:
249
+ """Handle incoming JSON-RPC response messages."""
250
+ message_id = message.get("id")
251
+ if message_id and message_id in self._pending_requests:
252
+ # Complete the pending request
253
+ future = self._pending_requests.pop(message_id)
254
+ if not future.done():
255
+ future.set_result(message)
113
256
 
114
257
  # ------------------------------------------------------------------ #
115
- # Public API (implements MCPBaseTransport) #
258
+ # MCP Protocol Methods #
116
259
  # ------------------------------------------------------------------ #
117
260
  async def send_ping(self) -> bool:
118
- if not self._client:
119
- return False
120
- try:
121
- await self._get_json("/ping")
122
- return True
123
- except Exception: # pragma: no cover
124
- return False
261
+ """Test if we have a working and initialized connection."""
262
+ return self._message_url is not None and self._initialized.is_set()
125
263
 
126
264
  async def get_tools(self) -> List[Dict[str, Any]]:
127
- if not self._client:
265
+ """Get available tools using tools/list."""
266
+ # NEW: Wait for initialization before proceeding
267
+ if not self._initialized.is_set():
268
+ print("⏳ Waiting for MCP initialization...")
269
+ try:
270
+ await asyncio.wait_for(self._initialized.wait(), timeout=10.0)
271
+ except asyncio.TimeoutError:
272
+ print("❌ Timeout waiting for MCP initialization")
273
+ return []
274
+
275
+ if not self._message_url:
128
276
  return []
277
+
129
278
  try:
130
- data = await self._get_json("/tools/list")
131
- return data.get("tools", []) if isinstance(data, dict) else []
132
- except Exception: # pragma: no cover
133
- return []
279
+ message = {
280
+ "jsonrpc": "2.0",
281
+ "id": "tools_list",
282
+ "method": "tools/list",
283
+ "params": {}
284
+ }
285
+
286
+ response = await self._send_message(message)
287
+
288
+ if "result" in response and "tools" in response["result"]:
289
+ return response["result"]["tools"]
290
+
291
+ except Exception as e:
292
+ print(f"❌ Failed to get tools: {e}")
293
+
294
+ return []
134
295
 
135
296
  async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
136
- # ─── tests expect this specific message if *not* initialised ───
137
- if not self._client:
138
- return {"isError": True, "error": "SSE transport not implemented"}
297
+ """Execute a tool call using the MCP protocol."""
298
+ # NEW: Ensure initialization before tool calls
299
+ if not self._initialized.is_set():
300
+ return {"isError": True, "error": "MCP session not initialized"}
301
+
302
+ if not self._message_url:
303
+ return {"isError": True, "error": "No message endpoint available"}
304
+
305
+ try:
306
+ message = {
307
+ "jsonrpc": "2.0",
308
+ "id": f"call_{tool_name}",
309
+ "method": "tools/call",
310
+ "params": {
311
+ "name": tool_name,
312
+ "arguments": arguments
313
+ }
314
+ }
315
+
316
+ response = await self._send_message(message)
317
+
318
+ # Process MCP response
319
+ if "error" in response:
320
+ return {
321
+ "isError": True,
322
+ "error": response["error"].get("message", "Unknown error")
323
+ }
324
+
325
+ if "result" in response:
326
+ result = response["result"]
327
+
328
+ # Handle MCP tool response format
329
+ if "content" in result:
330
+ # Extract content from MCP format
331
+ content = result["content"]
332
+ if isinstance(content, list) and content:
333
+ # Take first content item
334
+ first_content = content[0]
335
+ if isinstance(first_content, dict) and "text" in first_content:
336
+ return {"isError": False, "content": first_content["text"]}
337
+
338
+ return {"isError": False, "content": content}
339
+
340
+ # Direct result
341
+ return {"isError": False, "content": result}
342
+
343
+ return {"isError": True, "error": "No result in response"}
344
+
345
+ except Exception as e:
346
+ return {"isError": True, "error": str(e)}
347
+
348
+ async def _send_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
349
+ """Send a JSON-RPC message to the server and wait for async response."""
350
+ if not self._client or not self._message_url:
351
+ raise RuntimeError("Transport not properly initialized")
352
+
353
+ message_id = message.get("id")
354
+ if not message_id:
355
+ raise ValueError("Message must have an ID")
356
+
357
+ # Create a future for this request
358
+ future = asyncio.Future()
359
+ async with self._message_lock:
360
+ self._pending_requests[message_id] = future
139
361
 
140
362
  try:
141
- payload = {"name": tool_name, "arguments": arguments}
142
- return await self._post_json("/tools/call", payload)
143
- except Exception as exc: # pragma: no cover
144
- return {"isError": True, "error": str(exc)}
363
+ headers = {"Content-Type": "application/json"}
364
+
365
+ # Send the request
366
+ response = await self._client.post(
367
+ self._message_url,
368
+ json=message,
369
+ headers=headers
370
+ )
371
+
372
+ # Check if server accepted the request
373
+ if response.status_code == 202:
374
+ # Server accepted - wait for async response via SSE
375
+ try:
376
+ response_message = await asyncio.wait_for(future, timeout=30.0)
377
+ return response_message
378
+ except asyncio.TimeoutError:
379
+ raise RuntimeError(f"Timeout waiting for response to message {message_id}")
380
+ else:
381
+ # Immediate response - parse and return
382
+ response.raise_for_status()
383
+ return response.json()
384
+
385
+ finally:
386
+ # Clean up pending request
387
+ async with self._message_lock:
388
+ self._pending_requests.pop(message_id, None)
145
389
 
146
- # ----------------------- extras used by StreamManager ------------- #
390
+ # ------------------------------------------------------------------ #
391
+ # Additional MCP methods #
392
+ # ------------------------------------------------------------------ #
147
393
  async def list_resources(self) -> List[Dict[str, Any]]:
148
- if not self._client:
394
+ """List available resources."""
395
+ if not self._initialized.is_set() or not self._message_url:
149
396
  return []
397
+
150
398
  try:
151
- data = await self._get_json("/resources/list")
152
- return data.get("resources", []) if isinstance(data, dict) else []
153
- except Exception: # pragma: no cover
154
- return []
399
+ message = {
400
+ "jsonrpc": "2.0",
401
+ "id": "resources_list",
402
+ "method": "resources/list",
403
+ "params": {}
404
+ }
405
+
406
+ response = await self._send_message(message)
407
+ if "result" in response and "resources" in response["result"]:
408
+ return response["result"]["resources"]
409
+
410
+ except Exception:
411
+ pass
412
+
413
+ return []
155
414
 
156
415
  async def list_prompts(self) -> List[Dict[str, Any]]:
157
- if not self._client:
416
+ """List available prompts."""
417
+ if not self._initialized.is_set() or not self._message_url:
158
418
  return []
419
+
159
420
  try:
160
- data = await self._get_json("/prompts/list")
161
- return data.get("prompts", []) if isinstance(data, dict) else []
162
- except Exception: # pragma: no cover
163
- return []
164
-
165
- # ------------------------------------------------------------------ #
166
- # Background event-stream reader #
167
- # ------------------------------------------------------------------ #
168
- async def _consume_events(self) -> None: # pragma: no cover
169
- """Continuously read `/events` and push JSON objects onto a queue."""
170
- if not self._client:
171
- return
172
-
173
- while True:
174
- try:
175
- async with self._client.stream(
176
- "GET", _url(self.base_url, self.EVENTS_PATH), headers=HEADERS_JSON
177
- ) as resp:
178
- resp.raise_for_status()
179
- async for line in resp.aiter_lines():
180
- if not line:
181
- continue
182
- try:
183
- await self._incoming_queue.put(json.loads(line))
184
- except json.JSONDecodeError:
185
- continue
186
- except asyncio.CancelledError:
187
- break
188
- except Exception:
189
- await asyncio.sleep(1.0) # back-off and retry
421
+ message = {
422
+ "jsonrpc": "2.0",
423
+ "id": "prompts_list",
424
+ "method": "prompts/list",
425
+ "params": {}
426
+ }
427
+
428
+ response = await self._send_message(message)
429
+ if "result" in response and "prompts" in response["result"]:
430
+ return response["result"]["prompts"]
431
+
432
+ except Exception:
433
+ pass
434
+
435
+ return []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chuk-tool-processor
3
- Version: 0.1.7
3
+ Version: 0.2
4
4
  Summary: Add your description here
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -5,8 +5,8 @@ chuk_tool_processor/core/processor.py,sha256=ttEYZTQHctXXiUP8gxAMCCSjbRvyOHojQe_
5
5
  chuk_tool_processor/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  chuk_tool_processor/execution/tool_executor.py,sha256=NSzmvqGMMyKuVapJAmPr-YtNgGhZI3fcAxhilyGG5kY,12174
7
7
  chuk_tool_processor/execution/strategies/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- chuk_tool_processor/execution/strategies/inprocess_strategy.py,sha256=7Vx8zZ10DSK73tfinahAHElprctIJ1f4WKaR8lJf_Jk,21710
9
- chuk_tool_processor/execution/strategies/subprocess_strategy.py,sha256=6ByvqHhZ5fenrV7yPNRUHeid-htTiVu05Gn0n6ImXg4,20477
8
+ chuk_tool_processor/execution/strategies/inprocess_strategy.py,sha256=UJIv1g3Z9LpMsTYa9cqJB376StsI0up3cftH4OkqC2I,22582
9
+ chuk_tool_processor/execution/strategies/subprocess_strategy.py,sha256=Rb5GTffl-4dkAQG_zz8wjggqyWznVOr9gReLGHmE2io,22469
10
10
  chuk_tool_processor/execution/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  chuk_tool_processor/execution/wrappers/caching.py,sha256=1pSyouYT4H7AGkNcK_7wWAIT1d4AKnHJlKBODPO8tZw,20416
12
12
  chuk_tool_processor/execution/wrappers/rate_limiting.py,sha256=CBBsI1VLosjo8dZXLeJ3IaclGvy9VdjGyqgunY089KQ,9231
@@ -24,7 +24,7 @@ chuk_tool_processor/mcp/setup_mcp_stdio.py,sha256=P9qSgmxoNQbsOlGp83DlLLpN9BsG__
24
24
  chuk_tool_processor/mcp/stream_manager.py,sha256=mrmlG54P_xLbDYz_rBjdu-OPMnbi916dgyJg7BrIbjM,12798
25
25
  chuk_tool_processor/mcp/transport/__init__.py,sha256=7QQqeSKVKv0N9GcyJuYF0R4FDZeooii5RjggvFFg5GY,296
26
26
  chuk_tool_processor/mcp/transport/base_transport.py,sha256=1E29LjWw5vLQrPUDF_9TJt63P5dxAAN7n6E_KiZbGUY,3427
27
- chuk_tool_processor/mcp/transport/sse_transport.py,sha256=bryH9DOWOn5qr6LsimTriukDC4ix2kuRq6bUv9qOV20,7645
27
+ chuk_tool_processor/mcp/transport/sse_transport.py,sha256=AkEs02ef11dLbBju6mYIZwdMF6zm0tcME_I8LEVSmrQ,16710
28
28
  chuk_tool_processor/mcp/transport/stdio_transport.py,sha256=lFXL7p8ca4z_J0RBL8UCHrQ1UH7C2-LbC0tZhpya4V4,7763
29
29
  chuk_tool_processor/models/__init__.py,sha256=TC__rdVa0lQsmJHM_hbLDPRgToa_pQT_UxRcPZk6iVw,40
30
30
  chuk_tool_processor/models/execution_strategy.py,sha256=UVW35YIeMY2B3mpIKZD2rAkyOPayI6ckOOUALyf0YiQ,2115
@@ -52,7 +52,7 @@ chuk_tool_processor/registry/providers/__init__.py,sha256=eigwG_So11j7WbDGSWaKd3
52
52
  chuk_tool_processor/registry/providers/memory.py,sha256=LlpPUU9E7S8Se6Q3VyKxLwpNm82SvmP8GLUmI8MkHxQ,5188
53
53
  chuk_tool_processor/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
54
  chuk_tool_processor/utils/validation.py,sha256=fiTSsHq7zx-kyd755GaFCvPCa-EVasSpg0A1liNHkxU,4138
55
- chuk_tool_processor-0.1.7.dist-info/METADATA,sha256=vqD7WCOAdv5CWGyQv0Hyp3oxWKGDkslRby84vVdCvLw,10165
56
- chuk_tool_processor-0.1.7.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
57
- chuk_tool_processor-0.1.7.dist-info/top_level.txt,sha256=7lTsnuRx4cOW4U2sNJWNxl4ZTt_J1ndkjTbj3pHPY5M,20
58
- chuk_tool_processor-0.1.7.dist-info/RECORD,,
55
+ chuk_tool_processor-0.2.dist-info/METADATA,sha256=BF2f_DLVJk59zAMransa5Ca3wH5alCzPih6xFsNkscc,10163
56
+ chuk_tool_processor-0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
57
+ chuk_tool_processor-0.2.dist-info/top_level.txt,sha256=7lTsnuRx4cOW4U2sNJWNxl4ZTt_J1ndkjTbj3pHPY5M,20
58
+ chuk_tool_processor-0.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5