mtrx-cli 0.1.24 → 0.1.26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,17 +33,30 @@ import httpx
33
33
  from matrx.cli.cursor_ca import CertCache, load_ca
34
34
 
35
35
  try:
36
- from matrx.cli.cursor_reroute import is_ai_path, try_reroute_to_matrx
36
+ from matrx.cli.cursor_reroute import (
37
+ classify_ai_request,
38
+ is_ai_path,
39
+ try_inject_context,
40
+ try_reroute_to_matrx,
41
+ )
37
42
  except ImportError:
38
43
  # Stubs when cursor_reroute not available (e.g. npm package omit).
44
+ def classify_ai_request(method: str, path: str, headers: dict[str, str] | None = None) -> dict[str, bool]:
45
+ return {"candidate": False, "reroutable": False}
46
+
39
47
  def is_ai_path(path: str) -> bool:
40
48
  return False
41
49
 
42
50
  async def try_reroute_to_matrx(*, path: str, method: str, **kwargs: Any) -> None:
43
51
  return None
44
52
 
53
+ async def try_inject_context(**kwargs: Any) -> None:
54
+ return None
55
+
45
56
  logger = logging.getLogger(__name__)
46
57
 
58
+ _MAX_BODY_BYTES = 50 * 1024 * 1024 # 50 MB hard limit for buffered request bodies
59
+
47
60
  DEFAULT_PORT = 8842
48
61
  PROXY_HOST = "127.0.0.1"
49
62
  HEALTH_PATH = "/__mtrx_health__"
@@ -55,8 +68,17 @@ _INTERCEPT_DOMAINS = {
55
68
  "api4.cursor.sh",
56
69
  "api5.cursor.sh",
57
70
  "agentn.global.api5.cursor.sh",
71
+ "api.anthropic.com",
72
+ "api.openai.com",
58
73
  }
59
74
 
75
+ _PREWARM_DOMAINS = (
76
+ "api2.cursor.sh",
77
+ "api3.cursor.sh",
78
+ "api4.cursor.sh",
79
+ "api5.cursor.sh",
80
+ )
81
+
60
82
 
61
83
  class MITMProxy:
62
84
  """Async MITM forward proxy with telemetry mirroring."""
@@ -77,10 +99,12 @@ class MITMProxy:
77
99
  self._telemetry_client: httpx.AsyncClient | None = None
78
100
  self._cert_cache: CertCache | None = None
79
101
  self._request_count = 0
102
+ self._connect_count = 0
80
103
 
81
104
  async def start(self) -> None:
82
105
  ca_key, ca_cert = load_ca()
83
106
  self._cert_cache = CertCache(ca_key, ca_cert)
107
+ self._cert_cache.prewarm(_PREWARM_DOMAINS)
84
108
  self._telemetry_client = httpx.AsyncClient(timeout=10)
85
109
  self._server = await asyncio.start_server(
86
110
  self._handle_client, self.host, self.port
@@ -119,7 +143,7 @@ class MITMProxy:
119
143
  except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
120
144
  pass
121
145
  except Exception:
122
- logger.debug("proxy: connection error", exc_info=True)
146
+ logger.warning("proxy: connection error", exc_info=True)
123
147
  finally:
124
148
  try:
125
149
  writer.close()
@@ -163,8 +187,10 @@ class MITMProxy:
163
187
  await writer.drain()
164
188
 
165
189
  if hostname in _INTERCEPT_DOMAINS:
190
+ logger.info("proxy: CONNECT %s:%d [intercept]", hostname, port)
166
191
  await self._mitm_intercept(reader, writer, hostname, port)
167
192
  else:
193
+ logger.info("proxy: CONNECT %s:%d [tunnel]", hostname, port)
168
194
  await self._tunnel_passthrough(reader, writer, hostname, port)
169
195
  elif method in ("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"):
170
196
  # Plain HTTP proxy request (non-CONNECT) -- handle health check
@@ -215,10 +241,22 @@ class MITMProxy:
215
241
  port: int,
216
242
  ) -> None:
217
243
  assert self._cert_cache is not None
244
+ self._connect_count += 1
245
+ conn_id = f"{hostname}:{self._connect_count}"
218
246
 
219
247
  # Use the hostname from the CONNECT request for the cert
220
248
  # (matches SNI in virtually all cases, avoids ClientHello peeking)
249
+ handshake_info = self._cert_cache.get_handshake_info(hostname)
221
250
  server_ctx = self._cert_cache.get_ssl_context(hostname)
251
+ logger.info(
252
+ "proxy: tls_prepare conn=%s host=%s serial=%s leaf_sha256=%s chain_len=%s cert=%s",
253
+ conn_id,
254
+ hostname,
255
+ handshake_info["leaf_serial"],
256
+ handshake_info["leaf_sha256"],
257
+ handshake_info["chain_length"],
258
+ handshake_info["cert_path"],
259
+ )
222
260
 
223
261
  # Upgrade client connection to TLS (we are the "server")
224
262
  loop = asyncio.get_running_loop()
@@ -229,8 +267,23 @@ class MITMProxy:
229
267
  transport, protocol, server_ctx, server_side=True
230
268
  )
231
269
  except (ssl.SSLError, ConnectionError) as exc:
232
- logger.debug("TLS handshake with client failed for %s: %s", hostname, exc)
270
+ logger.warning(
271
+ "TLS handshake with client failed for %s [conn=%s serial=%s leaf_sha256=%s chain_len=%s]: %s",
272
+ hostname,
273
+ conn_id,
274
+ handshake_info["leaf_serial"],
275
+ handshake_info["leaf_sha256"],
276
+ handshake_info["chain_length"],
277
+ exc,
278
+ )
233
279
  return
280
+ logger.info(
281
+ "proxy: tls_ready conn=%s host=%s serial=%s chain_len=%s",
282
+ conn_id,
283
+ hostname,
284
+ handshake_info["leaf_serial"],
285
+ handshake_info["chain_length"],
286
+ )
234
287
 
235
288
  tls_writer = asyncio.StreamWriter(new_transport, protocol, client_reader, loop)
236
289
 
@@ -243,6 +296,11 @@ class MITMProxy:
243
296
  )
244
297
  except Exception:
245
298
  logger.debug("Failed to connect to upstream %s:%d", hostname, port)
299
+ try:
300
+ tls_writer.write(b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n")
301
+ await tls_writer.drain()
302
+ except Exception:
303
+ pass
246
304
  return
247
305
 
248
306
  # Forward HTTP/1.1 traffic between decrypted client and upstream
@@ -277,9 +335,11 @@ class MITMProxy:
277
335
  while True:
278
336
  req_line = await client_reader.readline()
279
337
  if not req_line:
338
+ logger.info("proxy: %s — connection closed (no request line)", hostname)
280
339
  break
281
340
  req_line_str = req_line.decode("utf-8", errors="replace").strip()
282
341
  if not req_line_str:
342
+ logger.info("proxy: %s — empty request line", hostname)
283
343
  break
284
344
 
285
345
  parts = req_line_str.split(" ", 2)
@@ -287,52 +347,99 @@ class MITMProxy:
287
347
  path = parts[1] if len(parts) > 1 else "/"
288
348
 
289
349
  req_body_size = 0
290
- # For AI paths: buffer request and try rerouting through MTRX (live injection)
291
- if method == "POST" and is_ai_path(path):
350
+ _is_ai_req = False
351
+ _req_session_id = ""
352
+ req_headers: dict[str, str]
353
+ req_cl: int
354
+ req_chunked: bool
355
+
356
+ if method == "POST":
292
357
  req_headers, req_cl, req_chunked = await self._read_headers_only(
293
358
  client_reader
294
359
  )
295
- req_body = await self._read_body_to_bytes(
296
- client_reader, req_cl, req_chunked
297
- )
298
- req_body_size = len(req_body)
299
- result = await try_reroute_to_matrx(
300
- path=path,
301
- method=method,
302
- req_headers=req_headers,
303
- req_body=req_body,
304
- matrx_base_url=self.matrx_base_url,
305
- matrx_key=self.matrx_key,
306
- session_id=str(uuid.uuid4()),
360
+ ai_classification = classify_ai_request(method, path, req_headers)
361
+ _is_ai_req = ai_classification["candidate"]
362
+ _is_ai_reroutable = ai_classification["reroutable"]
363
+ _req_session_id = str(uuid.uuid4()) if _is_ai_req else ""
364
+ logger.info(
365
+ "proxy: %s %s%s [ai=%s reroutable=%s ct=%s]",
366
+ method,
367
+ hostname,
368
+ path,
369
+ _is_ai_req,
370
+ _is_ai_reroutable,
371
+ req_headers.get("content-type", ""),
307
372
  )
308
- if result is not None:
309
- success, resp_headers, resp_body, is_streaming = result
310
- if success and resp_body is not None:
311
- self._request_count += 1
312
- self._write_http_response(
313
- client_writer, 200, resp_headers, resp_body
373
+ if _is_ai_req and not _is_ai_reroutable and "aiserver.v1." in path.lower():
374
+ logger.info("proxy: candidate AI request not yet reroutable: %s%s", hostname, path)
375
+
376
+ # For AI paths: buffer request and try rerouting through MTRX (live injection)
377
+ if _is_ai_req:
378
+ try:
379
+ req_body = await self._read_body_to_bytes(
380
+ client_reader, req_cl, req_chunked
314
381
  )
315
- asyncio.create_task(
316
- self._ship_telemetry(
317
- hostname=hostname,
318
- method=method,
319
- path=path,
320
- status_code=200,
321
- req_body_size=len(req_body),
322
- resp_body_size=len(resp_body),
323
- elapsed_ms=0,
324
- content_type=resp_headers.get("content-type", ""),
325
- is_streaming=is_streaming,
382
+ except ValueError:
383
+ client_writer.write(b"HTTP/1.1 413 Content Too Large\r\nContent-Length: 0\r\n\r\n")
384
+ await client_writer.drain()
385
+ return
386
+ req_body_size = len(req_body)
387
+ result = await try_reroute_to_matrx(
388
+ path=path,
389
+ method=method,
390
+ req_headers=req_headers,
391
+ req_body=req_body,
392
+ matrx_base_url=self.matrx_base_url,
393
+ matrx_key=self.matrx_key,
394
+ session_id=_req_session_id,
395
+ )
396
+ if result is not None:
397
+ success, resp_headers, resp_body, is_streaming = result
398
+ if success and resp_body is not None:
399
+ self._request_count += 1
400
+ self._write_http_response(
401
+ client_writer, 200, resp_headers, resp_body
326
402
  )
327
- )
328
- continue
329
- # Reroute returned but failed — fall through to forward
330
- # Reroute not implemented or failed — forward to upstream
331
- up_writer.write(req_line)
332
- await self._write_headers(up_writer, req_headers)
333
- up_writer.write(req_body)
334
- await up_writer.drain()
403
+ asyncio.create_task(
404
+ self._ship_telemetry(
405
+ hostname=hostname,
406
+ method=method,
407
+ path=path,
408
+ status_code=200,
409
+ req_body_size=len(req_body),
410
+ resp_body_size=len(resp_body),
411
+ elapsed_ms=0,
412
+ content_type=resp_headers.get("content-type", ""),
413
+ is_streaming=is_streaming,
414
+ )
415
+ )
416
+ continue
417
+ # Reroute returned but failed — fall through to forward
418
+ # Inject MTRX memory context into request before forwarding
419
+ injected_body = await try_inject_context(
420
+ req_body=req_body,
421
+ req_headers=req_headers,
422
+ matrx_base_url=self.matrx_base_url,
423
+ matrx_key=self.matrx_key,
424
+ session_id=_req_session_id,
425
+ )
426
+ body_to_forward = injected_body if injected_body is not None else req_body
427
+ fwd_headers = dict(req_headers)
428
+ fwd_headers["content-length"] = str(len(body_to_forward))
429
+ up_writer.write(req_line)
430
+ self._write_headers(up_writer, fwd_headers)
431
+ up_writer.write(body_to_forward)
432
+ await up_writer.drain()
433
+ else:
434
+ up_writer.write(req_line)
435
+ self._write_headers(up_writer, req_headers)
436
+ req_body_size = await self._forward_body(
437
+ client_reader, up_writer, req_cl, req_chunked
438
+ )
439
+ if req_body_size == 0 and req_cl > 0:
440
+ req_body_size = req_cl
335
441
  else:
442
+ logger.info("proxy: %s %s%s [ai=%s]", method, hostname, path, False)
336
443
  up_writer.write(req_line)
337
444
  req_headers, req_cl, req_chunked = await self._forward_headers(
338
445
  client_reader, up_writer
@@ -369,9 +476,20 @@ class MITMProxy:
369
476
  for t in ("text/event-stream", "grpc", "proto", "connect")
370
477
  )
371
478
 
372
- resp_body_size = await self._forward_body(
373
- up_reader, client_writer, resp_cl, resp_chunked
374
- )
479
+ if _is_ai_req:
480
+ resp_body_size, resp_captured = await self._forward_body_with_capture(
481
+ up_reader, client_writer, resp_cl, resp_chunked
482
+ )
483
+ if resp_captured:
484
+ asyncio.create_task(
485
+ self._extract_ai_response(
486
+ resp_captured, _req_session_id, hostname
487
+ )
488
+ )
489
+ else:
490
+ resp_body_size = await self._forward_body(
491
+ up_reader, client_writer, resp_cl, resp_chunked
492
+ )
375
493
 
376
494
  elapsed_ms = int((time.monotonic() - started) * 1000)
377
495
  self._request_count += 1
@@ -390,82 +508,158 @@ class MITMProxy:
390
508
  )
391
509
  )
392
510
 
393
- conn_h = (
394
- req_headers.get("connection", "")
395
- + resp_headers.get("connection", "")
396
- ).lower()
511
+ conn_h = resp_headers.get("connection", "").lower()
397
512
  if "close" in conn_h:
398
513
  break
399
514
 
400
- async def _read_headers_only(
401
- self, reader: asyncio.StreamReader
402
- ) -> tuple[dict[str, str], int, bool]:
403
- """Read headers without writing. Returns (headers_dict, content_length, is_chunked)."""
404
- headers: dict[str, str] = {}
405
- content_length = -1
406
- chunked = False
407
- while True:
408
- line = await reader.readline()
409
- decoded = line.decode("utf-8", errors="replace").strip()
410
- if not decoded:
411
- break
412
- if ":" in decoded:
413
- k, _, v = decoded.partition(":")
414
- k_lower = k.strip().lower()
415
- v_stripped = v.strip()
416
- headers[k_lower] = v_stripped
417
- if k_lower == "content-length":
418
- content_length = int(v_stripped)
419
- elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
420
- chunked = True
421
- return headers, content_length, chunked
422
-
423
- async def _read_body_to_bytes(
515
+ async def _forward_body_with_capture(
424
516
  self,
425
517
  reader: asyncio.StreamReader,
518
+ writer: asyncio.StreamWriter,
426
519
  content_length: int,
427
520
  chunked: bool,
428
- ) -> bytes:
429
- """Read body into bytes (no writer)."""
521
+ ) -> tuple[int, bytes]:
522
+ """Forward body like ``_forward_body`` while also capturing a copy.
523
+
524
+ Returns ``(bytes_forwarded, captured_bytes)``. The capture enables
525
+ background response extraction without blocking the forward path.
526
+ """
527
+ parts: list[bytes] = []
528
+
430
529
  if content_length > 0:
431
- return await reader.read(content_length)
530
+ total = 0
531
+ remaining = content_length
532
+ while remaining > 0:
533
+ chunk = await reader.read(min(remaining, 65536))
534
+ if not chunk:
535
+ break
536
+ writer.write(chunk)
537
+ await writer.drain()
538
+ parts.append(chunk)
539
+ total += len(chunk)
540
+ remaining -= len(chunk)
541
+ return total, b"".join(parts)
542
+
432
543
  if chunked:
433
- parts: list[bytes] = []
544
+ total = 0
434
545
  while True:
435
546
  size_line = await reader.readline()
547
+ if not size_line:
548
+ break
549
+ writer.write(size_line)
550
+ await writer.drain()
436
551
  size_str = size_line.decode("utf-8", errors="replace").strip()
437
552
  try:
438
553
  chunk_size = int(size_str.split(";")[0], 16)
439
554
  except ValueError:
440
555
  break
441
556
  if chunk_size == 0:
442
- await reader.readline() # trailer
557
+ trailer = await reader.readline()
558
+ writer.write(trailer)
559
+ await writer.drain()
443
560
  break
444
- parts.append(await reader.read(chunk_size))
445
- await reader.readline() # crlf
446
- return b"".join(parts)
447
- return b""
561
+ remaining = chunk_size
562
+ chunk_parts: list[bytes] = []
563
+ while remaining > 0:
564
+ data = await reader.read(min(remaining, 65536))
565
+ if not data:
566
+ return total, b"".join(parts)
567
+ writer.write(data)
568
+ await writer.drain()
569
+ chunk_parts.append(data)
570
+ total += len(data)
571
+ remaining -= len(data)
572
+ chunk_data = b"".join(chunk_parts)
573
+ parts.append(chunk_data)
574
+ crlf = await reader.readline()
575
+ writer.write(crlf)
576
+ await writer.drain()
577
+ return total, b"".join(parts)
448
578
 
449
- def _write_headers(
450
- self, writer: asyncio.StreamWriter, headers: dict[str, str]
451
- ) -> None:
452
- """Write headers as HTTP lines (caller must drain)."""
453
- for k, v in headers.items():
454
- writer.write(f"{k}: {v}\r\n".encode())
455
- writer.write(b"\r\n")
579
+ # No content-length, no chunked encoding — stream until the upstream closes.
580
+ # This covers Cursor's SSE AI responses that use raw HTTP/1.1 keep-alive streaming.
581
+ # Cap capture at 512 KB to bound memory; bytes beyond that are still forwarded.
582
+ _CAPTURE_LIMIT = 512 * 1024
583
+ parts = []
584
+ total = 0
585
+ capturing = True
586
+ while True:
587
+ chunk = await reader.read(65536)
588
+ if not chunk:
589
+ break
590
+ writer.write(chunk)
591
+ await writer.drain()
592
+ total += len(chunk)
593
+ if capturing:
594
+ parts.append(chunk)
595
+ if total >= _CAPTURE_LIMIT:
596
+ capturing = False
597
+ return total, b"".join(parts)
456
598
 
457
- def _write_http_response(
599
+ async def _extract_ai_response(
458
600
  self,
459
- writer: asyncio.StreamWriter,
460
- status: int,
461
- resp_headers: dict[str, str],
462
- resp_body: bytes,
601
+ resp_bytes: bytes,
602
+ session_id: str,
603
+ hostname: str,
463
604
  ) -> None:
464
- """Write a complete HTTP response."""
465
- writer.write(f"HTTP/1.1 {status} OK\r\n".encode())
466
- self._write_headers(writer, resp_headers)
467
- writer.write(resp_body)
468
- # Caller should drain
605
+ """Parse Connect frames from *resp_bytes* and ship response telemetry.
606
+
607
+ Tries compiled proto parsing first; falls back to raw wire-format parsing
608
+ so token counts are always extracted even without compiled proto files.
609
+ Fire-and-forget never raises, never blocks the forward path.
610
+ """
611
+ try:
612
+ from matrx.cli.cursor_extraction import ship_ai_telemetry
613
+
614
+ import gzip as _gzip
615
+ body = resp_bytes
616
+ if len(body) >= 2 and body[:2] == b"\x1f\x8b":
617
+ try:
618
+ body = _gzip.decompress(body)
619
+ except Exception:
620
+ body = resp_bytes
621
+
622
+ accumulated: dict = {
623
+ "session_id": session_id,
624
+ "response_text": "",
625
+ "tool_calls": [],
626
+ "usage": None,
627
+ }
628
+
629
+ if hostname == "api.anthropic.com":
630
+ from matrx.cli.cursor_extraction import extract_from_anthropic_sse_response
631
+ frame_data = extract_from_anthropic_sse_response(body)
632
+ accumulated["response_text"] = frame_data.get("text", "")
633
+ accumulated["tool_calls"] = frame_data.get("tool_calls", [])
634
+ accumulated["usage"] = frame_data.get("usage")
635
+ elif hostname == "api.openai.com":
636
+ from matrx.cli.cursor_extraction import extract_from_openai_sse_response
637
+ frame_data = extract_from_openai_sse_response(body)
638
+ accumulated["response_text"] = frame_data.get("text", "")
639
+ accumulated["tool_calls"] = frame_data.get("tool_calls", [])
640
+ accumulated["usage"] = frame_data.get("usage")
641
+ else:
642
+ # Cursor backend: Connect/gRPC protobuf frames
643
+ from matrx.cli.cursor_connect import parse_all_frames
644
+ from matrx.cli.cursor_extraction import (
645
+ _raw_extract_response_frame,
646
+ extract_from_response_frame,
647
+ parse_response_proto,
648
+ )
649
+ for flags, payload in parse_all_frames(body):
650
+ if flags == 0x02:
651
+ break
652
+ resp_proto = parse_response_proto(payload)
653
+ frame_data = extract_from_response_frame(resp_proto) if resp_proto is not None else _raw_extract_response_frame(payload)
654
+ if frame_data:
655
+ accumulated["response_text"] += frame_data.get("text", "")
656
+ accumulated["tool_calls"].extend(frame_data.get("tool_calls", []))
657
+ if frame_data.get("usage"):
658
+ accumulated["usage"] = frame_data["usage"]
659
+
660
+ await ship_ai_telemetry(accumulated, self.matrx_base_url, self.matrx_key)
661
+ except Exception:
662
+ logger.debug("proxy: _extract_ai_response failed", exc_info=True)
469
663
 
470
664
  async def _forward_headers(
471
665
  self,
@@ -491,7 +685,10 @@ class MITMProxy:
491
685
  v_stripped = v.strip()
492
686
  headers[k_lower] = v_stripped
493
687
  if k_lower == "content-length":
494
- content_length = int(v_stripped)
688
+ try:
689
+ content_length = int(v_stripped)
690
+ except ValueError:
691
+ pass
495
692
  elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
496
693
  chunked = True
497
694
  await writer.drain()
@@ -583,7 +780,10 @@ class MITMProxy:
583
780
  v_stripped = v.strip()
584
781
  headers[k_lower] = v_stripped
585
782
  if k_lower == "content-length":
586
- content_length = int(v_stripped)
783
+ try:
784
+ content_length = int(v_stripped)
785
+ except ValueError:
786
+ pass
587
787
  elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
588
788
  chunked = True
589
789
  return headers, content_length, chunked
@@ -596,6 +796,8 @@ class MITMProxy:
596
796
  ) -> bytes:
597
797
  """Read body into bytes."""
598
798
  if content_length > 0:
799
+ if content_length > _MAX_BODY_BYTES:
800
+ raise ValueError(f"Request body too large: {content_length} bytes")
599
801
  return await reader.readexactly(content_length)
600
802
  if chunked:
601
803
  parts: list[bytes] = []