chuk-tool-processor 0.6.7__py3-none-any.whl → 0.6.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -1,14 +1,24 @@
1
1
  # chuk_tool_processor/mcp/transport/sse_transport.py
2
2
  """
3
- Fixed SSE transport that matches your server's actual behavior.
4
- Based on your working debug script with smart URL detection.
3
+ SSE transport for MCP communication.
4
+
5
+ Implements Server-Sent Events transport with two-step async pattern:
6
+ 1. POST messages to /messages endpoint
7
+ 2. Receive responses via SSE stream
8
+
9
+ Note: This transport is deprecated in favor of HTTP Streamable (spec 2025-03-26)
10
+ but remains supported for backward compatibility.
11
+
12
+ FIXED: Updated to support both old format (/messages/) and new event-based format
13
+ (event: endpoint + data: https://...) for session discovery.
5
14
  """
6
15
  from __future__ import annotations
7
16
 
8
17
  import asyncio
9
18
  import json
19
+ import time
10
20
  import uuid
11
- from typing import Dict, Any, List, Optional, Tuple
21
+ from typing import Dict, Any, List, Optional
12
22
  import logging
13
23
 
14
24
  import httpx
@@ -20,28 +30,45 @@ logger = logging.getLogger(__name__)
20
30
 
21
31
  class SSETransport(MCPBaseTransport):
22
32
  """
23
- SSE transport that works with your server's two-step async pattern:
24
- 1. POST messages to /messages endpoint
25
- 2. Receive responses via SSE stream
33
+ SSE transport implementing the MCP protocol over Server-Sent Events.
34
+
35
+ This transport uses a dual-connection approach:
36
+ - SSE stream for receiving responses
37
+ - HTTP POST for sending requests
38
+
39
+ FIXED: Supports both old and new session discovery formats.
26
40
  """
27
41
 
28
42
  def __init__(self, url: str, api_key: Optional[str] = None,
29
43
  headers: Optional[Dict[str, str]] = None,
30
- connection_timeout: float = 30.0, default_timeout: float = 30.0):
31
- """Initialize SSE transport."""
44
+ connection_timeout: float = 30.0,
45
+ default_timeout: float = 30.0,
46
+ enable_metrics: bool = True):
47
+ """
48
+ Initialize SSE transport.
49
+
50
+ Args:
51
+ url: Base URL for the MCP server
52
+ api_key: Optional API key for authentication
53
+ headers: Optional custom headers
54
+ connection_timeout: Timeout for initial connection setup
55
+ default_timeout: Default timeout for operations
56
+ enable_metrics: Whether to track performance metrics
57
+ """
32
58
  self.url = url.rstrip('/')
33
59
  self.api_key = api_key
34
60
  self.configured_headers = headers or {}
35
61
  self.connection_timeout = connection_timeout
36
62
  self.default_timeout = default_timeout
63
+ self.enable_metrics = enable_metrics
37
64
 
38
- # DEBUG: Log what we received
39
- logger.debug("SSE Transport initialized with:")
40
- logger.debug(" URL: %s", self.url)
41
- logger.debug(" API Key: %s", "***" if api_key else None)
42
- logger.debug(" Headers: %s", {k: v[:10] + "..." if len(v) > 10 else v for k, v in self.configured_headers.items()})
65
+ logger.debug("SSE Transport initialized with URL: %s", self.url)
66
+ if self.api_key:
67
+ logger.debug("API key configured for authentication")
68
+ if self.configured_headers:
69
+ logger.debug("Custom headers configured: %s", list(self.configured_headers.keys()))
43
70
 
44
- # State
71
+ # Connection state
45
72
  self.session_id = None
46
73
  self.message_url = None
47
74
  self.pending_requests: Dict[str, asyncio.Future] = {}
@@ -51,10 +78,23 @@ class SSETransport(MCPBaseTransport):
51
78
  self.stream_client = None
52
79
  self.send_client = None
53
80
 
54
- # SSE stream
81
+ # SSE stream management
55
82
  self.sse_task = None
56
83
  self.sse_response = None
57
84
  self.sse_stream_context = None
85
+
86
+ # Performance metrics (consistent with other transports)
87
+ self._metrics = {
88
+ "total_calls": 0,
89
+ "successful_calls": 0,
90
+ "failed_calls": 0,
91
+ "total_time": 0.0,
92
+ "avg_response_time": 0.0,
93
+ "last_ping_time": None,
94
+ "initialization_time": None,
95
+ "session_discoveries": 0,
96
+ "stream_errors": 0
97
+ }
58
98
 
59
99
  def _construct_sse_url(self, base_url: str) -> str:
60
100
  """
@@ -62,53 +102,48 @@ class SSETransport(MCPBaseTransport):
62
102
 
63
103
  Smart detection to avoid double-appending /sse if already present.
64
104
  """
65
- # Remove trailing slashes
66
105
  base_url = base_url.rstrip('/')
67
106
 
68
- # Check if URL already ends with /sse
69
107
  if base_url.endswith('/sse'):
70
- # Already has /sse, use as-is
71
108
  logger.debug("URL already contains /sse endpoint: %s", base_url)
72
109
  return base_url
73
110
 
74
- # Append /sse to the base URL
75
111
  sse_url = f"{base_url}/sse"
76
- logger.debug("Appending /sse to base URL: %s -> %s", base_url, sse_url)
112
+ logger.debug("Constructed SSE URL: %s -> %s", base_url, sse_url)
77
113
  return sse_url
78
114
 
79
115
  def _get_headers(self) -> Dict[str, str]:
80
- """Get headers with auth if available."""
116
+ """Get headers with authentication and custom headers."""
81
117
  headers = {}
82
118
 
83
119
  # Add configured headers first
84
120
  if self.configured_headers:
85
121
  headers.update(self.configured_headers)
86
122
 
87
- # Add API key as Bearer token if provided (this will override any Authorization header)
123
+ # Add API key as Bearer token if provided (overrides Authorization header)
88
124
  if self.api_key:
89
125
  headers['Authorization'] = f'Bearer {self.api_key}'
90
126
 
91
- # DEBUG: Log what headers we're sending
92
- logger.debug("Sending headers: %s", {k: v[:10] + "..." if len(v) > 10 else v for k, v in headers.items()})
93
-
94
127
  return headers
95
128
 
96
129
  async def initialize(self) -> bool:
97
- """Initialize SSE connection and MCP handshake."""
130
+ """Initialize SSE connection and perform MCP handshake."""
98
131
  if self._initialized:
99
132
  logger.warning("Transport already initialized")
100
133
  return True
101
134
 
135
+ start_time = time.time()
136
+
102
137
  try:
103
138
  logger.debug("Initializing SSE transport...")
104
139
 
105
- # Create HTTP clients
140
+ # Create HTTP clients with appropriate timeouts
106
141
  self.stream_client = httpx.AsyncClient(timeout=self.connection_timeout)
107
142
  self.send_client = httpx.AsyncClient(timeout=self.default_timeout)
108
143
 
109
- # Connect to SSE stream with smart URL construction
144
+ # Connect to SSE stream
110
145
  sse_url = self._construct_sse_url(self.url)
111
- logger.debug("Connecting to SSE: %s", sse_url)
146
+ logger.debug("Connecting to SSE endpoint: %s", sse_url)
112
147
 
113
148
  self.sse_stream_context = self.stream_client.stream(
114
149
  'GET', sse_url, headers=self._get_headers()
@@ -116,28 +151,37 @@ class SSETransport(MCPBaseTransport):
116
151
  self.sse_response = await self.sse_stream_context.__aenter__()
117
152
 
118
153
  if self.sse_response.status_code != 200:
119
- logger.error("SSE connection failed: %s", self.sse_response.status_code)
154
+ logger.error("SSE connection failed with status: %s", self.sse_response.status_code)
155
+ await self._cleanup()
120
156
  return False
121
157
 
122
158
  logger.debug("SSE streaming connection established")
123
159
 
124
160
  # Start SSE processing task
125
- self.sse_task = asyncio.create_task(self._process_sse_stream())
161
+ self.sse_task = asyncio.create_task(
162
+ self._process_sse_stream(),
163
+ name="sse_stream_processor"
164
+ )
126
165
 
127
- # Wait for session discovery
166
+ # Wait for session discovery with timeout
128
167
  logger.debug("Waiting for session discovery...")
129
- for i in range(50): # 5 seconds max
130
- if self.message_url:
131
- break
168
+ session_timeout = 5.0 # 5 seconds max for session discovery
169
+ session_start = time.time()
170
+
171
+ while not self.message_url and (time.time() - session_start) < session_timeout:
132
172
  await asyncio.sleep(0.1)
133
173
 
134
174
  if not self.message_url:
135
- logger.error("Failed to get session info from SSE")
175
+ logger.error("Failed to discover session endpoint within %.1fs", session_timeout)
176
+ await self._cleanup()
136
177
  return False
137
178
 
138
- logger.debug("Session ready: %s", self.session_id)
179
+ if self.enable_metrics:
180
+ self._metrics["session_discoveries"] += 1
181
+
182
+ logger.debug("Session endpoint discovered: %s", self.session_id)
139
183
 
140
- # Now do MCP initialization
184
+ # Perform MCP initialization handshake
141
185
  try:
142
186
  init_response = await self._send_request("initialize", {
143
187
  "protocolVersion": "2024-11-05",
@@ -149,18 +193,25 @@ class SSETransport(MCPBaseTransport):
149
193
  })
150
194
 
151
195
  if 'error' in init_response:
152
- logger.error("Initialize failed: %s", init_response['error'])
196
+ logger.error("MCP initialize failed: %s", init_response['error'])
197
+ await self._cleanup()
153
198
  return False
154
199
 
155
200
  # Send initialized notification
156
201
  await self._send_notification("notifications/initialized")
157
202
 
158
203
  self._initialized = True
159
- logger.debug("SSE transport initialized successfully")
204
+
205
+ if self.enable_metrics:
206
+ init_time = time.time() - start_time
207
+ self._metrics["initialization_time"] = init_time
208
+
209
+ logger.debug("SSE transport initialized successfully in %.3fs", time.time() - start_time)
160
210
  return True
161
211
 
162
212
  except Exception as e:
163
- logger.error("MCP initialization failed: %s", e)
213
+ logger.error("MCP handshake failed: %s", e)
214
+ await self._cleanup()
164
215
  return False
165
216
 
166
217
  except Exception as e:
@@ -169,58 +220,94 @@ class SSETransport(MCPBaseTransport):
169
220
  return False
170
221
 
171
222
  async def _process_sse_stream(self):
172
- """Process the persistent SSE stream."""
223
+ """
224
+ Process the persistent SSE stream for responses and session discovery.
225
+
226
+ FIXED: Supports both old format (/messages/) and new event-based format
227
+ (event: endpoint + data: https://...) for session discovery.
228
+ """
173
229
  try:
174
230
  logger.debug("Starting SSE stream processing...")
175
231
 
232
+ current_event = None # Track current event type
233
+
176
234
  async for line in self.sse_response.aiter_lines():
177
235
  line = line.strip()
178
236
  if not line:
179
237
  continue
180
238
 
181
- # Handle session endpoint discovery
182
- if not self.message_url and line.startswith('data:') and '/messages/' in line:
183
- endpoint_path = line.split(':', 1)[1].strip()
184
- self.message_url = f"{self.url}{endpoint_path}"
239
+ # Handle event type declarations
240
+ if line.startswith('event:'):
241
+ current_event = line.split(':', 1)[1].strip()
242
+ logger.debug("SSE event type: %s", current_event)
243
+ continue
244
+
245
+ # Handle session endpoint discovery (BOTH FORMATS)
246
+ if not self.message_url and line.startswith('data:'):
247
+ data_part = line.split(':', 1)[1].strip()
185
248
 
186
- if 'session_id=' in endpoint_path:
187
- self.session_id = endpoint_path.split('session_id=')[1].split('&')[0]
249
+ # NEW FORMAT: event: endpoint + data: https://...
250
+ if current_event == "endpoint" and data_part.startswith('http'):
251
+ self.message_url = data_part
252
+
253
+ # Extract session ID from URL if present
254
+ if 'session_id=' in data_part:
255
+ self.session_id = data_part.split('session_id=')[1].split('&')[0]
256
+
257
+ logger.debug("Session endpoint discovered via event format: %s", self.session_id)
258
+ continue
188
259
 
189
- logger.debug("Got session info: %s", self.session_id)
190
- continue
260
+ # OLD FORMAT: data: /messages/... (backwards compatibility)
261
+ elif '/messages/' in data_part:
262
+ endpoint_path = data_part
263
+ self.message_url = f"{self.url}{endpoint_path}"
264
+
265
+ # Extract session ID if present
266
+ if 'session_id=' in endpoint_path:
267
+ self.session_id = endpoint_path.split('session_id=')[1].split('&')[0]
268
+
269
+ logger.debug("Session endpoint discovered via old format: %s", self.session_id)
270
+ continue
191
271
 
192
272
  # Handle JSON-RPC responses
193
273
  if line.startswith('data:'):
194
274
  data_part = line.split(':', 1)[1].strip()
195
275
 
196
- # Skip pings and empty data
197
- if not data_part or data_part.startswith('ping'):
276
+ # Skip keepalive pings and empty data
277
+ if not data_part or data_part.startswith('ping') or data_part in ('{}', '[]'):
198
278
  continue
199
279
 
200
280
  try:
201
281
  response_data = json.loads(data_part)
202
282
 
283
+ # Handle JSON-RPC responses with request IDs
203
284
  if 'jsonrpc' in response_data and 'id' in response_data:
204
285
  request_id = str(response_data['id'])
205
286
 
206
- # Resolve pending request
287
+ # Resolve pending request if found
207
288
  if request_id in self.pending_requests:
208
289
  future = self.pending_requests.pop(request_id)
209
290
  if not future.done():
210
291
  future.set_result(response_data)
211
- logger.debug("Resolved request: %s", request_id)
292
+ logger.debug("Resolved request ID: %s", request_id)
212
293
 
213
- except json.JSONDecodeError:
214
- pass # Not JSON, ignore
294
+ except json.JSONDecodeError as e:
295
+ logger.debug("Non-JSON data in SSE stream (ignoring): %s", e)
296
+
297
+ # Reset event type after processing data (only if we processed JSON-RPC)
298
+ if line.startswith('data:') and current_event not in ("endpoint",):
299
+ current_event = None
215
300
 
216
301
  except Exception as e:
217
- logger.error("SSE stream error: %s", e)
302
+ if self.enable_metrics:
303
+ self._metrics["stream_errors"] += 1
304
+ logger.error("SSE stream processing error: %s", e)
218
305
 
219
306
  async def _send_request(self, method: str, params: Dict[str, Any] = None,
220
307
  timeout: Optional[float] = None) -> Dict[str, Any]:
221
- """Send request and wait for async response."""
308
+ """Send JSON-RPC request and wait for async response via SSE."""
222
309
  if not self.message_url:
223
- raise RuntimeError("Not connected")
310
+ raise RuntimeError("SSE transport not connected - no message URL")
224
311
 
225
312
  request_id = str(uuid.uuid4())
226
313
  message = {
@@ -230,12 +317,12 @@ class SSETransport(MCPBaseTransport):
230
317
  "params": params or {}
231
318
  }
232
319
 
233
- # Create future for response
320
+ # Create future for async response
234
321
  future = asyncio.Future()
235
322
  self.pending_requests[request_id] = future
236
323
 
237
324
  try:
238
- # Send message
325
+ # Send HTTP POST request
239
326
  headers = {
240
327
  'Content-Type': 'application/json',
241
328
  **self._get_headers()
@@ -248,9 +335,9 @@ class SSETransport(MCPBaseTransport):
248
335
  )
249
336
 
250
337
  if response.status_code == 202:
251
- # Wait for async response
252
- timeout = timeout or self.default_timeout
253
- result = await asyncio.wait_for(future, timeout=timeout)
338
+ # Async response - wait for result via SSE
339
+ request_timeout = timeout or self.default_timeout
340
+ result = await asyncio.wait_for(future, timeout=request_timeout)
254
341
  return result
255
342
  elif response.status_code == 200:
256
343
  # Immediate response
@@ -258,7 +345,7 @@ class SSETransport(MCPBaseTransport):
258
345
  return response.json()
259
346
  else:
260
347
  self.pending_requests.pop(request_id, None)
261
- raise RuntimeError(f"Request failed: {response.status_code}")
348
+ raise RuntimeError(f"HTTP request failed with status: {response.status_code}")
262
349
 
263
350
  except asyncio.TimeoutError:
264
351
  self.pending_requests.pop(request_id, None)
@@ -268,9 +355,9 @@ class SSETransport(MCPBaseTransport):
268
355
  raise
269
356
 
270
357
  async def _send_notification(self, method: str, params: Dict[str, Any] = None):
271
- """Send notification (no response expected)."""
358
+ """Send JSON-RPC notification (no response expected)."""
272
359
  if not self.message_url:
273
- raise RuntimeError("Not connected")
360
+ raise RuntimeError("SSE transport not connected - no message URL")
274
361
 
275
362
  message = {
276
363
  "jsonrpc": "2.0",
@@ -283,30 +370,46 @@ class SSETransport(MCPBaseTransport):
283
370
  **self._get_headers()
284
371
  }
285
372
 
286
- await self.send_client.post(
373
+ response = await self.send_client.post(
287
374
  self.message_url,
288
375
  headers=headers,
289
376
  json=message
290
377
  )
378
+
379
+ if response.status_code not in (200, 202):
380
+ logger.warning("Notification failed with status: %s", response.status_code)
291
381
 
292
382
  async def send_ping(self) -> bool:
293
- """Send ping to check connection."""
383
+ """Send ping to check connection health."""
294
384
  if not self._initialized:
295
385
  return False
296
386
 
387
+ start_time = time.time()
297
388
  try:
298
- # Your server might not support ping, so we'll just check if we can list tools
389
+ # Use tools/list as a lightweight ping since not all servers support ping
299
390
  response = await self._send_request("tools/list", {}, timeout=5.0)
391
+
392
+ if self.enable_metrics:
393
+ ping_time = time.time() - start_time
394
+ self._metrics["last_ping_time"] = ping_time
395
+ logger.debug("SSE ping completed in %.3fs", ping_time)
396
+
300
397
  return 'error' not in response
301
- except Exception:
398
+ except Exception as e:
399
+ logger.debug("SSE ping failed: %s", e)
302
400
  return False
303
401
 
402
+ def is_connected(self) -> bool:
403
+ """Check if the transport is connected and ready."""
404
+ return self._initialized and self.session_id is not None
405
+
304
406
  async def get_tools(self) -> List[Dict[str, Any]]:
305
- """Get tools list."""
407
+ """Get list of available tools from the server."""
306
408
  if not self._initialized:
307
409
  logger.error("Cannot get tools: transport not initialized")
308
410
  return []
309
411
 
412
+ start_time = time.time()
310
413
  try:
311
414
  response = await self._send_request("tools/list", {})
312
415
 
@@ -315,7 +418,11 @@ class SSETransport(MCPBaseTransport):
315
418
  return []
316
419
 
317
420
  tools = response.get('result', {}).get('tools', [])
318
- logger.debug("Retrieved %d tools", len(tools))
421
+
422
+ if self.enable_metrics:
423
+ response_time = time.time() - start_time
424
+ logger.debug("Retrieved %d tools in %.3fs", len(tools), response_time)
425
+
319
426
  return tools
320
427
 
321
428
  except Exception as e:
@@ -324,15 +431,19 @@ class SSETransport(MCPBaseTransport):
324
431
 
325
432
  async def call_tool(self, tool_name: str, arguments: Dict[str, Any],
326
433
  timeout: Optional[float] = None) -> Dict[str, Any]:
327
- """Call a tool."""
434
+ """Execute a tool with the given arguments."""
328
435
  if not self._initialized:
329
436
  return {
330
437
  "isError": True,
331
438
  "error": "Transport not initialized"
332
439
  }
333
440
 
441
+ start_time = time.time()
442
+ if self.enable_metrics:
443
+ self._metrics["total_calls"] += 1
444
+
334
445
  try:
335
- logger.debug("Calling tool %s with args: %s", tool_name, arguments)
446
+ logger.debug("Calling tool '%s' with arguments: %s", tool_name, arguments)
336
447
 
337
448
  response = await self._send_request(
338
449
  "tools/call",
@@ -344,58 +455,56 @@ class SSETransport(MCPBaseTransport):
344
455
  )
345
456
 
346
457
  if 'error' in response:
458
+ if self.enable_metrics:
459
+ self._update_metrics(time.time() - start_time, False)
460
+
347
461
  return {
348
462
  "isError": True,
349
463
  "error": response['error'].get('message', 'Unknown error')
350
464
  }
351
465
 
352
- # Extract result
466
+ # Extract and normalize result using base class method
353
467
  result = response.get('result', {})
468
+ normalized_result = self._normalize_mcp_response({"result": result})
354
469
 
355
- # Handle content format
356
- if 'content' in result:
357
- content = result['content']
358
- if isinstance(content, list) and len(content) == 1:
359
- content_item = content[0]
360
- if isinstance(content_item, dict) and content_item.get('type') == 'text':
361
- text_content = content_item.get('text', '')
362
- try:
363
- # Try to parse as JSON
364
- parsed_content = json.loads(text_content)
365
- return {
366
- "isError": False,
367
- "content": parsed_content
368
- }
369
- except json.JSONDecodeError:
370
- return {
371
- "isError": False,
372
- "content": text_content
373
- }
374
-
375
- return {
376
- "isError": False,
377
- "content": content
378
- }
470
+ if self.enable_metrics:
471
+ self._update_metrics(time.time() - start_time, True)
379
472
 
380
- return {
381
- "isError": False,
382
- "content": result
383
- }
473
+ return normalized_result
384
474
 
385
475
  except asyncio.TimeoutError:
476
+ if self.enable_metrics:
477
+ self._update_metrics(time.time() - start_time, False)
478
+
386
479
  return {
387
480
  "isError": True,
388
481
  "error": "Tool execution timed out"
389
482
  }
390
483
  except Exception as e:
391
- logger.error("Error calling tool %s: %s", tool_name, e)
484
+ if self.enable_metrics:
485
+ self._update_metrics(time.time() - start_time, False)
486
+
487
+ logger.error("Error calling tool '%s': %s", tool_name, e)
392
488
  return {
393
489
  "isError": True,
394
490
  "error": str(e)
395
491
  }
396
492
 
493
+ def _update_metrics(self, response_time: float, success: bool) -> None:
494
+ """Update performance metrics."""
495
+ if success:
496
+ self._metrics["successful_calls"] += 1
497
+ else:
498
+ self._metrics["failed_calls"] += 1
499
+
500
+ self._metrics["total_time"] += response_time
501
+ if self._metrics["total_calls"] > 0:
502
+ self._metrics["avg_response_time"] = (
503
+ self._metrics["total_time"] / self._metrics["total_calls"]
504
+ )
505
+
397
506
  async def list_resources(self) -> Dict[str, Any]:
398
- """List resources."""
507
+ """List available resources from the server."""
399
508
  if not self._initialized:
400
509
  return {}
401
510
 
@@ -405,11 +514,12 @@ class SSETransport(MCPBaseTransport):
405
514
  logger.debug("Resources not supported: %s", response['error'])
406
515
  return {}
407
516
  return response.get('result', {})
408
- except Exception:
517
+ except Exception as e:
518
+ logger.debug("Error listing resources: %s", e)
409
519
  return {}
410
520
 
411
521
  async def list_prompts(self) -> Dict[str, Any]:
412
- """List prompts."""
522
+ """List available prompts from the server."""
413
523
  if not self._initialized:
414
524
  return {}
415
525
 
@@ -419,59 +529,104 @@ class SSETransport(MCPBaseTransport):
419
529
  logger.debug("Prompts not supported: %s", response['error'])
420
530
  return {}
421
531
  return response.get('result', {})
422
- except Exception:
532
+ except Exception as e:
533
+ logger.debug("Error listing prompts: %s", e)
423
534
  return {}
424
535
 
425
536
  async def close(self) -> None:
426
- """Close the transport."""
537
+ """Close the transport and clean up resources."""
538
+ if not self._initialized:
539
+ return
540
+
541
+ # Log final metrics
542
+ if self.enable_metrics and self._metrics["total_calls"] > 0:
543
+ logger.debug(
544
+ "SSE transport closing - Total calls: %d, Success rate: %.1f%%, Avg response time: %.3fs",
545
+ self._metrics["total_calls"],
546
+ (self._metrics["successful_calls"] / self._metrics["total_calls"] * 100),
547
+ self._metrics["avg_response_time"]
548
+ )
549
+
427
550
  await self._cleanup()
428
551
 
429
552
  async def _cleanup(self) -> None:
430
- """Clean up resources."""
431
- if self.sse_task:
553
+ """Clean up all resources and reset state."""
554
+ # Cancel SSE processing task
555
+ if self.sse_task and not self.sse_task.done():
432
556
  self.sse_task.cancel()
433
557
  try:
434
558
  await self.sse_task
435
559
  except asyncio.CancelledError:
436
560
  pass
437
561
 
562
+ # Close SSE stream context
438
563
  if self.sse_stream_context:
439
564
  try:
440
565
  await self.sse_stream_context.__aexit__(None, None, None)
441
- except Exception:
442
- pass
566
+ except Exception as e:
567
+ logger.debug("Error closing SSE stream: %s", e)
443
568
 
569
+ # Close HTTP clients
444
570
  if self.stream_client:
445
571
  await self.stream_client.aclose()
446
572
 
447
573
  if self.send_client:
448
574
  await self.send_client.aclose()
449
575
 
576
+ # Cancel any pending requests
577
+ for request_id, future in self.pending_requests.items():
578
+ if not future.done():
579
+ future.cancel()
580
+
581
+ # Reset state
450
582
  self._initialized = False
451
583
  self.session_id = None
452
584
  self.message_url = None
453
585
  self.pending_requests.clear()
586
+ self.sse_task = None
587
+ self.sse_response = None
588
+ self.sse_stream_context = None
589
+ self.stream_client = None
590
+ self.send_client = None
591
+
592
+ # ------------------------------------------------------------------ #
593
+ # Metrics and monitoring (consistent with other transports) #
594
+ # ------------------------------------------------------------------ #
595
+ def get_metrics(self) -> Dict[str, Any]:
596
+ """Get performance and connection metrics."""
597
+ return self._metrics.copy()
598
+
599
+ def reset_metrics(self) -> None:
600
+ """Reset performance metrics."""
601
+ self._metrics = {
602
+ "total_calls": 0,
603
+ "successful_calls": 0,
604
+ "failed_calls": 0,
605
+ "total_time": 0.0,
606
+ "avg_response_time": 0.0,
607
+ "last_ping_time": self._metrics.get("last_ping_time"),
608
+ "initialization_time": self._metrics.get("initialization_time"),
609
+ "session_discoveries": self._metrics.get("session_discoveries", 0),
610
+ "stream_errors": 0
611
+ }
454
612
 
613
+ # ------------------------------------------------------------------ #
614
+ # Backward compatibility #
615
+ # ------------------------------------------------------------------ #
455
616
  def get_streams(self) -> List[tuple]:
456
- """Not applicable for this transport."""
617
+ """SSE transport doesn't expose raw streams."""
457
618
  return []
458
619
 
459
- def is_connected(self) -> bool:
460
- """Check if connected."""
461
- return self._initialized and self.session_id is not None
462
-
620
+ # ------------------------------------------------------------------ #
621
+ # Context manager support #
622
+ # ------------------------------------------------------------------ #
463
623
  async def __aenter__(self):
464
- """Context manager support."""
624
+ """Context manager entry."""
465
625
  success = await self.initialize()
466
626
  if not success:
467
- raise RuntimeError("Failed to initialize SSE transport")
627
+ raise RuntimeError("Failed to initialize SSETransport")
468
628
  return self
469
629
 
470
630
  async def __aexit__(self, exc_type, exc_val, exc_tb):
471
631
  """Context manager cleanup."""
472
- await self.close()
473
-
474
- def __repr__(self) -> str:
475
- """String representation."""
476
- status = "initialized" if self._initialized else "not initialized"
477
- return f"SSETransport(status={status}, url={self.url}, session={self.session_id})"
632
+ await self.close()