mtrx-cli 0.1.25 → 0.1.26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,9 +33,17 @@ import httpx
33
33
  from matrx.cli.cursor_ca import CertCache, load_ca
34
34
 
35
35
  try:
36
- from matrx.cli.cursor_reroute import is_ai_path, try_inject_context, try_reroute_to_matrx
36
+ from matrx.cli.cursor_reroute import (
37
+ classify_ai_request,
38
+ is_ai_path,
39
+ try_inject_context,
40
+ try_reroute_to_matrx,
41
+ )
37
42
  except ImportError:
38
43
  # Stubs when cursor_reroute not available (e.g. npm package omit).
44
+ def classify_ai_request(method: str, path: str, headers: dict[str, str] | None = None) -> dict[str, bool]:
45
+ return {"candidate": False, "reroutable": False}
46
+
39
47
  def is_ai_path(path: str) -> bool:
40
48
  return False
41
49
 
@@ -47,6 +55,8 @@ except ImportError:
47
55
 
48
56
  logger = logging.getLogger(__name__)
49
57
 
58
+ _MAX_BODY_BYTES = 50 * 1024 * 1024 # 50 MB hard limit for buffered request bodies
59
+
50
60
  DEFAULT_PORT = 8842
51
61
  PROXY_HOST = "127.0.0.1"
52
62
  HEALTH_PATH = "/__mtrx_health__"
@@ -58,8 +68,17 @@ _INTERCEPT_DOMAINS = {
58
68
  "api4.cursor.sh",
59
69
  "api5.cursor.sh",
60
70
  "agentn.global.api5.cursor.sh",
71
+ "api.anthropic.com",
72
+ "api.openai.com",
61
73
  }
62
74
 
75
+ _PREWARM_DOMAINS = (
76
+ "api2.cursor.sh",
77
+ "api3.cursor.sh",
78
+ "api4.cursor.sh",
79
+ "api5.cursor.sh",
80
+ )
81
+
63
82
 
64
83
  class MITMProxy:
65
84
  """Async MITM forward proxy with telemetry mirroring."""
@@ -80,10 +99,12 @@ class MITMProxy:
80
99
  self._telemetry_client: httpx.AsyncClient | None = None
81
100
  self._cert_cache: CertCache | None = None
82
101
  self._request_count = 0
102
+ self._connect_count = 0
83
103
 
84
104
  async def start(self) -> None:
85
105
  ca_key, ca_cert = load_ca()
86
106
  self._cert_cache = CertCache(ca_key, ca_cert)
107
+ self._cert_cache.prewarm(_PREWARM_DOMAINS)
87
108
  self._telemetry_client = httpx.AsyncClient(timeout=10)
88
109
  self._server = await asyncio.start_server(
89
110
  self._handle_client, self.host, self.port
@@ -122,7 +143,7 @@ class MITMProxy:
122
143
  except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
123
144
  pass
124
145
  except Exception:
125
- logger.debug("proxy: connection error", exc_info=True)
146
+ logger.warning("proxy: connection error", exc_info=True)
126
147
  finally:
127
148
  try:
128
149
  writer.close()
@@ -166,8 +187,10 @@ class MITMProxy:
166
187
  await writer.drain()
167
188
 
168
189
  if hostname in _INTERCEPT_DOMAINS:
190
+ logger.info("proxy: CONNECT %s:%d [intercept]", hostname, port)
169
191
  await self._mitm_intercept(reader, writer, hostname, port)
170
192
  else:
193
+ logger.info("proxy: CONNECT %s:%d [tunnel]", hostname, port)
171
194
  await self._tunnel_passthrough(reader, writer, hostname, port)
172
195
  elif method in ("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"):
173
196
  # Plain HTTP proxy request (non-CONNECT) -- handle health check
@@ -218,10 +241,22 @@ class MITMProxy:
218
241
  port: int,
219
242
  ) -> None:
220
243
  assert self._cert_cache is not None
244
+ self._connect_count += 1
245
+ conn_id = f"{hostname}:{self._connect_count}"
221
246
 
222
247
  # Use the hostname from the CONNECT request for the cert
223
248
  # (matches SNI in virtually all cases, avoids ClientHello peeking)
249
+ handshake_info = self._cert_cache.get_handshake_info(hostname)
224
250
  server_ctx = self._cert_cache.get_ssl_context(hostname)
251
+ logger.info(
252
+ "proxy: tls_prepare conn=%s host=%s serial=%s leaf_sha256=%s chain_len=%s cert=%s",
253
+ conn_id,
254
+ hostname,
255
+ handshake_info["leaf_serial"],
256
+ handshake_info["leaf_sha256"],
257
+ handshake_info["chain_length"],
258
+ handshake_info["cert_path"],
259
+ )
225
260
 
226
261
  # Upgrade client connection to TLS (we are the "server")
227
262
  loop = asyncio.get_running_loop()
@@ -232,8 +267,23 @@ class MITMProxy:
232
267
  transport, protocol, server_ctx, server_side=True
233
268
  )
234
269
  except (ssl.SSLError, ConnectionError) as exc:
235
- logger.debug("TLS handshake with client failed for %s: %s", hostname, exc)
270
+ logger.warning(
271
+ "TLS handshake with client failed for %s [conn=%s serial=%s leaf_sha256=%s chain_len=%s]: %s",
272
+ hostname,
273
+ conn_id,
274
+ handshake_info["leaf_serial"],
275
+ handshake_info["leaf_sha256"],
276
+ handshake_info["chain_length"],
277
+ exc,
278
+ )
236
279
  return
280
+ logger.info(
281
+ "proxy: tls_ready conn=%s host=%s serial=%s chain_len=%s",
282
+ conn_id,
283
+ hostname,
284
+ handshake_info["leaf_serial"],
285
+ handshake_info["chain_length"],
286
+ )
237
287
 
238
288
  tls_writer = asyncio.StreamWriter(new_transport, protocol, client_reader, loop)
239
289
 
@@ -246,6 +296,11 @@ class MITMProxy:
246
296
  )
247
297
  except Exception:
248
298
  logger.debug("Failed to connect to upstream %s:%d", hostname, port)
299
+ try:
300
+ tls_writer.write(b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n")
301
+ await tls_writer.drain()
302
+ except Exception:
303
+ pass
249
304
  return
250
305
 
251
306
  # Forward HTTP/1.1 traffic between decrypted client and upstream
@@ -280,9 +335,11 @@ class MITMProxy:
280
335
  while True:
281
336
  req_line = await client_reader.readline()
282
337
  if not req_line:
338
+ logger.info("proxy: %s — connection closed (no request line)", hostname)
283
339
  break
284
340
  req_line_str = req_line.decode("utf-8", errors="replace").strip()
285
341
  if not req_line_str:
342
+ logger.info("proxy: %s — empty request line", hostname)
286
343
  break
287
344
 
288
345
  parts = req_line_str.split(" ", 2)
@@ -290,64 +347,99 @@ class MITMProxy:
290
347
  path = parts[1] if len(parts) > 1 else "/"
291
348
 
292
349
  req_body_size = 0
293
- _is_ai_req = method == "POST" and is_ai_path(path)
294
- _req_session_id = str(uuid.uuid4()) if _is_ai_req else ""
295
- # For AI paths: buffer request and try rerouting through MTRX (live injection)
296
- if _is_ai_req:
350
+ _is_ai_req = False
351
+ _req_session_id = ""
352
+ req_headers: dict[str, str]
353
+ req_cl: int
354
+ req_chunked: bool
355
+
356
+ if method == "POST":
297
357
  req_headers, req_cl, req_chunked = await self._read_headers_only(
298
358
  client_reader
299
359
  )
300
- req_body = await self._read_body_to_bytes(
301
- client_reader, req_cl, req_chunked
360
+ ai_classification = classify_ai_request(method, path, req_headers)
361
+ _is_ai_req = ai_classification["candidate"]
362
+ _is_ai_reroutable = ai_classification["reroutable"]
363
+ _req_session_id = str(uuid.uuid4()) if _is_ai_req else ""
364
+ logger.info(
365
+ "proxy: %s %s%s [ai=%s reroutable=%s ct=%s]",
366
+ method,
367
+ hostname,
368
+ path,
369
+ _is_ai_req,
370
+ _is_ai_reroutable,
371
+ req_headers.get("content-type", ""),
302
372
  )
303
- req_body_size = len(req_body)
304
- result = await try_reroute_to_matrx(
305
- path=path,
306
- method=method,
307
- req_headers=req_headers,
308
- req_body=req_body,
309
- matrx_base_url=self.matrx_base_url,
310
- matrx_key=self.matrx_key,
311
- session_id=_req_session_id,
312
- )
313
- if result is not None:
314
- success, resp_headers, resp_body, is_streaming = result
315
- if success and resp_body is not None:
316
- self._request_count += 1
317
- self._write_http_response(
318
- client_writer, 200, resp_headers, resp_body
373
+ if _is_ai_req and not _is_ai_reroutable and "aiserver.v1." in path.lower():
374
+ logger.info("proxy: candidate AI request not yet reroutable: %s%s", hostname, path)
375
+
376
+ # For AI paths: buffer request and try rerouting through MTRX (live injection)
377
+ if _is_ai_req:
378
+ try:
379
+ req_body = await self._read_body_to_bytes(
380
+ client_reader, req_cl, req_chunked
319
381
  )
320
- asyncio.create_task(
321
- self._ship_telemetry(
322
- hostname=hostname,
323
- method=method,
324
- path=path,
325
- status_code=200,
326
- req_body_size=len(req_body),
327
- resp_body_size=len(resp_body),
328
- elapsed_ms=0,
329
- content_type=resp_headers.get("content-type", ""),
330
- is_streaming=is_streaming,
382
+ except ValueError:
383
+ client_writer.write(b"HTTP/1.1 413 Content Too Large\r\nContent-Length: 0\r\n\r\n")
384
+ await client_writer.drain()
385
+ return
386
+ req_body_size = len(req_body)
387
+ result = await try_reroute_to_matrx(
388
+ path=path,
389
+ method=method,
390
+ req_headers=req_headers,
391
+ req_body=req_body,
392
+ matrx_base_url=self.matrx_base_url,
393
+ matrx_key=self.matrx_key,
394
+ session_id=_req_session_id,
395
+ )
396
+ if result is not None:
397
+ success, resp_headers, resp_body, is_streaming = result
398
+ if success and resp_body is not None:
399
+ self._request_count += 1
400
+ self._write_http_response(
401
+ client_writer, 200, resp_headers, resp_body
331
402
  )
332
- )
333
- continue
334
- # Reroute returned but failed — fall through to forward
335
- # Inject MTRX memory context into request before forwarding
336
- injected_body = await try_inject_context(
337
- req_body=req_body,
338
- req_headers=req_headers,
339
- matrx_base_url=self.matrx_base_url,
340
- matrx_key=self.matrx_key,
341
- session_id=_req_session_id,
342
- )
343
- body_to_forward = injected_body if injected_body is not None else req_body
344
- fwd_headers = dict(req_headers)
345
- fwd_headers["content-length"] = str(len(body_to_forward))
346
- up_writer.write(req_line)
347
- await self._write_headers(up_writer, fwd_headers)
348
- up_writer.write(body_to_forward)
349
- await up_writer.drain()
403
+ asyncio.create_task(
404
+ self._ship_telemetry(
405
+ hostname=hostname,
406
+ method=method,
407
+ path=path,
408
+ status_code=200,
409
+ req_body_size=len(req_body),
410
+ resp_body_size=len(resp_body),
411
+ elapsed_ms=0,
412
+ content_type=resp_headers.get("content-type", ""),
413
+ is_streaming=is_streaming,
414
+ )
415
+ )
416
+ continue
417
+ # Reroute returned but failed — fall through to forward
418
+ # Inject MTRX memory context into request before forwarding
419
+ injected_body = await try_inject_context(
420
+ req_body=req_body,
421
+ req_headers=req_headers,
422
+ matrx_base_url=self.matrx_base_url,
423
+ matrx_key=self.matrx_key,
424
+ session_id=_req_session_id,
425
+ )
426
+ body_to_forward = injected_body if injected_body is not None else req_body
427
+ fwd_headers = dict(req_headers)
428
+ fwd_headers["content-length"] = str(len(body_to_forward))
429
+ up_writer.write(req_line)
430
+ self._write_headers(up_writer, fwd_headers)
431
+ up_writer.write(body_to_forward)
432
+ await up_writer.drain()
433
+ else:
434
+ up_writer.write(req_line)
435
+ self._write_headers(up_writer, req_headers)
436
+ req_body_size = await self._forward_body(
437
+ client_reader, up_writer, req_cl, req_chunked
438
+ )
439
+ if req_body_size == 0 and req_cl > 0:
440
+ req_body_size = req_cl
350
441
  else:
442
+ logger.info("proxy: %s %s%s [ai=%s]", method, hostname, path, False)
351
443
  up_writer.write(req_line)
352
444
  req_headers, req_cl, req_chunked = await self._forward_headers(
353
445
  client_reader, up_writer
@@ -416,10 +508,7 @@ class MITMProxy:
416
508
  )
417
509
  )
418
510
 
419
- conn_h = (
420
- req_headers.get("connection", "")
421
- + resp_headers.get("connection", "")
422
- ).lower()
511
+ conn_h = resp_headers.get("connection", "").lower()
423
512
  if "close" in conn_h:
424
513
  break
425
514
 
@@ -487,7 +576,25 @@ class MITMProxy:
487
576
  await writer.drain()
488
577
  return total, b"".join(parts)
489
578
 
490
- return 0, b""
579
+ # No content-length, no chunked encoding — stream until the upstream closes.
580
+ # This covers Cursor's SSE AI responses that use raw HTTP/1.1 keep-alive streaming.
581
+ # Cap capture at 512 KB to bound memory; bytes beyond that are still forwarded.
582
+ _CAPTURE_LIMIT = 512 * 1024
583
+ parts = []
584
+ total = 0
585
+ capturing = True
586
+ while True:
587
+ chunk = await reader.read(65536)
588
+ if not chunk:
589
+ break
590
+ writer.write(chunk)
591
+ await writer.drain()
592
+ total += len(chunk)
593
+ if capturing:
594
+ parts.append(chunk)
595
+ if total >= _CAPTURE_LIMIT:
596
+ capturing = False
597
+ return total, b"".join(parts)
491
598
 
492
599
  async def _extract_ai_response(
493
600
  self,
@@ -497,110 +604,63 @@ class MITMProxy:
497
604
  ) -> None:
498
605
  """Parse Connect frames from *resp_bytes* and ship response telemetry.
499
606
 
607
+ Tries compiled proto parsing first; falls back to raw wire-format parsing
608
+ so token counts are always extracted even without compiled proto files.
500
609
  Fire-and-forget — never raises, never blocks the forward path.
501
610
  """
502
611
  try:
503
- from matrx.cli.cursor_connect import parse_all_frames
504
- from matrx.cli.cursor_extraction import (
505
- extract_from_response_frame,
506
- parse_response_proto,
507
- ship_ai_telemetry,
508
- )
612
+ from matrx.cli.cursor_extraction import ship_ai_telemetry
613
+
614
+ import gzip as _gzip
615
+ body = resp_bytes
616
+ if len(body) >= 2 and body[:2] == b"\x1f\x8b":
617
+ try:
618
+ body = _gzip.decompress(body)
619
+ except Exception:
620
+ body = resp_bytes
509
621
 
510
- frames = parse_all_frames(resp_bytes)
511
622
  accumulated: dict = {
512
623
  "session_id": session_id,
513
624
  "response_text": "",
514
625
  "tool_calls": [],
515
626
  "usage": None,
516
627
  }
517
- for flags, payload in frames:
518
- if flags == 0x02: # end-of-stream trailer — stop
519
- break
520
- resp_proto = parse_response_proto(payload)
521
- frame_data = extract_from_response_frame(resp_proto)
522
- if frame_data:
523
- accumulated["response_text"] = (
524
- accumulated.get("response_text", "") + frame_data.get("text", "")
525
- )
526
- accumulated["tool_calls"].extend(frame_data.get("tool_calls", []))
527
- if frame_data.get("usage"):
528
- accumulated["usage"] = frame_data["usage"]
628
+
629
+ if hostname == "api.anthropic.com":
630
+ from matrx.cli.cursor_extraction import extract_from_anthropic_sse_response
631
+ frame_data = extract_from_anthropic_sse_response(body)
632
+ accumulated["response_text"] = frame_data.get("text", "")
633
+ accumulated["tool_calls"] = frame_data.get("tool_calls", [])
634
+ accumulated["usage"] = frame_data.get("usage")
635
+ elif hostname == "api.openai.com":
636
+ from matrx.cli.cursor_extraction import extract_from_openai_sse_response
637
+ frame_data = extract_from_openai_sse_response(body)
638
+ accumulated["response_text"] = frame_data.get("text", "")
639
+ accumulated["tool_calls"] = frame_data.get("tool_calls", [])
640
+ accumulated["usage"] = frame_data.get("usage")
641
+ else:
642
+ # Cursor backend: Connect/gRPC protobuf frames
643
+ from matrx.cli.cursor_connect import parse_all_frames
644
+ from matrx.cli.cursor_extraction import (
645
+ _raw_extract_response_frame,
646
+ extract_from_response_frame,
647
+ parse_response_proto,
648
+ )
649
+ for flags, payload in parse_all_frames(body):
650
+ if flags == 0x02:
651
+ break
652
+ resp_proto = parse_response_proto(payload)
653
+ frame_data = extract_from_response_frame(resp_proto) if resp_proto is not None else _raw_extract_response_frame(payload)
654
+ if frame_data:
655
+ accumulated["response_text"] += frame_data.get("text", "")
656
+ accumulated["tool_calls"].extend(frame_data.get("tool_calls", []))
657
+ if frame_data.get("usage"):
658
+ accumulated["usage"] = frame_data["usage"]
529
659
 
530
660
  await ship_ai_telemetry(accumulated, self.matrx_base_url, self.matrx_key)
531
661
  except Exception:
532
662
  logger.debug("proxy: _extract_ai_response failed", exc_info=True)
533
663
 
534
- async def _read_headers_only(
535
- self, reader: asyncio.StreamReader
536
- ) -> tuple[dict[str, str], int, bool]:
537
- """Read headers without writing. Returns (headers_dict, content_length, is_chunked)."""
538
- headers: dict[str, str] = {}
539
- content_length = -1
540
- chunked = False
541
- while True:
542
- line = await reader.readline()
543
- decoded = line.decode("utf-8", errors="replace").strip()
544
- if not decoded:
545
- break
546
- if ":" in decoded:
547
- k, _, v = decoded.partition(":")
548
- k_lower = k.strip().lower()
549
- v_stripped = v.strip()
550
- headers[k_lower] = v_stripped
551
- if k_lower == "content-length":
552
- content_length = int(v_stripped)
553
- elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
554
- chunked = True
555
- return headers, content_length, chunked
556
-
557
- async def _read_body_to_bytes(
558
- self,
559
- reader: asyncio.StreamReader,
560
- content_length: int,
561
- chunked: bool,
562
- ) -> bytes:
563
- """Read body into bytes (no writer)."""
564
- if content_length > 0:
565
- return await reader.read(content_length)
566
- if chunked:
567
- parts: list[bytes] = []
568
- while True:
569
- size_line = await reader.readline()
570
- size_str = size_line.decode("utf-8", errors="replace").strip()
571
- try:
572
- chunk_size = int(size_str.split(";")[0], 16)
573
- except ValueError:
574
- break
575
- if chunk_size == 0:
576
- await reader.readline() # trailer
577
- break
578
- parts.append(await reader.read(chunk_size))
579
- await reader.readline() # crlf
580
- return b"".join(parts)
581
- return b""
582
-
583
- def _write_headers(
584
- self, writer: asyncio.StreamWriter, headers: dict[str, str]
585
- ) -> None:
586
- """Write headers as HTTP lines (caller must drain)."""
587
- for k, v in headers.items():
588
- writer.write(f"{k}: {v}\r\n".encode())
589
- writer.write(b"\r\n")
590
-
591
- def _write_http_response(
592
- self,
593
- writer: asyncio.StreamWriter,
594
- status: int,
595
- resp_headers: dict[str, str],
596
- resp_body: bytes,
597
- ) -> None:
598
- """Write a complete HTTP response."""
599
- writer.write(f"HTTP/1.1 {status} OK\r\n".encode())
600
- self._write_headers(writer, resp_headers)
601
- writer.write(resp_body)
602
- # Caller should drain
603
-
604
664
  async def _forward_headers(
605
665
  self,
606
666
  reader: asyncio.StreamReader,
@@ -625,7 +685,10 @@ class MITMProxy:
625
685
  v_stripped = v.strip()
626
686
  headers[k_lower] = v_stripped
627
687
  if k_lower == "content-length":
628
- content_length = int(v_stripped)
688
+ try:
689
+ content_length = int(v_stripped)
690
+ except ValueError:
691
+ pass
629
692
  elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
630
693
  chunked = True
631
694
  await writer.drain()
@@ -717,7 +780,10 @@ class MITMProxy:
717
780
  v_stripped = v.strip()
718
781
  headers[k_lower] = v_stripped
719
782
  if k_lower == "content-length":
720
- content_length = int(v_stripped)
783
+ try:
784
+ content_length = int(v_stripped)
785
+ except ValueError:
786
+ pass
721
787
  elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
722
788
  chunked = True
723
789
  return headers, content_length, chunked
@@ -730,6 +796,8 @@ class MITMProxy:
730
796
  ) -> bytes:
731
797
  """Read body into bytes."""
732
798
  if content_length > 0:
799
+ if content_length > _MAX_BODY_BYTES:
800
+ raise ValueError(f"Request body too large: {content_length} bytes")
733
801
  return await reader.readexactly(content_length)
734
802
  if chunked:
735
803
  parts: list[bytes] = []