coding-agent-wrapper 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caw/providers/codex.py ADDED
@@ -0,0 +1,564 @@
1
+ """Codex provider — wraps the ``codex`` CLI in JSON mode."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import atexit
6
+ import json
7
+ import logging
8
+ import re
9
+ import subprocess
10
+ import threading
11
+ import uuid
12
+ from datetime import datetime, timedelta, timezone
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ from caw.display import Display, get_global_display
17
+ from caw.models import (
18
+ ContentBlock,
19
+ MCPServer,
20
+ ModelTier,
21
+ TextBlock,
22
+ ThinkingBlock,
23
+ ToolGroup,
24
+ ToolUse,
25
+ Trajectory,
26
+ Turn,
27
+ UsageStats,
28
+ )
29
+ from caw.pricing import compute_cost
30
+ from caw.provider import Provider, ProviderSession
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # -- Usage-limit detection ----------------------------------------------------
35
+
36
+ _CODEX_LIMIT_RE = re.compile(
37
+ r"try again at\s+(\d{1,2}):(\d{2})\s*(AM|PM)",
38
+ re.IGNORECASE,
39
+ )
40
+
41
+ _DEFAULT_WAIT_MINUTES = 60
42
+
43
+
44
+ def _parse_codex_reset_minutes(text: str) -> int | None:
45
+ """Parse a Codex limit message and return minutes until reset (+ 5 min buffer).
46
+
47
+ Expected format: ``"try again at 3:47 PM"``.
48
+ No timezone is provided so local time is assumed.
49
+ Returns ``None`` if the pattern is not found.
50
+ """
51
+ match = _CODEX_LIMIT_RE.search(text)
52
+ if not match:
53
+ return None
54
+
55
+ hour = int(match.group(1))
56
+ minute = int(match.group(2))
57
+ ampm = match.group(3).lower()
58
+
59
+ # Convert 12-hour to 24-hour
60
+ if ampm == "am" and hour == 12:
61
+ hour = 0
62
+ elif ampm == "pm" and hour != 12:
63
+ hour += 12
64
+
65
+ now = datetime.now() # local time (Codex doesn't include timezone)
66
+ reset_time = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
67
+
68
+ if reset_time <= now:
69
+ reset_time += timedelta(days=1)
70
+
71
+ delta = reset_time - now
72
+ wait_minutes = int(delta.total_seconds() / 60) + 5 # 5-minute buffer
73
+ return max(1, wait_minutes)
74
+
75
+
76
+ def detect_codex_usage_limit(text: str) -> int | None:
77
+ """Check whether *text* indicates a Codex usage limit.
78
+
79
+ Returns the number of minutes to wait before retrying, or ``None`` if no
80
+ limit was detected.
81
+ """
82
+ lower = text.lower()
83
+ if "usage limit" not in lower:
84
+ return None
85
+ return _parse_codex_reset_minutes(text) or _DEFAULT_WAIT_MINUTES
86
+
87
+
88
+ # -- Subprocess registry + atexit cleanup -------------------------------------
89
+
90
+ _active_processes: set[subprocess.Popen] = set()
91
+ _process_lock = threading.Lock()
92
+
93
+
94
+ def _register_process(proc: subprocess.Popen) -> None:
95
+ with _process_lock:
96
+ _active_processes.add(proc)
97
+
98
+
99
+ def _unregister_process(proc: subprocess.Popen) -> None:
100
+ with _process_lock:
101
+ _active_processes.discard(proc)
102
+
103
+
104
+ def _cleanup_processes() -> None:
105
+ """Kill all tracked subprocesses at interpreter exit."""
106
+ with _process_lock:
107
+ procs = list(_active_processes)
108
+ for proc in procs:
109
+ try:
110
+ proc.kill()
111
+ except OSError:
112
+ pass
113
+
114
+
115
+ atexit.register(_cleanup_processes)
116
+
117
+
118
+ def _read_codex_config_model() -> str | None:
119
+ """Read the default model from ``~/.codex/config.toml``, if available."""
120
+ config_path = Path.home() / ".codex" / "config.toml"
121
+ if not config_path.is_file():
122
+ return None
123
+ try:
124
+ import tomllib
125
+ except ModuleNotFoundError: # Python < 3.11
126
+ try:
127
+ import tomli as tomllib # type: ignore[no-redef]
128
+ except ModuleNotFoundError:
129
+ # Fall back to a simple regex parse for the top-level model key
130
+ text = config_path.read_text()
131
+ m = re.match(r'^model\s*=\s*"([^"]+)"', text, re.MULTILINE)
132
+ return m.group(1) if m else None
133
+ try:
134
+ data = tomllib.loads(config_path.read_text())
135
+ return data.get("model")
136
+ except Exception:
137
+ return None
138
+
139
+
140
+ class CodexSession(ProviderSession):
141
+ """Live session backed by the ``codex`` CLI."""
142
+
143
+ def __init__(
144
+ self,
145
+ mcp_servers: list[MCPServer],
146
+ model: str | None = None,
147
+ system_prompt: str | None = None,
148
+ session_id: str | None = None,
149
+ reasoning: str | None = None,
150
+ sandbox: str | None = None,
151
+ ) -> None:
152
+ self._session_id = session_id or str(uuid.uuid4())
153
+ self._model = model or _read_codex_config_model()
154
+ self._mcp_servers = mcp_servers
155
+ self._system_prompt = system_prompt
156
+ self._reasoning = reasoning
157
+ self._sandbox = sandbox
158
+ self._created_at = datetime.now(timezone.utc).isoformat()
159
+ self._has_sent = False
160
+ self._thread_id: str | None = None
161
+ self._turns: list[Turn] = []
162
+ self._total_usage = UsageStats()
163
+ self._total_duration_ms = 0
164
+ self._last_raw_output: str = ""
165
+ self._step_callback = None
166
+
167
+ def set_step_callback(self, callback):
168
+ self._step_callback = callback
169
+
170
+ # ------------------------------------------------------------------
171
+ # MCP config helpers
172
+ # ------------------------------------------------------------------
173
+
174
+ def _mcp_config_args(self) -> list[str]:
175
+ """Build ``-c`` config override flags for MCP servers."""
176
+ args: list[str] = []
177
+ for srv in self._mcp_servers:
178
+ if srv.url:
179
+ args += ["-c", f'mcp_servers.{srv.name}.url="{srv.url}"']
180
+ else:
181
+ args += ["-c", f'mcp_servers.{srv.name}.command="{srv.command}"']
182
+ if srv.args:
183
+ args += ["-c", f"mcp_servers.{srv.name}.args={json.dumps(srv.args)}"]
184
+ return args
185
+
186
+ # ------------------------------------------------------------------
187
+ # Core send (streaming Popen)
188
+ # ------------------------------------------------------------------
189
+
190
+ def send(self, message: str) -> Turn:
191
+ display = get_global_display()
192
+
193
+ if display:
194
+ if not self._has_sent:
195
+ display.on_metadata(
196
+ agent="codex",
197
+ model=self._model or "",
198
+ session=self._session_id,
199
+ )
200
+ display.on_user_message(message)
201
+
202
+ # Build the prompt (prepend system prompt on first turn)
203
+ prompt = message
204
+ if not self._has_sent and self._system_prompt:
205
+ prompt = f"{self._system_prompt}\n\n{message}"
206
+
207
+ # Build sandbox flags
208
+ if self._sandbox is None or self._sandbox == "danger-full-access":
209
+ sandbox_flags = ["--dangerously-bypass-approvals-and-sandbox"]
210
+ else:
211
+ sandbox_flags = ["--full-auto", "--sandbox", self._sandbox]
212
+
213
+ # Build command
214
+ if not self._has_sent:
215
+ cmd = (
216
+ ["codex", "exec"]
217
+ + sandbox_flags
218
+ + [
219
+ "--skip-git-repo-check",
220
+ "--json",
221
+ ]
222
+ )
223
+ else:
224
+ cmd = (
225
+ ["codex", "exec", "resume", self._thread_id or ""]
226
+ + sandbox_flags
227
+ + [
228
+ "--skip-git-repo-check",
229
+ "--json",
230
+ ]
231
+ )
232
+
233
+ if self._model:
234
+ cmd += ["-m", self._model]
235
+
236
+ if self._reasoning:
237
+ cmd += ["-c", f'model_reasoning_effort="{self._reasoning}"']
238
+
239
+ cmd += self._mcp_config_args()
240
+
241
+ # Prompt as positional arg (last)
242
+ cmd.append(prompt)
243
+
244
+ # Accumulated state for event processing
245
+ blocks: list[ContentBlock] = []
246
+ tool_blocks: dict[str, ToolUse] = {}
247
+ usage = UsageStats()
248
+ raw_lines: list[str] = []
249
+
250
+ try:
251
+ proc = subprocess.Popen(
252
+ cmd,
253
+ stdin=subprocess.DEVNULL,
254
+ stdout=subprocess.PIPE,
255
+ stderr=subprocess.PIPE,
256
+ text=True,
257
+ )
258
+ except FileNotFoundError:
259
+ raise RuntimeError("codex CLI not found. Install it with: npm install -g @openai/codex")
260
+
261
+ _register_process(proc)
262
+ try:
263
+ # Stream stdout line by line
264
+ for line in proc.stdout: # type: ignore[union-attr]
265
+ line = line.rstrip("\n")
266
+ raw_lines.append(line)
267
+ stripped = line.strip()
268
+ if not stripped:
269
+ continue
270
+ try:
271
+ event = json.loads(stripped)
272
+ except json.JSONDecodeError:
273
+ continue
274
+
275
+ result = self._process_event(event, blocks, tool_blocks, display)
276
+ if result is not None:
277
+ usage = result
278
+ if self._step_callback and blocks:
279
+ self._step_callback(list(blocks))
280
+
281
+ # Read stderr after stdout is exhausted
282
+ stderr = proc.stderr.read() if proc.stderr else "" # type: ignore[union-attr]
283
+ proc.wait()
284
+
285
+ self._last_raw_output = "\n".join(raw_lines)
286
+
287
+ if proc.returncode != 0 and not raw_lines:
288
+ raise RuntimeError(f"codex CLI exited with code {proc.returncode}: {stderr}")
289
+
290
+ except (KeyboardInterrupt, Exception):
291
+ proc.kill()
292
+ proc.wait()
293
+ raise
294
+ finally:
295
+ _unregister_process(proc)
296
+
297
+ self._has_sent = True
298
+
299
+ turn = Turn(input=message, output=blocks, usage=usage, duration_ms=0)
300
+
301
+ if display:
302
+ display.on_turn_end(turn.result, usage, 0)
303
+
304
+ self._turns.append(turn)
305
+ self._total_usage = self._total_usage + turn.usage
306
+ return turn
307
+
308
+ # ------------------------------------------------------------------
309
+ # Usage-limit detection (called by core Session auto-wait loop)
310
+ # ------------------------------------------------------------------
311
+
312
+ def detect_usage_limit(self, turn: Turn) -> int | None:
313
+ """Detect Codex usage-limit messages in the turn's result text."""
314
+ return detect_codex_usage_limit(turn.result)
315
+
316
+ # ------------------------------------------------------------------
317
+ # Per-event processing
318
+ # ------------------------------------------------------------------
319
+
320
+ def _process_event(
321
+ self,
322
+ event: dict[str, Any],
323
+ blocks: list[ContentBlock],
324
+ tool_blocks: dict[str, ToolUse],
325
+ display: Display | None,
326
+ ) -> UsageStats | None:
327
+ """Process a single JSONL event. Returns UsageStats on ``turn.completed``."""
328
+ event_type = event.get("type")
329
+
330
+ if event_type == "thread.started":
331
+ self._thread_id = event.get("thread_id")
332
+
333
+ elif event_type == "item.started":
334
+ item = event.get("item", {})
335
+ item_type = item.get("type")
336
+ tool_id = item.get("id", str(uuid.uuid4()))
337
+
338
+ if item_type == "command_execution":
339
+ block = ToolUse(
340
+ id=tool_id,
341
+ name="command_execution",
342
+ arguments={"command": item.get("command", "")},
343
+ )
344
+ blocks.append(block)
345
+ tool_blocks[tool_id] = block
346
+ if display:
347
+ display.on_tool_call(block)
348
+
349
+ elif item_type == "mcp_tool_call":
350
+ server = item.get("server", "")
351
+ tool_name = item.get("tool", "")
352
+ arguments = item.get("arguments", {})
353
+ block = ToolUse(
354
+ id=tool_id,
355
+ name=f"{server}.{tool_name}" if server else tool_name,
356
+ arguments=arguments if isinstance(arguments, dict) else {"input": arguments},
357
+ )
358
+ blocks.append(block)
359
+ tool_blocks[tool_id] = block
360
+ if display:
361
+ display.on_tool_call(block)
362
+
363
+ elif item_type == "file_change":
364
+ block = ToolUse(
365
+ id=tool_id,
366
+ name="file_change",
367
+ arguments={"file": item.get("file", ""), "action": item.get("action", "")},
368
+ )
369
+ blocks.append(block)
370
+ tool_blocks[tool_id] = block
371
+ if display:
372
+ display.on_tool_call(block)
373
+
374
+ elif event_type in ("item.completed", "item.updated"):
375
+ item = event.get("item", {})
376
+ item_type = item.get("type")
377
+ is_final = event_type == "item.completed"
378
+
379
+ if item_type == "command_execution":
380
+ tool_id = item.get("id", "")
381
+ if tool_id in tool_blocks:
382
+ tool_blocks[tool_id].output = item.get("output", "")
383
+ tool_blocks[tool_id].is_error = item.get("exit_code", 0) != 0
384
+ if display and is_final:
385
+ display.on_tool_result(tool_blocks[tool_id])
386
+
387
+ elif item_type == "mcp_tool_call":
388
+ tool_id = item.get("id", "")
389
+ if tool_id in tool_blocks:
390
+ result = item.get("result")
391
+ error = item.get("error")
392
+ if result:
393
+ # Extract text from MCP content blocks
394
+ texts: list[str] = []
395
+ for c in result.get("content", []):
396
+ if isinstance(c, dict) and c.get("type") == "text":
397
+ texts.append(c.get("text", ""))
398
+ elif isinstance(c, str):
399
+ texts.append(c)
400
+ tool_blocks[tool_id].output = "\n".join(texts)
401
+ if error:
402
+ tool_blocks[tool_id].is_error = True
403
+ msg = error.get("message", str(error)) if isinstance(error, dict) else str(error)
404
+ tool_blocks[tool_id].output = msg
405
+ elif item.get("status") == "failed":
406
+ tool_blocks[tool_id].is_error = True
407
+ if display and is_final:
408
+ display.on_tool_result(tool_blocks[tool_id])
409
+
410
+ elif item_type == "file_change":
411
+ tool_id = item.get("id", "")
412
+ if tool_id in tool_blocks:
413
+ tool_blocks[tool_id].output = item.get("patch", item.get("content", ""))
414
+ if display and is_final:
415
+ display.on_tool_result(tool_blocks[tool_id])
416
+
417
+ elif item_type == "reasoning" and is_final:
418
+ text = item.get("text", "")
419
+ if text:
420
+ block = ThinkingBlock(text=text)
421
+ blocks.append(block)
422
+ if display:
423
+ display.on_thinking(block)
424
+
425
+ elif item_type == "agent_message" and is_final:
426
+ text = item.get("text", "")
427
+ if text:
428
+ block = TextBlock(text=text)
429
+ blocks.append(block)
430
+ if display:
431
+ display.on_text(block)
432
+
433
+ elif event_type == "turn.completed":
434
+ return self._parse_usage(event)
435
+
436
+ elif event_type in ("turn.failed", "error"):
437
+ raw = event.get("message", event.get("error", "Unknown error"))
438
+ if isinstance(raw, dict):
439
+ error_msg = raw.get("message", raw.get("error", str(raw)))
440
+ else:
441
+ error_msg = str(raw)
442
+ # Usage-limit errors are recoverable — surface as text so the
443
+ # auto-wait loop in Session.send() can detect and retry.
444
+ if detect_codex_usage_limit(error_msg) is not None:
445
+ block = TextBlock(text=error_msg)
446
+ blocks.append(block)
447
+ if display:
448
+ display.on_text(block)
449
+ else:
450
+ raise RuntimeError(f"Codex turn failed: {error_msg}")
451
+
452
+ return None
453
+
454
+ # ------------------------------------------------------------------
455
+ # Usage parsing
456
+ # ------------------------------------------------------------------
457
+
458
+ def _parse_usage(self, event: dict[str, Any]) -> UsageStats:
459
+ u = event.get("usage", {})
460
+ raw_input = u.get("input_tokens", 0)
461
+ cached = u.get("cached_input_tokens", 0)
462
+ usage = UsageStats(
463
+ input_tokens=raw_input - cached,
464
+ output_tokens=u.get("output_tokens", 0),
465
+ cache_read_tokens=cached,
466
+ cache_write_tokens=0,
467
+ )
468
+ usage.cost_usd = compute_cost("codex", self._model or "", usage)
469
+ return usage
470
+
471
+ # ------------------------------------------------------------------
472
+ # Trajectory / lifecycle
473
+ # ------------------------------------------------------------------
474
+
475
+ @property
476
+ def session_id(self) -> str:
477
+ return self._session_id
478
+
479
+ @property
480
+ def last_raw_output(self) -> str:
481
+ return self._last_raw_output
482
+
483
+ @property
484
+ def trajectory(self) -> Trajectory:
485
+ return Trajectory(
486
+ agent="codex",
487
+ model=self._model or "",
488
+ session_id=self._session_id,
489
+ created_at=self._created_at,
490
+ system_prompt=self._system_prompt or "",
491
+ reasoning=self._reasoning or "",
492
+ mcp_servers=list(self._mcp_servers),
493
+ turns=list(self._turns),
494
+ usage=self._total_usage,
495
+ duration_ms=self._total_duration_ms,
496
+ metadata={},
497
+ )
498
+
499
+ def end(self) -> Trajectory:
500
+ return self.trajectory
501
+
502
+
503
+ _MODEL_TIER_MAP: dict[ModelTier, str] = {
504
+ ModelTier.STRONGEST: "gpt-5.3-codex",
505
+ ModelTier.FAST: "gpt-5.3-codex-spark",
506
+ }
507
+
508
+
509
+ class CodexProvider(Provider):
510
+ """Provider that delegates to the ``codex`` CLI."""
511
+
512
+ @property
513
+ def name(self) -> str:
514
+ return "codex"
515
+
516
+ def resolve_model(self, tier: ModelTier) -> str:
517
+ return _MODEL_TIER_MAP[tier]
518
+
519
+ def resolve_tool_restrictions(self, tools: ToolGroup) -> dict[str, Any]:
520
+ if tools == ToolGroup.ALL:
521
+ return {}
522
+ if not tools:
523
+ raise ValueError("ToolGroup must not be empty — at least one group is required.")
524
+
525
+ has_exec = bool(tools & ToolGroup.EXEC)
526
+ has_writer = bool(tools & ToolGroup.WRITER)
527
+ has_reader = bool(tools & ToolGroup.READER)
528
+
529
+ # Warn about groups that Codex cannot distinguish
530
+ lost = []
531
+ for group_name in ("PARALLEL", "WEB", "INTERACTION"):
532
+ group = ToolGroup[group_name]
533
+ if bool(tools & group) != bool(ToolGroup.ALL & group):
534
+ lost.append(group_name)
535
+ if lost:
536
+ logger.warning(
537
+ "Codex provider cannot enforce per-tool restrictions for %s; "
538
+ "these distinctions are lost in sandbox-level mapping.",
539
+ ", ".join(lost),
540
+ )
541
+
542
+ if has_exec:
543
+ return {"sandbox": "danger-full-access"}
544
+ if has_writer:
545
+ return {"sandbox": "workspace-write"}
546
+ if has_reader:
547
+ return {"sandbox": "read-only"}
548
+
549
+ # Fallback: some groups set but none of READER/WRITER/EXEC
550
+ logger.warning("Codex: no file/exec groups enabled; defaulting to read-only sandbox.")
551
+ return {"sandbox": "read-only"}
552
+
553
+ def _limit_probe_kwargs(self) -> dict[str, Any]:
554
+ return {"sandbox": "read-only"}
555
+
556
+ def start_session(self, mcp_servers: list[MCPServer], **kwargs: Any) -> CodexSession:
557
+ return CodexSession(
558
+ mcp_servers=mcp_servers,
559
+ model=kwargs.get("model"),
560
+ system_prompt=kwargs.get("system_prompt"),
561
+ session_id=kwargs.get("session_id"),
562
+ reasoning=kwargs.get("reasoning"),
563
+ sandbox=kwargs.get("sandbox"),
564
+ )
caw/py.typed ADDED
File without changes