zwarm 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,9 +8,12 @@ Uses codex mcp-server for true iterative conversations:
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- import asyncio
12
11
  import json
12
+ import logging
13
+ import queue
13
14
  import subprocess
15
+ import threading
16
+ import time
14
17
  from pathlib import Path
15
18
  from typing import Any, Literal
16
19
 
@@ -23,113 +26,270 @@ from zwarm.core.models import (
23
26
  SessionStatus,
24
27
  )
25
28
 
29
+ logger = logging.getLogger(__name__)
30
+
26
31
 
27
32
  class MCPClient:
28
- """Minimal MCP client for communicating with codex mcp-server."""
33
+ """
34
+ Robust MCP client for communicating with codex mcp-server.
35
+
36
+ Uses subprocess.Popen (NOT asyncio.subprocess) to avoid being tied to
37
+ any specific event loop. This allows the MCP server to stay alive across
38
+ multiple asyncio.run() calls, preserving conversation state.
39
+
40
+ Uses dedicated reader threads that queue lines, avoiding the race condition
41
+ of spawning new reader threads on timeout.
42
+ """
29
43
 
30
- def __init__(self, proc: subprocess.Popen):
31
- self.proc = proc
44
+ def __init__(self):
45
+ self._proc: subprocess.Popen | None = None
32
46
  self._request_id = 0
33
47
  self._initialized = False
48
+ self._stderr_thread: threading.Thread | None = None
49
+ self._stdout_thread: threading.Thread | None = None
50
+ self._stderr_lines: list[str] = []
51
+ self._stdout_queue: queue.Queue[str | None] = queue.Queue()
52
+ self._lock = threading.Lock() # Protect writes only
53
+
54
+ def start(self) -> None:
55
+ """Start the MCP server process."""
56
+ with self._lock:
57
+ if self._proc is not None and self._proc.poll() is None:
58
+ return # Already running
59
+
60
+ logger.info("Starting codex mcp-server...")
61
+ self._proc = subprocess.Popen(
62
+ ["codex", "mcp-server"],
63
+ stdin=subprocess.PIPE,
64
+ stdout=subprocess.PIPE,
65
+ stderr=subprocess.PIPE,
66
+ text=False, # Binary mode for explicit encoding control
67
+ )
68
+ self._initialized = False
69
+ self._stderr_lines = []
70
+ self._stdout_queue = queue.Queue() # Fresh queue
71
+
72
+ # Start background thread to read stderr
73
+ self._stderr_thread = threading.Thread(
74
+ target=self._read_stderr_loop,
75
+ daemon=True,
76
+ name="mcp-stderr-reader",
77
+ )
78
+ self._stderr_thread.start()
79
+
80
+ # Start background thread to read stdout into queue
81
+ self._stdout_thread = threading.Thread(
82
+ target=self._read_stdout_loop,
83
+ daemon=True,
84
+ name="mcp-stdout-reader",
85
+ )
86
+ self._stdout_thread.start()
87
+
88
+ logger.info(f"MCP server started (pid={self._proc.pid})")
89
+
90
+ def _read_stderr_loop(self) -> None:
91
+ """Background thread to read stderr and log errors."""
92
+ if not self._proc or not self._proc.stderr:
93
+ return
94
+ try:
95
+ while True:
96
+ line = self._proc.stderr.readline()
97
+ if not line:
98
+ break
99
+ decoded = line.decode().strip()
100
+ if decoded:
101
+ self._stderr_lines.append(decoded)
102
+ # Keep only last 100 lines
103
+ if len(self._stderr_lines) > 100:
104
+ self._stderr_lines = self._stderr_lines[-100:]
105
+ # Log errors prominently
106
+ if "error" in decoded.lower() or "ERROR" in decoded:
107
+ logger.error(f"[MCP stderr] {decoded}")
108
+ else:
109
+ logger.debug(f"[MCP stderr] {decoded}")
110
+ except Exception as e:
111
+ logger.warning(f"stderr reader stopped: {e}")
112
+
113
+ def _read_stdout_loop(self) -> None:
114
+ """Background thread to read stdout and queue lines."""
115
+ if not self._proc or not self._proc.stdout:
116
+ return
117
+ try:
118
+ while True:
119
+ line = self._proc.stdout.readline()
120
+ if not line:
121
+ # EOF - signal end
122
+ self._stdout_queue.put(None)
123
+ break
124
+ decoded = line.decode()
125
+ self._stdout_queue.put(decoded)
126
+ except Exception as e:
127
+ logger.warning(f"stdout reader stopped: {e}")
128
+ self._stdout_queue.put(None) # Signal error
34
129
 
35
130
  def _next_id(self) -> int:
36
131
  self._request_id += 1
37
132
  return self._request_id
38
133
 
39
- async def _read_line(self) -> str:
40
- """Read a line from stdout asynchronously."""
41
- loop = asyncio.get_event_loop()
42
- return await loop.run_in_executor(None, self.proc.stdout.readline)
134
+ def _write(self, data: str) -> None:
135
+ """Write to stdin with error handling."""
136
+ if not self._proc or not self._proc.stdin:
137
+ raise RuntimeError("MCP server not running")
138
+ if self._proc.poll() is not None:
139
+ raise RuntimeError(f"MCP server died (exit code {self._proc.returncode})")
43
140
 
44
- async def send_request(self, method: str, params: dict | None = None) -> dict:
45
- """Send JSON-RPC request and wait for response."""
46
- request: dict[str, Any] = {
47
- "jsonrpc": "2.0",
48
- "id": self._next_id(),
49
- "method": method,
50
- }
51
- if params:
52
- request["params"] = params
141
+ self._proc.stdin.write(data.encode())
142
+ self._proc.stdin.flush()
53
143
 
54
- request_line = json.dumps(request) + "\n"
144
+ def _read_line(self, timeout: float = 120.0) -> str:
145
+ """
146
+ Read a line from the stdout queue with timeout.
55
147
 
56
- # Write request
57
- self.proc.stdin.write(request_line)
58
- self.proc.stdin.flush()
148
+ Uses a dedicated reader thread that queues lines, so we never
149
+ lose data on timeout - we just haven't received it yet.
150
+ """
151
+ if not self._proc:
152
+ raise RuntimeError("MCP server not running")
153
+
154
+ try:
155
+ line = self._stdout_queue.get(timeout=timeout)
156
+ except queue.Empty:
157
+ # Timeout - check process health
158
+ if self._proc.poll() is not None:
159
+ stderr_context = "\n".join(self._stderr_lines[-10:]) if self._stderr_lines else "(no stderr)"
160
+ raise RuntimeError(
161
+ f"MCP server died (exit code {self._proc.returncode}).\n"
162
+ f"Recent stderr:\n{stderr_context}"
163
+ )
164
+ # Process still alive, just slow - return empty to let caller decide
165
+ return ""
166
+
167
+ if line is None:
168
+ # EOF or error from reader thread
169
+ stderr_context = "\n".join(self._stderr_lines[-10:]) if self._stderr_lines else "(no stderr)"
170
+ if self._proc.poll() is not None:
171
+ raise RuntimeError(
172
+ f"MCP server exited (code {self._proc.returncode}).\n"
173
+ f"Recent stderr:\n{stderr_context}"
174
+ )
175
+ raise RuntimeError(f"MCP stdout closed unexpectedly.\nRecent stderr:\n{stderr_context}")
176
+
177
+ return line
178
+
179
+ def _check_alive(self) -> None:
180
+ """Check if the MCP server is still alive, raise if not."""
181
+ if not self._proc:
182
+ raise RuntimeError("MCP server not started")
183
+ if self._proc.poll() is not None:
184
+ stderr_context = "\n".join(self._stderr_lines[-10:]) if self._stderr_lines else "(no stderr)"
185
+ raise RuntimeError(
186
+ f"MCP server died (exit code {self._proc.returncode}).\n"
187
+ f"Recent stderr:\n{stderr_context}"
188
+ )
59
189
 
60
- # Read response
61
- response_line = await self._read_line()
190
+ def initialize(self) -> dict:
191
+ """Initialize MCP connection."""
192
+ self._check_alive()
62
193
 
194
+ request = {
195
+ "jsonrpc": "2.0",
196
+ "id": self._next_id(),
197
+ "method": "initialize",
198
+ "params": {
199
+ "protocolVersion": "2024-11-05",
200
+ "capabilities": {},
201
+ "clientInfo": {"name": "zwarm", "version": "0.1.0"},
202
+ },
203
+ }
204
+ with self._lock:
205
+ self._write(json.dumps(request) + "\n")
206
+
207
+ response_line = self._read_line(timeout=30.0)
63
208
  if not response_line:
64
- raise RuntimeError("No response from MCP server")
209
+ raise RuntimeError("No response from MCP server during init")
65
210
 
66
211
  response = json.loads(response_line)
67
-
68
- # Check for error
69
212
  if "error" in response:
70
- error = response["error"]
71
- raise RuntimeError(f"MCP error: {error.get('message', error)}")
72
-
73
- return response
74
-
75
- async def initialize(self) -> dict:
76
- """Initialize MCP connection."""
77
- result = await self.send_request("initialize", {
78
- "protocolVersion": "2024-11-05",
79
- "capabilities": {},
80
- "clientInfo": {"name": "zwarm", "version": "0.1.0"},
81
- })
213
+ raise RuntimeError(f"MCP init error: {response['error']}")
82
214
 
83
215
  # Send initialized notification
84
- notif = json.dumps({
85
- "jsonrpc": "2.0",
86
- "method": "notifications/initialized",
87
- }) + "\n"
88
- self.proc.stdin.write(notif)
89
- self.proc.stdin.flush()
216
+ notif = {"jsonrpc": "2.0", "method": "notifications/initialized"}
217
+ with self._lock:
218
+ self._write(json.dumps(notif) + "\n")
90
219
 
91
220
  self._initialized = True
92
- return result
221
+ logger.info("MCP connection initialized")
222
+ return response
93
223
 
94
- async def call_tool(self, name: str, arguments: dict) -> dict:
224
+ def call_tool(self, name: str, arguments: dict, timeout: float = 300.0) -> dict:
95
225
  """
96
226
  Call an MCP tool and collect streaming events.
97
227
 
98
- Codex MCP uses streaming events, so we read multiple responses
99
- until we get the final result.
228
+ Args:
229
+ name: Tool name (codex, codex-reply)
230
+ arguments: Tool arguments
231
+ timeout: Overall timeout for the call (default 5 min)
100
232
  """
233
+ self._check_alive()
234
+
101
235
  if not self._initialized:
102
- await self.initialize()
236
+ self.initialize()
103
237
 
104
238
  request_id = self._next_id()
105
239
  request = {
106
240
  "jsonrpc": "2.0",
107
241
  "id": request_id,
108
242
  "method": "tools/call",
109
- "params": {
110
- "name": name,
111
- "arguments": arguments,
112
- },
243
+ "params": {"name": name, "arguments": arguments},
113
244
  }
114
- self.proc.stdin.write(json.dumps(request) + "\n")
115
- self.proc.stdin.flush()
245
+
246
+ logger.debug(f"Calling MCP tool: {name} with args: {list(arguments.keys())}")
247
+ with self._lock:
248
+ self._write(json.dumps(request) + "\n")
116
249
 
117
250
  # Collect streaming events until final result
251
+ # Reader thread queues lines, we pull from queue with timeout
118
252
  session_id = None
119
253
  agent_messages: list[str] = []
254
+ streaming_text: list[str] = [] # Accumulate streaming delta text
120
255
  final_result = None
256
+ token_usage: dict[str, Any] = {} # Track token usage
257
+ start_time = time.time()
258
+
259
+ for event_count in range(1000): # Safety limit on events
260
+ self._check_alive()
261
+
262
+ # Check overall timeout
263
+ elapsed = time.time() - start_time
264
+ if elapsed > timeout:
265
+ raise RuntimeError(f"MCP call timed out after {timeout}s ({event_count} events received)")
266
+
267
+ # Read from queue with per-event timeout
268
+ # Empty string = timeout (process still alive, just waiting)
269
+ # None sentinel is handled inside _read_line (raises RuntimeError)
270
+ line = self._read_line(timeout=30.0)
121
271
 
122
- for _ in range(500): # Safety limit on events
123
- line = await self._read_line()
124
272
  if not line:
125
- break
273
+ # Timeout waiting for event - process is still alive, just slow
274
+ # This is normal during long codex operations
275
+ logger.debug(f"Waiting for MCP event... (elapsed: {elapsed:.0f}s, events: {event_count})")
276
+ continue
126
277
 
127
- event = json.loads(line)
278
+ try:
279
+ event = json.loads(line)
280
+ except json.JSONDecodeError as e:
281
+ logger.warning(f"Invalid JSON from MCP: {line[:100]}... - {e}")
282
+ continue
128
283
 
129
284
  # Check for final result (has matching id)
130
- if event.get("id") == request_id and "result" in event:
131
- final_result = event.get("result", {})
132
- break
285
+ if event.get("id") == request_id:
286
+ if "result" in event:
287
+ final_result = event["result"]
288
+ logger.debug(f"Got final result after {event_count} events")
289
+ break
290
+ elif "error" in event:
291
+ error = event["error"]
292
+ raise RuntimeError(f"MCP tool error: {error.get('message', error)}")
133
293
 
134
294
  # Process streaming events
135
295
  if event.get("method") == "codex/event":
@@ -137,35 +297,157 @@ class MCPClient:
137
297
  msg = params.get("msg", {})
138
298
  msg_type = msg.get("type")
139
299
 
300
+ # Log ALL event types to help debug missing messages
301
+ logger.debug(f"MCP event: type={msg_type}, keys={list(msg.keys())}")
302
+
140
303
  if msg_type == "session_configured":
141
304
  session_id = msg.get("session_id")
305
+ logger.debug(f"Session configured: {session_id}")
306
+
307
+ elif msg_type == "item_completed":
308
+ item = msg.get("item", {})
309
+ item_type = item.get("type")
310
+
311
+ # Agent text responses - codex uses "AgentMessage" type
312
+ if item_type == "AgentMessage":
313
+ content = item.get("content", [])
314
+ for block in content:
315
+ if isinstance(block, dict) and block.get("text"):
316
+ agent_messages.append(block["text"])
317
+ elif isinstance(block, str):
318
+ agent_messages.append(block)
319
+
320
+ # Legacy format check
321
+ elif item_type == "message" and item.get("role") == "assistant":
322
+ content = item.get("content", [])
323
+ for block in content:
324
+ if isinstance(block, dict) and block.get("text"):
325
+ agent_messages.append(block["text"])
326
+ elif isinstance(block, str):
327
+ agent_messages.append(block)
328
+
329
+ # Function call outputs (for context)
330
+ elif item_type == "function_call_output":
331
+ output = item.get("output", "")
332
+ if output and len(output) < 1000:
333
+ agent_messages.append(f"[Tool output]: {output[:500]}")
334
+
335
+ # Log other item types we're not handling
336
+ elif item_type not in ("function_call", "tool_call", "UserMessage"):
337
+ logger.debug(f"Unhandled item_completed type: {item_type}, keys: {list(item.keys())}")
338
+
142
339
  elif msg_type == "agent_message":
143
- agent_messages.append(msg.get("message", ""))
144
- elif msg_type == "task_completed":
145
- # Task is done, break
340
+ # Direct agent message event
341
+ message = msg.get("message", "")
342
+ if message:
343
+ agent_messages.append(message)
344
+
345
+ elif msg_type in ("task_complete", "task_completed"):
346
+ # Task is done - capture last_agent_message as fallback
347
+ last_msg = msg.get("last_agent_message")
348
+ if last_msg and last_msg not in agent_messages:
349
+ agent_messages.append(last_msg)
350
+ logger.debug(f"Task complete after {event_count} events")
146
351
  break
147
- elif msg_type == "error":
148
- raise RuntimeError(f"Codex error: {msg.get('error', msg)}")
149
352
 
150
- # Build result from collected events
353
+ elif msg_type == "token_count":
354
+ # Capture token usage for cost tracking
355
+ info = msg.get("info") or {}
356
+ if info:
357
+ usage = info.get("total_token_usage", {})
358
+ if usage:
359
+ token_usage = {
360
+ "input_tokens": usage.get("input_tokens", 0),
361
+ "output_tokens": usage.get("output_tokens", 0),
362
+ "cached_input_tokens": usage.get("cached_input_tokens", 0),
363
+ "reasoning_tokens": usage.get("reasoning_output_tokens", 0),
364
+ "total_tokens": usage.get("total_tokens", 0),
365
+ }
366
+ logger.debug(f"Token usage: {token_usage}")
367
+
368
+ elif msg_type == "error":
369
+ error_msg = msg.get("error", msg.get("message", str(msg)))
370
+ raise RuntimeError(f"Codex error: {error_msg}")
371
+
372
+ # Handle streaming text events (various formats)
373
+ elif msg_type in ("text_delta", "content_block_delta", "message_delta"):
374
+ delta = msg.get("delta", {})
375
+ text = delta.get("text", "") or msg.get("text", "")
376
+ if text:
377
+ streaming_text.append(text)
378
+
379
+ elif msg_type == "text":
380
+ text = msg.get("text", "")
381
+ if text:
382
+ streaming_text.append(text)
383
+
384
+ elif msg_type == "response":
385
+ # Some versions send the full response this way
386
+ response_text = msg.get("response", "") or msg.get("text", "")
387
+ if response_text:
388
+ agent_messages.append(response_text)
389
+
390
+ elif msg_type == "message":
391
+ # Direct message event
392
+ text = msg.get("text", "") or msg.get("content", "")
393
+ if text:
394
+ agent_messages.append(text)
395
+
396
+ else:
397
+ # Log unknown event types at debug level to help diagnose
398
+ if msg_type and msg_type not in ("session_started", "thinking", "tool_call", "function_call"):
399
+ logger.debug(f"Unhandled MCP event type: {msg_type}, msg keys: {list(msg.keys())}")
400
+
401
+ # Merge streaming text into messages if we got any
402
+ if streaming_text:
403
+ full_streaming = "".join(streaming_text)
404
+ if full_streaming.strip():
405
+ agent_messages.append(full_streaming)
406
+ logger.debug(f"Captured {len(streaming_text)} streaming chunks ({len(full_streaming)} chars)")
407
+
408
+ # Build result
151
409
  result = {
152
410
  "conversationId": session_id,
153
411
  "messages": agent_messages,
154
412
  "output": "\n".join(agent_messages) if agent_messages else "",
413
+ "usage": token_usage, # Token usage for cost tracking
155
414
  }
415
+
416
+ # Merge final result and try to extract content if no messages
156
417
  if final_result:
157
418
  result.update(final_result)
158
-
419
+ if not agent_messages and "content" in final_result:
420
+ content = final_result["content"]
421
+ if isinstance(content, list):
422
+ for block in content:
423
+ if isinstance(block, dict) and block.get("text"):
424
+ agent_messages.append(block["text"])
425
+ if agent_messages:
426
+ result["messages"] = agent_messages
427
+ result["output"] = "\n".join(agent_messages)
428
+
429
+ logger.debug(f"MCP call complete: {len(agent_messages)} messages, session={session_id}")
159
430
  return result
160
431
 
161
432
  def close(self) -> None:
162
- """Close the MCP connection."""
163
- if self.proc and self.proc.poll() is None:
164
- self.proc.terminate()
433
+ """Close the MCP connection gracefully."""
434
+ if self._proc and self._proc.poll() is None:
435
+ logger.info("Terminating MCP server...")
436
+ self._proc.terminate()
165
437
  try:
166
- self.proc.wait(timeout=5)
438
+ self._proc.wait(timeout=5)
167
439
  except subprocess.TimeoutExpired:
168
- self.proc.kill()
440
+ logger.warning("MCP server didn't terminate, killing...")
441
+ self._proc.kill()
442
+ self._proc.wait()
443
+
444
+ self._proc = None
445
+ self._initialized = False
446
+
447
+ @property
448
+ def is_alive(self) -> bool:
449
+ """Check if the MCP server is running."""
450
+ return self._proc is not None and self._proc.poll() is None
169
451
 
170
452
 
171
453
  class CodexMCPAdapter(ExecutorAdapter):
@@ -173,40 +455,50 @@ class CodexMCPAdapter(ExecutorAdapter):
173
455
  Codex adapter using MCP server for sync conversations.
174
456
 
175
457
  This is the recommended way to have iterative conversations with Codex.
458
+ The MCP client uses subprocess.Popen (not asyncio) so it persists across
459
+ multiple asyncio.run() calls, preserving conversation state.
176
460
  """
177
461
 
178
462
  name = "codex_mcp"
463
+ DEFAULT_MODEL = "gpt-5.1-codex-mini" # Default codex model
179
464
 
180
- def __init__(self):
465
+ def __init__(self, model: str | None = None):
466
+ self._model = model or self.DEFAULT_MODEL
181
467
  self._mcp_client: MCPClient | None = None
182
- self._mcp_proc: subprocess.Popen | None = None
183
468
  self._sessions: dict[str, str] = {} # session_id -> conversationId
469
+ # Cumulative token usage for cost tracking
470
+ self._total_usage: dict[str, int] = {
471
+ "input_tokens": 0,
472
+ "output_tokens": 0,
473
+ "cached_input_tokens": 0,
474
+ "reasoning_tokens": 0,
475
+ "total_tokens": 0,
476
+ }
477
+
478
+ def _accumulate_usage(self, usage: dict[str, Any]) -> None:
479
+ """Add usage to cumulative totals."""
480
+ if not usage:
481
+ return
482
+ for key in self._total_usage:
483
+ self._total_usage[key] += usage.get(key, 0)
484
+
485
+ @property
486
+ def total_usage(self) -> dict[str, int]:
487
+ """Get cumulative token usage across all calls."""
488
+ return self._total_usage.copy()
489
+
490
+ def _ensure_client(self) -> MCPClient:
491
+ """Ensure MCP client is running and return it."""
492
+ if self._mcp_client is None:
493
+ self._mcp_client = MCPClient()
494
+
495
+ if not self._mcp_client.is_alive:
496
+ self._mcp_client.start()
184
497
 
185
- async def _ensure_server(self) -> MCPClient:
186
- """Ensure MCP server is running and return client."""
187
- if self._mcp_client is not None:
188
- # Check if process is still alive
189
- if self._mcp_proc and self._mcp_proc.poll() is None:
190
- return self._mcp_client
191
- # Process died, restart
192
- self._mcp_client = None
193
- self._mcp_proc = None
194
-
195
- # Start codex mcp-server
196
- self._mcp_proc = subprocess.Popen(
197
- ["codex", "mcp-server"],
198
- stdin=subprocess.PIPE,
199
- stdout=subprocess.PIPE,
200
- stderr=subprocess.PIPE,
201
- text=True,
202
- bufsize=1,
203
- )
204
- self._mcp_client = MCPClient(self._mcp_proc)
205
- await self._mcp_client.initialize()
206
498
  return self._mcp_client
207
499
 
208
500
  @weave.op()
209
- async def _call_codex(
501
+ def _call_codex(
210
502
  self,
211
503
  task: str,
212
504
  cwd: str,
@@ -216,10 +508,10 @@ class CodexMCPAdapter(ExecutorAdapter):
216
508
  """
217
509
  Call codex MCP tool - traced by Weave.
218
510
 
219
- This wraps the actual codex call so it appears in Weave traces
220
- with full input/output visibility.
511
+ This is synchronous (uses subprocess.Popen, not asyncio) so the MCP
512
+ server persists across calls.
221
513
  """
222
- client = await self._ensure_server()
514
+ client = self._ensure_client()
223
515
 
224
516
  args: dict[str, Any] = {
225
517
  "prompt": task,
@@ -229,17 +521,22 @@ class CodexMCPAdapter(ExecutorAdapter):
229
521
  if model:
230
522
  args["model"] = model
231
523
 
232
- result = await client.call_tool("codex", args)
524
+ result = client.call_tool("codex", args)
525
+
526
+ # Track usage
527
+ usage = result.get("usage", {})
528
+ self._accumulate_usage(usage)
233
529
 
234
- # Return structured result for Weave
235
530
  return {
236
531
  "conversation_id": result.get("conversationId"),
237
532
  "response": self._extract_response(result),
238
533
  "raw_messages": result.get("messages", []),
534
+ "usage": usage,
535
+ "total_usage": self.total_usage,
239
536
  }
240
537
 
241
538
  @weave.op()
242
- async def _call_codex_reply(
539
+ def _call_codex_reply(
243
540
  self,
244
541
  conversation_id: str,
245
542
  message: str,
@@ -247,19 +544,25 @@ class CodexMCPAdapter(ExecutorAdapter):
247
544
  """
248
545
  Call codex-reply MCP tool - traced by Weave.
249
546
 
250
- This wraps the reply call so it appears in Weave traces
251
- with full input/output visibility.
547
+ This is synchronous (uses subprocess.Popen, not asyncio) so the MCP
548
+ server persists across calls.
252
549
  """
253
- client = await self._ensure_server()
550
+ client = self._ensure_client()
254
551
 
255
- result = await client.call_tool("codex-reply", {
552
+ result = client.call_tool("codex-reply", {
256
553
  "conversationId": conversation_id,
257
554
  "prompt": message,
258
555
  })
259
556
 
557
+ # Track usage
558
+ usage = result.get("usage", {})
559
+ self._accumulate_usage(usage)
560
+
260
561
  return {
261
562
  "response": self._extract_response(result),
262
563
  "raw_messages": result.get("messages", []),
564
+ "usage": usage,
565
+ "total_usage": self.total_usage,
263
566
  }
264
567
 
265
568
  async def start_session(
@@ -272,30 +575,35 @@ class CodexMCPAdapter(ExecutorAdapter):
272
575
  **kwargs,
273
576
  ) -> ConversationSession:
274
577
  """Start a Codex session."""
578
+ effective_model = model or self._model
275
579
  session = ConversationSession(
276
580
  adapter=self.name,
277
581
  mode=SessionMode(mode),
278
582
  working_dir=working_dir,
279
583
  task_description=task,
280
- model=model,
584
+ model=effective_model,
281
585
  )
282
586
 
283
587
  if mode == "sync":
284
- # Use traced codex call
285
- result = await self._call_codex(
588
+ # Use traced codex call (synchronous - MCP client persists across calls)
589
+ result = self._call_codex(
286
590
  task=task,
287
591
  cwd=str(working_dir.absolute()),
288
592
  sandbox=sandbox,
289
- model=model,
593
+ model=effective_model,
290
594
  )
291
595
 
292
596
  # Extract conversation ID and response
293
597
  session.conversation_id = result["conversation_id"]
294
- self._sessions[session.id] = session.conversation_id
598
+ if session.conversation_id:
599
+ self._sessions[session.id] = session.conversation_id
295
600
 
296
601
  session.add_message("user", task)
297
602
  session.add_message("assistant", result["response"])
298
603
 
604
+ # Track token usage on the session
605
+ session.add_usage(result.get("usage", {}))
606
+
299
607
  else:
300
608
  # Async mode: use codex exec (fire-and-forget)
301
609
  # This runs in a subprocess without MCP
@@ -304,9 +612,8 @@ class CodexMCPAdapter(ExecutorAdapter):
304
612
  "--dangerously-bypass-approvals-and-sandbox",
305
613
  "--skip-git-repo-check",
306
614
  "--json",
615
+ "--model", effective_model,
307
616
  ]
308
- if model:
309
- cmd.extend(["--model", model])
310
617
  cmd.extend(["--", task])
311
618
 
312
619
  proc = subprocess.Popen(
@@ -334,8 +641,8 @@ class CodexMCPAdapter(ExecutorAdapter):
334
641
  if not session.conversation_id:
335
642
  raise ValueError("Session has no conversation ID")
336
643
 
337
- # Use traced codex-reply call
338
- result = await self._call_codex_reply(
644
+ # Use traced codex-reply call (synchronous - MCP client persists across calls)
645
+ result = self._call_codex_reply(
339
646
  conversation_id=session.conversation_id,
340
647
  message=message,
341
648
  )
@@ -344,6 +651,9 @@ class CodexMCPAdapter(ExecutorAdapter):
344
651
  session.add_message("user", message)
345
652
  session.add_message("assistant", response_text)
346
653
 
654
+ # Track token usage on the session
655
+ session.add_usage(result.get("usage", {}))
656
+
347
657
  return response_text
348
658
 
349
659
  async def check_status(
@@ -376,6 +686,8 @@ class CodexMCPAdapter(ExecutorAdapter):
376
686
  session: ConversationSession,
377
687
  ) -> None:
378
688
  """Stop a session."""
689
+ import subprocess
690
+
379
691
  if session.process and session.process.poll() is None:
380
692
  session.process.terminate()
381
693
  try:
@@ -394,30 +706,29 @@ class CodexMCPAdapter(ExecutorAdapter):
394
706
  if self._mcp_client:
395
707
  self._mcp_client.close()
396
708
  self._mcp_client = None
397
- self._mcp_proc = None
398
709
 
399
710
  def _extract_response(self, result: dict) -> str:
400
711
  """Extract response text from MCP result."""
401
712
  # First check for our collected output
402
- if "output" in result and result["output"]:
713
+ if result.get("output"):
403
714
  return result["output"]
404
715
 
405
716
  # Check for messages list
406
- if "messages" in result and result["messages"]:
717
+ if result.get("messages"):
407
718
  return "\n".join(result["messages"])
408
719
 
409
720
  # Result may have different structures depending on codex version
410
721
  if "content" in result:
411
722
  content = result["content"]
412
723
  if isinstance(content, list):
413
- # Extract text from content blocks
414
724
  texts = []
415
725
  for block in content:
416
726
  if isinstance(block, dict) and "text" in block:
417
727
  texts.append(block["text"])
418
728
  elif isinstance(block, str):
419
729
  texts.append(block)
420
- return "\n".join(texts)
730
+ if texts:
731
+ return "\n".join(texts)
421
732
  elif isinstance(content, str):
422
733
  return content
423
734