mcp-debugger 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,446 @@
1
+ """Core Replay Engine for replaying recorded MCP sessions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import sys
9
+ import time
10
+ from datetime import datetime, timezone
11
+ from typing import Any, Callable, Dict, List, Optional, Set
12
+ from pydantic import BaseModel
13
+
14
+ from mcp_debugger.storage.database import Database
15
+ from mcp_debugger.replay.diff import DiffNode, compare_json, render_diff
16
+
17
+ logger = logging.getLogger("mcp_debugger.replay")
18
+
19
+
20
+ class ReplayedMessage(BaseModel):
21
+ """Result of replaying a single message."""
22
+
23
+ original_message_id: int
24
+ method: str
25
+ request_sent: Dict[str, Any]
26
+ original_response: Optional[Dict[str, Any]] = None
27
+ replayed_response: Optional[Dict[str, Any]] = None
28
+ error: Optional[str] = None
29
+ latency_ms: float
30
+ matches: bool
31
+ diff: Optional[List[DiffNode]] = None
32
+ diff_text: Optional[str] = None
33
+
34
+
35
+ class ReplayResult(BaseModel):
36
+ """Overall summary of a replay session."""
37
+
38
+ replay_id: Optional[int] = None
39
+ session_id: int
40
+ target_server_command: str
41
+ started_at: datetime
42
+ ended_at: datetime
43
+ total_messages_replayed: int
44
+ successful_responses: int
45
+ failed_responses: int
46
+ mismatched_responses: int
47
+ timed_out: int
48
+ messages: List[ReplayedMessage]
49
+
50
+
51
+ def deep_compare(val1: Any, val2: Any, ignore_keys: Optional[Set[str]] = None) -> bool:
52
+ """Recursively compare two JSON-compatible values, ignoring specific keys."""
53
+ if ignore_keys is None:
54
+ ignore_keys = {"timestamp", "latency_ms"}
55
+
56
+ if isinstance(val1, dict) and isinstance(val2, dict):
57
+ k1 = set(val1.keys()) - ignore_keys
58
+ k2 = set(val2.keys()) - ignore_keys
59
+ if k1 != k2:
60
+ return False
61
+ for k in k1:
62
+ if not deep_compare(val1[k], val2[k], ignore_keys):
63
+ return False
64
+ return True
65
+ elif isinstance(val1, list) and isinstance(val2, list):
66
+ if len(val1) != len(val2):
67
+ return False
68
+ for i1, i2 in zip(val1, val2):
69
+ if not deep_compare(i1, i2, ignore_keys):
70
+ return False
71
+ return True
72
+ else:
73
+ return bool(val1 == val2)
74
+
75
+
76
+ class ReplayEngine:
77
+ """Loads recorded messages from a session and replays them to a target server."""
78
+
79
+ def __init__(self, db: Database) -> None:
80
+ """Initialize the ReplayEngine with a Database instance."""
81
+ self.db = db
82
+
83
+ async def _reader_loop(
84
+ self,
85
+ reader: asyncio.StreamReader,
86
+ pending_requests: Dict[str, asyncio.Future[Dict[str, Any]]],
87
+ ) -> None:
88
+ """Continuously read lines from the server stdout and resolve pending requests."""
89
+ while True:
90
+ try:
91
+ line_bytes = await reader.readline()
92
+ if not line_bytes:
93
+ break
94
+
95
+ line = line_bytes.decode("utf-8", errors="replace").strip()
96
+ if not line:
97
+ continue
98
+
99
+ try:
100
+ msg = json.loads(line)
101
+ except json.JSONDecodeError:
102
+ # Non-JSON line: ignore and write to stderr
103
+ logger.debug("Server log: %s", line)
104
+ continue
105
+
106
+ if isinstance(msg, dict) and "id" in msg:
107
+ msg_id = str(msg["id"])
108
+ if msg_id in pending_requests:
109
+ fut = pending_requests.pop(msg_id)
110
+ if not fut.done():
111
+ fut.set_result(msg)
112
+ except asyncio.CancelledError:
113
+ break
114
+ except Exception as e:
115
+ logger.error("Error in replay reader loop: %s", e)
116
+ break
117
+
118
+ async def _stderr_loop(self, reader: asyncio.StreamReader) -> None:
119
+ """Continuously read lines from the server stderr and write them to sys.stderr."""
120
+ while True:
121
+ try:
122
+ line_bytes = await reader.readline()
123
+ if not line_bytes:
124
+ break
125
+ line = line_bytes.decode("utf-8", errors="replace")
126
+ sys.stderr.write(f"[server-stderr] {line}")
127
+ sys.stderr.flush()
128
+ except asyncio.CancelledError:
129
+ break
130
+ except Exception:
131
+ break
132
+
133
+ async def replay(
134
+ self,
135
+ session_id: int,
136
+ target_server_command: str,
137
+ timeout_ms: int = 5000,
138
+ replay_mode: str = "exact",
139
+ message_filter: Optional[List[str]] = None,
140
+ persist: bool = True,
141
+ max_messages: Optional[int] = None,
142
+ on_message_replayed: Optional[Callable[[int, int], None]] = None,
143
+ ) -> ReplayResult:
144
+ """Replay client messages from session_id to a new server.
145
+
146
+ Args:
147
+ session_id: Source session to replay.
148
+ target_server_command: Command to launch the server for replay.
149
+ timeout_ms: Max wait per request-response pair.
150
+ replay_mode: "exact" (all messages) or "selective" (filtered).
151
+ message_filter: List of method names to replay (if selective).
152
+ persist: Whether to save the replay result in the database.
153
+ max_messages: Maximum number of messages to replay.
154
+ on_message_replayed: Optional callback when a message is replayed, called with (current, total).
155
+
156
+ Returns:
157
+ ReplayResult containing original vs replayed responses, diff status.
158
+ """
159
+ started_at = datetime.now(timezone.utc)
160
+ original_msgs = await self.db.get_replay_messages(session_id)
161
+
162
+ # Apply selective filtering
163
+ if replay_mode == "selective" and message_filter is not None:
164
+ original_msgs = [m for m in original_msgs if m.get("method") in message_filter]
165
+
166
+ # Apply max messages limit
167
+ if max_messages is not None and max_messages > 0:
168
+ original_msgs = original_msgs[:max_messages]
169
+
170
+ if not original_msgs:
171
+ # Return empty result if no messages found
172
+ return ReplayResult(
173
+ session_id=session_id,
174
+ target_server_command=target_server_command,
175
+ started_at=started_at,
176
+ ended_at=datetime.now(timezone.utc),
177
+ total_messages_replayed=0,
178
+ successful_responses=0,
179
+ failed_responses=0,
180
+ mismatched_responses=0,
181
+ timed_out=0,
182
+ messages=[],
183
+ )
184
+
185
+ process: Optional[asyncio.subprocess.Process] = None
186
+ # Launch target server
187
+ try:
188
+ process = await asyncio.create_subprocess_shell(
189
+ target_server_command,
190
+ stdin=asyncio.subprocess.PIPE,
191
+ stdout=asyncio.subprocess.PIPE,
192
+ stderr=asyncio.subprocess.PIPE,
193
+ )
194
+ except Exception as e:
195
+ # Failed to spawn server command
196
+ ended_at = datetime.now(timezone.utc)
197
+ failed_msgs = []
198
+ for msg in original_msgs:
199
+ failed_msgs.append(
200
+ ReplayedMessage(
201
+ original_message_id=msg["original_message_id"],
202
+ method=msg["method"] or "",
203
+ request_sent={
204
+ "method": msg["method"],
205
+ "params": msg["params"],
206
+ "id": msg["message_id"],
207
+ },
208
+ original_response=msg["original_response"],
209
+ replayed_response=None,
210
+ error=f"Failed to start server: {e}",
211
+ latency_ms=0.0,
212
+ matches=False,
213
+ )
214
+ )
215
+
216
+ replay_id = None
217
+ if persist:
218
+ replay_id = await self.db.save_replay(
219
+ source_session_id=session_id,
220
+ target_server_command=target_server_command,
221
+ status="failed",
222
+ total_messages=len(failed_msgs),
223
+ mismatches=len(failed_msgs),
224
+ messages=[m.model_dump() for m in failed_msgs],
225
+ started_at=started_at.isoformat(),
226
+ ended_at=ended_at.isoformat(),
227
+ )
228
+
229
+ result = ReplayResult(
230
+ replay_id=replay_id,
231
+ session_id=session_id,
232
+ target_server_command=target_server_command,
233
+ started_at=started_at,
234
+ ended_at=ended_at,
235
+ total_messages_replayed=len(failed_msgs),
236
+ successful_responses=0,
237
+ failed_responses=len(failed_msgs),
238
+ mismatched_responses=0,
239
+ timed_out=0,
240
+ messages=failed_msgs,
241
+ )
242
+ return result
243
+
244
+ assert process.stdin and process.stdout
245
+ # Increase StreamReader limit to 10MB to handle large tool outputs and schemas
246
+ setattr(process.stdout, "_limit", 10 * 1024 * 1024)
247
+
248
+ pending_requests: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
249
+ reader_task = asyncio.create_task(self._reader_loop(process.stdout, pending_requests))
250
+ stderr_task = None
251
+ if process.stderr:
252
+ stderr_task = asyncio.create_task(self._stderr_loop(process.stderr))
253
+
254
+ replayed_messages: List[ReplayedMessage] = []
255
+ server_terminated = False
256
+
257
+ try:
258
+ for idx, msg in enumerate(original_msgs):
259
+ method = msg["method"]
260
+
261
+ # Apply selective filtering if requested (already done pre-loop, but kept for compatibility)
262
+ if replay_mode == "selective" and message_filter is not None:
263
+ if method not in message_filter:
264
+ continue
265
+
266
+ msg_id = msg["message_id"]
267
+ params = msg["params"]
268
+ is_notification = msg["message_type"] == "notification" or msg_id is None
269
+
270
+ # Reconstruct payload
271
+ payload: Dict[str, Any] = {"jsonrpc": "2.0"}
272
+ if msg_id is not None:
273
+ try:
274
+ payload["id"] = int(msg_id)
275
+ except ValueError:
276
+ payload["id"] = msg_id
277
+ if method is not None:
278
+ payload["method"] = method
279
+ if params is not None:
280
+ payload["params"] = params
281
+
282
+ payload_str = json.dumps(payload) + "\n"
283
+ start_time = time.monotonic()
284
+ replayed_resp = None
285
+ error = None
286
+ latency = 0.0
287
+ matches = False
288
+
289
+ if server_terminated:
290
+ error = "Server process terminated"
291
+ else:
292
+ # Send payload
293
+ try:
294
+ process.stdin.write(payload_str.encode("utf-8"))
295
+ await process.stdin.drain()
296
+ except Exception as e:
297
+ error = f"Write error: {e}"
298
+ server_terminated = True
299
+
300
+ if not error:
301
+ if is_notification:
302
+ # Notifications do not have responses
303
+ latency = (time.monotonic() - start_time) * 1000.0
304
+ replayed_resp = None
305
+ matches = msg["original_response"] is None
306
+ error = None
307
+ else:
308
+ # Request: wait for response matching ID
309
+ fut = asyncio.get_running_loop().create_future()
310
+ pending_requests[str(msg_id)] = fut
311
+
312
+ try:
313
+ replayed_resp = await asyncio.wait_for(
314
+ fut, timeout=timeout_ms / 1000.0
315
+ )
316
+ latency = (time.monotonic() - start_time) * 1000.0
317
+ matches = deep_compare(replayed_resp, msg["original_response"])
318
+ except asyncio.TimeoutError:
319
+ # Clean up pending future
320
+ if str(msg_id) in pending_requests:
321
+ pending_requests.pop(str(msg_id))
322
+ latency = (time.monotonic() - start_time) * 1000.0
323
+ matches = False
324
+ error = "Timeout waiting for response"
325
+ # Abort remaining messages on first timeout
326
+ server_terminated = True
327
+ except Exception as e:
328
+ latency = (time.monotonic() - start_time) * 1000.0
329
+ matches = False
330
+ error = str(e)
331
+
332
+ msg_diff = None
333
+ msg_diff_text = None
334
+ if not matches:
335
+ diff_node = compare_json(msg["original_response"], replayed_resp)
336
+ if diff_node is not None:
337
+ msg_diff = [diff_node]
338
+ msg_diff_text = render_diff(diff_node)
339
+
340
+ replayed_messages.append(
341
+ ReplayedMessage(
342
+ original_message_id=msg["original_message_id"],
343
+ method=method or "",
344
+ request_sent=payload,
345
+ original_response=msg["original_response"],
346
+ replayed_response=replayed_resp,
347
+ error=error,
348
+ latency_ms=latency,
349
+ matches=matches,
350
+ diff=msg_diff,
351
+ diff_text=msg_diff_text,
352
+ )
353
+ )
354
+ if on_message_replayed:
355
+ on_message_replayed(idx + 1, len(original_msgs))
356
+
357
+ finally:
358
+ reader_task.cancel()
359
+ if stderr_task:
360
+ stderr_task.cancel()
361
+ if process is not None:
362
+ if process.stdin:
363
+ try:
364
+ process.stdin.close()
365
+ if hasattr(process.stdin, "wait_closed"):
366
+ await process.stdin.wait_closed()
367
+ except Exception:
368
+ pass
369
+ try:
370
+ process.terminate()
371
+ await asyncio.wait_for(process.wait(), timeout=2.0)
372
+ except Exception:
373
+ try:
374
+ process.kill()
375
+ await process.wait()
376
+ except Exception:
377
+ pass
378
+ if hasattr(process, "_transport") and process._transport:
379
+ try:
380
+ process._transport.close()
381
+ except Exception:
382
+ pass
383
+ try:
384
+ await reader_task
385
+ except asyncio.CancelledError:
386
+ pass
387
+ if stderr_task:
388
+ try:
389
+ await stderr_task
390
+ except asyncio.CancelledError:
391
+ pass
392
+ process = None
393
+ import gc
394
+
395
+ gc.collect()
396
+
397
+ ended_at = datetime.now(timezone.utc)
398
+
399
+ # Calculate stats
400
+ successful_responses = sum(
401
+ 1 for m in replayed_messages if m.replayed_response is not None and not m.error
402
+ )
403
+ failed_responses = sum(
404
+ 1 for m in replayed_messages if m.error is not None and "Timeout" not in m.error
405
+ )
406
+ timed_out = sum(
407
+ 1 for m in replayed_messages if m.error is not None and "Timeout" in m.error
408
+ )
409
+ mismatched_responses = sum(
410
+ 1 for m in replayed_messages if not m.matches and m.replayed_response is not None
411
+ )
412
+
413
+ status = "completed"
414
+ if timed_out > 0:
415
+ status = "timeout"
416
+ elif failed_responses > 0:
417
+ status = "failed"
418
+
419
+ replay_id = None
420
+ if persist:
421
+ replay_id = await self.db.save_replay(
422
+ source_session_id=session_id,
423
+ target_server_command=target_server_command,
424
+ status=status,
425
+ total_messages=len(replayed_messages),
426
+ mismatches=mismatched_responses + timed_out + failed_responses,
427
+ messages=[m.model_dump() for m in replayed_messages],
428
+ started_at=started_at.isoformat(),
429
+ ended_at=ended_at.isoformat(),
430
+ )
431
+
432
+ result = ReplayResult(
433
+ replay_id=replay_id,
434
+ session_id=session_id,
435
+ target_server_command=target_server_command,
436
+ started_at=started_at,
437
+ ended_at=ended_at,
438
+ total_messages_replayed=len(replayed_messages),
439
+ successful_responses=successful_responses,
440
+ failed_responses=failed_responses,
441
+ mismatched_responses=mismatched_responses,
442
+ timed_out=timed_out,
443
+ messages=replayed_messages,
444
+ )
445
+
446
+ return result
File without changes