mtrx-cli 0.1.10 → 0.1.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,119 +1,115 @@
1
1
  """
2
- Local HTTP proxy server for routing Cursor IDE requests through MTRX.
2
+ Local MITM forward proxy for transparent Cursor IDE traffic interception.
3
3
 
4
- Sits between Cursor and the MTRX API, injecting X-Matrx-* headers on every
5
- request so that Cursor gets the same full-featured proxy pipeline as
6
- ``mtrx claude`` and ``mtrx codex`` (injection, compression, memory, groups,
7
- observability).
4
+ Sits between Cursor and its backend servers (api2.cursor.sh, api3.cursor.sh,
5
+ etc.), forwarding all traffic unchanged while mirroring metadata to the MTRX
6
+ telemetry API for observability.
8
7
 
9
- Cursor's "Override OpenAI Base URL" sends ALL model requests (OpenAI,
10
- Anthropic, Google, etc.) in OpenAI Chat Completions format to the configured
11
- URL. The MTRX router auto-detects the provider from the model name.
8
+ Cursor is configured to route through this proxy via its ``http.proxy``
9
+ setting. TLS is terminated and re-encrypted using dynamically-signed
10
+ certificates from the MTRX CA (see ``cursor_ca.py``).
11
+
12
+ Design choices (informed by cursor-tap):
13
+ - Force HTTP/1.1 on both client and upstream sides to avoid HTTP/2
14
+ multiplexing complexity.
15
+ - Zero-buffer bidirectional streaming: never hold SSE/gRPC frames.
16
+ - Async telemetry shipping: never block the forwarding path.
12
17
  """
13
18
 
14
19
  from __future__ import annotations
15
20
 
16
21
  import asyncio
17
- import base64
18
- import json
19
22
  import logging
20
- import platform
21
23
  import os
22
- import shutil
23
24
  import signal
24
- import socket
25
- import threading
25
+ import ssl
26
+ import time
26
27
  import uuid
28
+ from pathlib import Path
27
29
  from typing import Any
28
30
 
29
31
  import httpx
30
32
 
31
- logger = logging.getLogger(__name__)
33
+ from matrx.cli.cursor_ca import CertCache, load_ca
32
34
 
33
- _PROXY_API_KEY = "mtrx-cursor-proxy"
34
- _FORWARDED_METHODS = {"POST", "GET", "PUT", "PATCH", "DELETE", "OPTIONS"}
35
+ logger = logging.getLogger(__name__)
35
36
 
37
+ DEFAULT_PORT = 8842
38
+ PROXY_HOST = "127.0.0.1"
39
+ HEALTH_PATH = "/__mtrx_health__"
36
40
 
37
- def _capture_env_snapshot() -> dict:
38
- out: dict = {
39
- "os": platform.system(),
40
- "shell": os.environ.get("SHELL", os.environ.get("COMSPEC", "")),
41
- "cwd": os.getcwd(),
42
- }
43
- out["venv"] = os.environ.get("VIRTUAL_ENV", "") or None
44
- node = shutil.which("node")
45
- if node:
46
- import subprocess as _sp
47
- try:
48
- r = _sp.run([node, "-v"], capture_output=True, text=True, timeout=2)
49
- out["node"] = r.stdout.strip() if r.returncode == 0 else None
50
- except Exception:
51
- out["node"] = None
52
- else:
53
- out["node"] = None
54
- return out
41
+ # Domains whose TLS we intercept for observability.
42
+ _INTERCEPT_DOMAINS = {
43
+ "api2.cursor.sh",
44
+ "api3.cursor.sh",
45
+ "api4.cursor.sh",
46
+ "api5.cursor.sh",
47
+ "agentn.global.api5.cursor.sh",
48
+ }
55
49
 
56
50
 
57
- class CursorProxyServer:
58
- """Async HTTP proxy that injects MTRX headers and forwards to the MTRX API."""
51
+ class MITMProxy:
52
+ """Async MITM forward proxy with telemetry mirroring."""
59
53
 
60
54
  def __init__(
61
55
  self,
62
56
  *,
63
57
  matrx_key: str,
64
58
  matrx_base_url: str,
65
- session_id: str | None = None,
66
- group_id: str = "",
67
- project_id: str = "",
68
- host: str = "127.0.0.1",
69
- port: int = 0,
59
+ host: str = PROXY_HOST,
60
+ port: int = DEFAULT_PORT,
70
61
  ):
71
62
  self.matrx_key = matrx_key
72
63
  self.matrx_base_url = matrx_base_url.rstrip("/")
73
- self.session_id = session_id or str(uuid.uuid4())
74
- self.group_id = group_id
75
- self.project_id = project_id
76
64
  self.host = host
77
65
  self.port = port
78
-
79
- env_snap = _capture_env_snapshot()
80
- self._env_b64 = (
81
- base64.b64encode(json.dumps(env_snap).encode()).decode()
82
- if env_snap
83
- else ""
84
- )
85
66
  self._server: asyncio.Server | None = None
86
- self._http_client: httpx.AsyncClient | None = None
87
- self._loop: asyncio.AbstractEventLoop | None = None
67
+ self._telemetry_client: httpx.AsyncClient | None = None
68
+ self._cert_cache: CertCache | None = None
69
+ self._request_count = 0
70
+
71
+ async def start(self) -> None:
72
+ ca_key, ca_cert = load_ca()
73
+ self._cert_cache = CertCache(ca_key, ca_cert)
74
+ self._telemetry_client = httpx.AsyncClient(timeout=10)
75
+ self._server = await asyncio.start_server(
76
+ self._handle_client, self.host, self.port
77
+ )
78
+ logger.info("MITM proxy listening on %s:%d", self.host, self.port)
79
+
80
+ async def serve_forever(self) -> None:
81
+ if self._server is None:
82
+ await self.start()
83
+ assert self._server is not None
84
+ async with self._server:
85
+ await self._server.serve_forever()
86
+
87
+ async def stop(self) -> None:
88
+ if self._server:
89
+ self._server.close()
90
+ await self._server.wait_closed()
91
+ if self._telemetry_client:
92
+ await self._telemetry_client.aclose()
88
93
 
89
94
  @property
90
- def url(self) -> str:
91
- return f"http://{self.host}:{self.port}/v1"
92
-
93
- def _build_matrx_headers(self) -> dict[str, str]:
94
- headers: dict[str, str] = {
95
- "X-Matrx-Key": self.matrx_key,
96
- "X-Matrx-Agent-Id": "cursor",
97
- "X-Matrx-Provider": "cursor",
98
- "X-Matrx-Session-Id": self.session_id,
99
- }
100
- if self.group_id:
101
- headers["X-Matrx-Group"] = self.group_id
102
- if self.project_id:
103
- headers["X-Matrx-Project-Id"] = self.project_id
104
- if self._env_b64:
105
- headers["X-Matrx-Env"] = self._env_b64
106
- return headers
107
-
108
- async def _handle_request(
95
+ def request_count(self) -> int:
96
+ return self._request_count
97
+
98
+ # -----------------------------------------------------------------
99
+ # Connection handling
100
+ # -----------------------------------------------------------------
101
+
102
+ async def _handle_client(
109
103
  self,
110
104
  reader: asyncio.StreamReader,
111
105
  writer: asyncio.StreamWriter,
112
106
  ) -> None:
113
107
  try:
114
- await self._process_http(reader, writer)
108
+ await self._process(reader, writer)
109
+ except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
110
+ pass
115
111
  except Exception:
116
- logger.debug("cursor_proxy: connection error", exc_info=True)
112
+ logger.debug("proxy: connection error", exc_info=True)
117
113
  finally:
118
114
  try:
119
115
  writer.close()
@@ -121,7 +117,7 @@ class CursorProxyServer:
121
117
  except Exception:
122
118
  pass
123
119
 
124
- async def _process_http(
120
+ async def _process(
125
121
  self,
126
122
  reader: asyncio.StreamReader,
127
123
  writer: asyncio.StreamWriter,
@@ -129,223 +125,425 @@ class CursorProxyServer:
129
125
  request_line = await reader.readline()
130
126
  if not request_line:
131
127
  return
132
- parts = request_line.decode("utf-8", errors="replace").strip().split(" ", 2)
128
+ parts = request_line.decode("utf-8", errors="replace").strip().split()
133
129
  if len(parts) < 3:
134
130
  return
135
- method, raw_path, _ = parts
131
+ method = parts[0].upper()
136
132
 
133
+ # Consume remaining request headers
137
134
  headers_raw: dict[str, str] = {}
138
135
  while True:
139
136
  line = await reader.readline()
140
- decoded = line.decode("utf-8", errors="replace").strip()
141
- if not decoded:
137
+ if not line or line in (b"\r\n", b"\n"):
142
138
  break
139
+ decoded = line.decode("utf-8", errors="replace").strip()
143
140
  if ":" in decoded:
144
- key, _, value = decoded.partition(":")
145
- headers_raw[key.strip().lower()] = value.strip()
141
+ k, _, v = decoded.partition(":")
142
+ headers_raw[k.strip().lower()] = v.strip()
143
+
144
+ if method == "CONNECT":
145
+ target = parts[1]
146
+ hostname, _, port_str = target.rpartition(":")
147
+ port = int(port_str) if port_str else 443
148
+ if not hostname:
149
+ hostname = target
150
+ port = 443
151
+
152
+ writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n")
153
+ await writer.drain()
154
+
155
+ if hostname in _INTERCEPT_DOMAINS:
156
+ await self._mitm_intercept(reader, writer, hostname, port)
157
+ else:
158
+ await self._tunnel_passthrough(reader, writer, hostname, port)
159
+ elif method in ("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"):
160
+ # Plain HTTP proxy request (non-CONNECT) -- handle health check
161
+ path = parts[1] if len(parts) > 1 else "/"
162
+ if HEALTH_PATH in path:
163
+ body = f'{{"status":"ok","requests":{self._request_count}}}'
164
+ resp = (
165
+ f"HTTP/1.1 200 OK\r\n"
166
+ f"Content-Type: application/json\r\n"
167
+ f"Content-Length: {len(body)}\r\n"
168
+ f"Connection: close\r\n\r\n{body}"
169
+ )
170
+ writer.write(resp.encode())
171
+ await writer.drain()
172
+ else:
173
+ writer.write(b"HTTP/1.1 405 Method Not Allowed\r\n\r\n")
174
+ await writer.drain()
175
+ else:
176
+ writer.write(b"HTTP/1.1 405 Method Not Allowed\r\n\r\n")
177
+ await writer.drain()
146
178
 
147
- content_length = int(headers_raw.get("content-length", "0"))
148
- body = b""
149
- if content_length > 0:
150
- body = await reader.readexactly(content_length)
179
+ # -----------------------------------------------------------------
180
+ # Opaque tunnel (non-intercepted domains)
181
+ # -----------------------------------------------------------------
151
182
 
152
- if method not in _FORWARDED_METHODS:
153
- self._send_response(writer, 405, {"error": "Method not allowed"})
183
+ async def _tunnel_passthrough(
184
+ self,
185
+ client_reader: asyncio.StreamReader,
186
+ client_writer: asyncio.StreamWriter,
187
+ hostname: str,
188
+ port: int,
189
+ ) -> None:
190
+ try:
191
+ up_reader, up_writer = await asyncio.open_connection(hostname, port)
192
+ except Exception:
154
193
  return
194
+ await self._pipe_bidirectional(client_reader, client_writer, up_reader, up_writer)
155
195
 
156
- upstream_path = raw_path
157
- if upstream_path.startswith("/v1/"):
158
- upstream_path = "/" + upstream_path[4:]
159
- elif upstream_path == "/v1":
160
- upstream_path = "/"
196
+ # -----------------------------------------------------------------
197
+ # MITM interception (Cursor AI endpoints)
198
+ # -----------------------------------------------------------------
161
199
 
162
- upstream_url = f"{self.matrx_base_url}/v1{upstream_path}"
200
+ async def _mitm_intercept(
201
+ self,
202
+ client_reader: asyncio.StreamReader,
203
+ client_writer: asyncio.StreamWriter,
204
+ hostname: str,
205
+ port: int,
206
+ ) -> None:
207
+ assert self._cert_cache is not None
163
208
 
164
- upstream_headers: dict[str, str] = {}
165
- for key, value in headers_raw.items():
166
- if key in {"host", "connection", "transfer-encoding"}:
167
- continue
168
- if key == "authorization":
169
- continue
170
- upstream_headers[key] = value
209
+ # Use the hostname from the CONNECT request for the cert
210
+ # (matches SNI in virtually all cases, avoids ClientHello peeking)
211
+ server_ctx = self._cert_cache.get_ssl_context(hostname)
171
212
 
172
- upstream_headers.update(self._build_matrx_headers())
173
- upstream_headers["Authorization"] = f"Bearer {self.matrx_key}"
213
+ # Upgrade client connection to TLS (we are the "server")
214
+ loop = asyncio.get_running_loop()
215
+ transport = client_writer.transport
216
+ protocol = transport.get_protocol()
217
+ try:
218
+ new_transport = await loop.start_tls(
219
+ transport, protocol, server_ctx, server_side=True
220
+ )
221
+ except (ssl.SSLError, ConnectionError) as exc:
222
+ logger.debug("TLS handshake with client failed for %s: %s", hostname, exc)
223
+ return
224
+
225
+ tls_writer = asyncio.StreamWriter(new_transport, protocol, client_reader, loop)
226
+
227
+ # Connect to real upstream with TLS
228
+ upstream_ctx = ssl.create_default_context()
229
+ upstream_ctx.set_alpn_protocols(["http/1.1"])
230
+ try:
231
+ up_reader, up_writer = await asyncio.open_connection(
232
+ hostname, port, ssl=upstream_ctx
233
+ )
234
+ except Exception:
235
+ logger.debug("Failed to connect to upstream %s:%d", hostname, port)
236
+ return
174
237
 
175
- is_stream = False
176
- if body and method == "POST":
238
+ # Forward HTTP/1.1 traffic between decrypted client and upstream
239
+ try:
240
+ await self._forward_http(
241
+ client_reader, tls_writer, up_reader, up_writer, hostname
242
+ )
243
+ finally:
177
244
  try:
178
- parsed = json.loads(body)
179
- is_stream = parsed.get("stream", False)
180
- except (json.JSONDecodeError, AttributeError):
245
+ up_writer.close()
246
+ await up_writer.wait_closed()
247
+ except Exception:
181
248
  pass
182
249
 
183
- assert self._http_client is not None
250
+ # -----------------------------------------------------------------
251
+ # HTTP/1.1 forwarding with telemetry
252
+ # -----------------------------------------------------------------
253
+
254
+ async def _forward_http(
255
+ self,
256
+ client_reader: asyncio.StreamReader,
257
+ client_writer: asyncio.StreamWriter,
258
+ up_reader: asyncio.StreamReader,
259
+ up_writer: asyncio.StreamWriter,
260
+ hostname: str,
261
+ ) -> None:
262
+ """Forward HTTP/1.1 request-response pairs, logging each to telemetry."""
263
+ while True:
264
+ req_line = await client_reader.readline()
265
+ if not req_line:
266
+ break
267
+ req_line_str = req_line.decode("utf-8", errors="replace").strip()
268
+ if not req_line_str:
269
+ break
270
+
271
+ parts = req_line_str.split(" ", 2)
272
+ method = parts[0] if parts else "?"
273
+ path = parts[1] if len(parts) > 1 else "/"
184
274
 
185
- if is_stream:
186
- await self._proxy_streaming(
187
- writer, method, upstream_url, upstream_headers, body
275
+ up_writer.write(req_line)
276
+
277
+ req_headers, req_cl, req_chunked = await self._forward_headers(
278
+ client_reader, up_writer
188
279
  )
189
- else:
190
- await self._proxy_buffered(
191
- writer, method, upstream_url, upstream_headers, body
280
+
281
+ req_body_size = await self._forward_body(
282
+ client_reader, up_writer, req_cl, req_chunked
283
+ )
284
+
285
+ started = time.monotonic()
286
+
287
+ resp_line = await up_reader.readline()
288
+ if not resp_line:
289
+ break
290
+ client_writer.write(resp_line)
291
+
292
+ resp_line_str = resp_line.decode("utf-8", errors="replace").strip()
293
+ status_code = 0
294
+ rp = resp_line_str.split(" ", 2)
295
+ if len(rp) >= 2:
296
+ try:
297
+ status_code = int(rp[1])
298
+ except ValueError:
299
+ pass
300
+
301
+ resp_headers, resp_cl, resp_chunked = await self._forward_headers(
302
+ up_reader, client_writer
303
+ )
304
+
305
+ content_type = resp_headers.get("content-type", "")
306
+ is_streaming = any(
307
+ t in content_type
308
+ for t in ("text/event-stream", "grpc", "proto", "connect")
309
+ )
310
+
311
+ resp_body_size = await self._forward_body(
312
+ up_reader, client_writer, resp_cl, resp_chunked
313
+ )
314
+
315
+ elapsed_ms = int((time.monotonic() - started) * 1000)
316
+ self._request_count += 1
317
+
318
+ asyncio.create_task(
319
+ self._ship_telemetry(
320
+ hostname=hostname,
321
+ method=method,
322
+ path=path,
323
+ status_code=status_code,
324
+ req_body_size=req_body_size,
325
+ resp_body_size=resp_body_size,
326
+ elapsed_ms=elapsed_ms,
327
+ content_type=content_type,
328
+ is_streaming=is_streaming,
329
+ )
192
330
  )
193
331
 
194
- async def _proxy_buffered(
332
+ conn_h = (
333
+ req_headers.get("connection", "")
334
+ + resp_headers.get("connection", "")
335
+ ).lower()
336
+ if "close" in conn_h:
337
+ break
338
+
339
+ async def _forward_headers(
195
340
  self,
341
+ reader: asyncio.StreamReader,
196
342
  writer: asyncio.StreamWriter,
197
- method: str,
198
- url: str,
199
- headers: dict[str, str],
200
- body: bytes,
201
- ) -> None:
202
- assert self._http_client is not None
203
- try:
204
- resp = await self._http_client.request(
205
- method, url, headers=headers, content=body, timeout=120
206
- )
207
- resp_headers = {
208
- "Content-Type": resp.headers.get("content-type", "application/json"),
209
- "Content-Length": str(len(resp.content)),
210
- }
211
- for h in ("x-matrx-request-id", "x-matrx-latency-ms", "x-matrx-tokens-saved"):
212
- if h in resp.headers:
213
- resp_headers[h] = resp.headers[h]
214
- self._send_raw_response(writer, resp.status_code, resp_headers, resp.content)
215
- except httpx.HTTPError as exc:
216
- logger.warning("cursor_proxy: upstream error: %s", exc)
217
- self._send_response(writer, 502, {"error": f"Upstream error: {exc}"})
218
-
219
- async def _proxy_streaming(
343
+ ) -> tuple[dict[str, str], int, bool]:
344
+ """Read headers from *reader*, write to *writer*.
345
+
346
+ Returns (headers_dict, content_length, is_chunked).
347
+ """
348
+ headers: dict[str, str] = {}
349
+ content_length = -1
350
+ chunked = False
351
+ while True:
352
+ line = await reader.readline()
353
+ writer.write(line)
354
+ decoded = line.decode("utf-8", errors="replace").strip()
355
+ if not decoded:
356
+ break
357
+ if ":" in decoded:
358
+ k, _, v = decoded.partition(":")
359
+ k_lower = k.strip().lower()
360
+ v_stripped = v.strip()
361
+ headers[k_lower] = v_stripped
362
+ if k_lower == "content-length":
363
+ content_length = int(v_stripped)
364
+ elif k_lower == "transfer-encoding" and "chunked" in v_stripped.lower():
365
+ chunked = True
366
+ await writer.drain()
367
+ return headers, content_length, chunked
368
+
369
+ async def _forward_body(
220
370
  self,
371
+ reader: asyncio.StreamReader,
221
372
  writer: asyncio.StreamWriter,
222
- method: str,
223
- url: str,
224
- headers: dict[str, str],
225
- body: bytes,
226
- ) -> None:
227
- assert self._http_client is not None
228
- try:
229
- async with self._http_client.stream(
230
- method, url, headers=headers, content=body, timeout=300
231
- ) as resp:
232
- resp_headers = {
233
- "Content-Type": resp.headers.get(
234
- "content-type", "text/event-stream"
235
- ),
236
- "Cache-Control": "no-cache",
237
- "Transfer-Encoding": "chunked",
238
- }
239
- for h in ("x-matrx-request-id", "x-matrx-latency-ms", "x-matrx-tokens-saved"):
240
- if h in resp.headers:
241
- resp_headers[h] = resp.headers[h]
242
-
243
- header_block = f"HTTP/1.1 {resp.status_code} OK\r\n"
244
- for k, v in resp_headers.items():
245
- header_block += f"{k}: {v}\r\n"
246
- header_block += "\r\n"
247
- writer.write(header_block.encode("utf-8"))
248
- await writer.drain()
249
-
250
- async for chunk in resp.aiter_bytes():
251
- chunk_header = f"{len(chunk):x}\r\n".encode("utf-8")
252
- writer.write(chunk_header + chunk + b"\r\n")
253
- await writer.drain()
254
-
255
- writer.write(b"0\r\n\r\n")
256
- await writer.drain()
257
- except httpx.HTTPError as exc:
258
- logger.warning("cursor_proxy: streaming error: %s", exc)
259
- self._send_response(writer, 502, {"error": f"Upstream error: {exc}"})
373
+ content_length: int,
374
+ chunked: bool,
375
+ ) -> int:
376
+ """Forward request/response body. Returns total bytes forwarded."""
377
+ if content_length > 0:
378
+ return await self._forward_fixed(reader, writer, content_length)
379
+ if chunked:
380
+ return await self._forward_chunked(reader, writer)
381
+ return 0
260
382
 
261
- @staticmethod
262
- def _send_response(
383
+ async def _forward_fixed(
384
+ self,
385
+ reader: asyncio.StreamReader,
263
386
  writer: asyncio.StreamWriter,
264
- status: int,
265
- body: dict[str, Any],
266
- ) -> None:
267
- content = json.dumps(body).encode("utf-8")
268
- CursorProxyServer._send_raw_response(
269
- writer,
270
- status,
271
- {"Content-Type": "application/json", "Content-Length": str(len(content))},
272
- content,
273
- )
387
+ length: int,
388
+ ) -> int:
389
+ total = 0
390
+ remaining = length
391
+ while remaining > 0:
392
+ chunk = await reader.read(min(remaining, 65536))
393
+ if not chunk:
394
+ break
395
+ writer.write(chunk)
396
+ await writer.drain()
397
+ total += len(chunk)
398
+ remaining -= len(chunk)
399
+ return total
274
400
 
275
- @staticmethod
276
- def _send_raw_response(
401
+ async def _forward_chunked(
402
+ self,
403
+ reader: asyncio.StreamReader,
277
404
  writer: asyncio.StreamWriter,
278
- status: int,
279
- headers: dict[str, str],
280
- content: bytes,
405
+ ) -> int:
406
+ total = 0
407
+ while True:
408
+ size_line = await reader.readline()
409
+ if not size_line:
410
+ break
411
+ writer.write(size_line)
412
+ await writer.drain()
413
+ size_str = size_line.decode("utf-8", errors="replace").strip()
414
+ try:
415
+ chunk_size = int(size_str.split(";")[0], 16)
416
+ except ValueError:
417
+ break
418
+ if chunk_size == 0:
419
+ trailer = await reader.readline()
420
+ writer.write(trailer)
421
+ await writer.drain()
422
+ break
423
+ remaining = chunk_size
424
+ while remaining > 0:
425
+ data = await reader.read(min(remaining, 65536))
426
+ if not data:
427
+ return total
428
+ writer.write(data)
429
+ await writer.drain()
430
+ total += len(data)
431
+ remaining -= len(data)
432
+ crlf = await reader.readline()
433
+ writer.write(crlf)
434
+ await writer.drain()
435
+ return total
436
+
437
+ # -----------------------------------------------------------------
438
+ # Raw bidirectional pipe (for opaque tunnels)
439
+ # -----------------------------------------------------------------
440
+
441
+ async def _pipe_bidirectional(
442
+ self,
443
+ r1: asyncio.StreamReader,
444
+ w1: asyncio.StreamWriter,
445
+ r2: asyncio.StreamReader,
446
+ w2: asyncio.StreamWriter,
281
447
  ) -> None:
282
- reason = "OK" if 200 <= status < 300 else "Error"
283
- header_block = f"HTTP/1.1 {status} {reason}\r\n"
284
- for k, v in headers.items():
285
- header_block += f"{k}: {v}\r\n"
286
- header_block += "\r\n"
287
- writer.write(header_block.encode("utf-8") + content)
288
-
289
- def _pick_port(self) -> int:
290
- if self.port:
291
- return self.port
292
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
293
- s.bind(("127.0.0.1", 0))
294
- return s.getsockname()[1]
295
-
296
- async def _run(self) -> None:
297
- self.port = self._pick_port()
298
- self._http_client = httpx.AsyncClient(http2=False, follow_redirects=True)
299
- self._server = await asyncio.start_server(
300
- self._handle_request, self.host, self.port
301
- )
302
- logger.info("cursor_proxy: listening on %s:%s", self.host, self.port)
303
- async with self._server:
304
- await self._server.serve_forever()
305
-
306
- async def _shutdown(self) -> None:
307
- if self._server:
308
- self._server.close()
309
- await self._server.wait_closed()
310
- if self._http_client:
311
- await self._http_client.aclose()
312
-
313
- def start_background(self) -> None:
314
- """Start the proxy in a background daemon thread. Returns once listening."""
315
- ready = threading.Event()
316
-
317
- def _run_loop() -> None:
318
- loop = asyncio.new_event_loop()
319
- asyncio.set_event_loop(loop)
320
- self._loop = loop
321
-
322
- async def _start_and_signal() -> None:
323
- self.port = self._pick_port()
324
- self._http_client = httpx.AsyncClient(
325
- http2=False, follow_redirects=True
326
- )
327
- self._server = await asyncio.start_server(
328
- self._handle_request, self.host, self.port
329
- )
330
- ready.set()
331
- async with self._server:
332
- await self._server.serve_forever()
333
-
448
+ async def _copy(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
334
449
  try:
335
- loop.run_until_complete(_start_and_signal())
336
- except asyncio.CancelledError:
450
+ while True:
451
+ data = await src.read(65536)
452
+ if not data:
453
+ break
454
+ dst.write(data)
455
+ await dst.drain()
456
+ except Exception:
337
457
  pass
338
458
  finally:
339
- loop.run_until_complete(self._shutdown())
340
- loop.close()
341
-
342
- thread = threading.Thread(target=_run_loop, daemon=True)
343
- thread.start()
344
- ready.wait(timeout=10)
345
- if not ready.is_set():
346
- raise RuntimeError("Cursor proxy failed to start within 10 seconds")
347
-
348
- def stop(self) -> None:
349
- """Stop the background proxy."""
350
- if self._server and self._loop and self._loop.is_running():
351
- self._loop.call_soon_threadsafe(self._server.close)
459
+ try:
460
+ dst.close()
461
+ except Exception:
462
+ pass
463
+
464
+ await asyncio.gather(_copy(r1, w2), _copy(r2, w1))
465
+
466
+ # -----------------------------------------------------------------
467
+ # Telemetry
468
+ # -----------------------------------------------------------------
469
+
470
+ async def _ship_telemetry(
471
+ self,
472
+ *,
473
+ hostname: str,
474
+ method: str,
475
+ path: str,
476
+ status_code: int,
477
+ req_body_size: int,
478
+ resp_body_size: int,
479
+ elapsed_ms: int,
480
+ content_type: str,
481
+ is_streaming: bool,
482
+ ) -> None:
483
+ if self._telemetry_client is None:
484
+ return
485
+
486
+ payload = {
487
+ "timestamp": time.time(),
488
+ "hostname": hostname,
489
+ "method": method,
490
+ "path": path,
491
+ "status_code": status_code,
492
+ "req_bytes": req_body_size,
493
+ "resp_bytes": resp_body_size,
494
+ "elapsed_ms": elapsed_ms,
495
+ "content_type": content_type,
496
+ "streaming": is_streaming,
497
+ }
498
+ url = f"{self.matrx_base_url}/v1/telemetry/cursor"
499
+ try:
500
+ await self._telemetry_client.post(
501
+ url,
502
+ json=payload,
503
+ headers={"X-Matrx-Key": self.matrx_key},
504
+ )
505
+ except Exception:
506
+ logger.debug("telemetry ship failed", exc_info=True)
507
+
508
+
509
+ # ---------------------------------------------------------------------------
510
+ # Entry-point for running the proxy as a standalone process (daemon use)
511
+ # ---------------------------------------------------------------------------
512
+
513
+ def run_proxy(
514
+ *,
515
+ matrx_key: str,
516
+ matrx_base_url: str,
517
+ host: str = PROXY_HOST,
518
+ port: int = DEFAULT_PORT,
519
+ pid_file: Path | None = None,
520
+ ) -> None:
521
+ """Run the MITM proxy (blocking). Intended for daemon/service use."""
522
+ if pid_file:
523
+ pid_file.parent.mkdir(parents=True, exist_ok=True)
524
+ pid_file.write_text(str(os.getpid()), encoding="utf-8")
525
+
526
+ proxy = MITMProxy(
527
+ matrx_key=matrx_key,
528
+ matrx_base_url=matrx_base_url,
529
+ host=host,
530
+ port=port,
531
+ )
532
+
533
+ loop = asyncio.new_event_loop()
534
+ asyncio.set_event_loop(loop)
535
+
536
+ for sig in (signal.SIGTERM, signal.SIGINT):
537
+ try:
538
+ loop.add_signal_handler(sig, lambda: loop.create_task(proxy.stop()))
539
+ except NotImplementedError:
540
+ pass
541
+
542
+ try:
543
+ loop.run_until_complete(proxy.serve_forever())
544
+ except (KeyboardInterrupt, SystemExit):
545
+ loop.run_until_complete(proxy.stop())
546
+ finally:
547
+ if pid_file and pid_file.exists():
548
+ pid_file.unlink(missing_ok=True)
549
+ loop.close()