coding-agent-wrapper 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caw/models.py ADDED
@@ -0,0 +1,385 @@
1
+ """Core data models for the coding agent wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Union
8
+
9
+
10
+ class ModelTier(enum.Enum):
11
+ """Abstract model selection tiers.
12
+
13
+ Each provider maps these to concrete model identifiers::
14
+
15
+ agent = Agent(model=ModelTier.STRONGEST) # provider picks its best model
16
+ agent = Agent(model=ModelTier.FAST) # provider picks its fast model
17
+ agent = Agent(model="claude-opus-4-6") # explicit model string still works
18
+ """
19
+
20
+ STRONGEST = "strongest"
21
+ FAST = "fast"
22
+
23
+
24
+ class ToolGroup(enum.Flag):
25
+ """Abstract tool permission groups.
26
+
27
+ Combine with ``|`` (union) and ``-`` (subtract) to build permission sets::
28
+
29
+ ToolGroup.READER | ToolGroup.EXEC # read + execute only
30
+ ToolGroup.ALL - ToolGroup.WRITER # everything except writes
31
+ ToolGroup.ALL - ToolGroup.INTERACTION # default for automated pipelines
32
+ """
33
+
34
+ READER = enum.auto()
35
+ WRITER = enum.auto()
36
+ EXEC = enum.auto()
37
+ WEB = enum.auto()
38
+ PARALLEL = enum.auto()
39
+ INTERACTION = enum.auto()
40
+
41
+ ALL = READER | WRITER | EXEC | WEB | PARALLEL | INTERACTION
42
+ NO_INTERACTION = READER | WRITER | EXEC | WEB | PARALLEL
43
+
44
+ def __sub__(self, other):
45
+ if not isinstance(other, ToolGroup):
46
+ return NotImplemented
47
+ return self & ~other
48
+
49
+
50
+ @dataclass
51
+ class MCPServer:
52
+ """Configuration for an MCP server.
53
+
54
+ For stdio transport, set ``command``/``args``/``env``.
55
+ For HTTP transport, set ``url`` (command/args/env are ignored).
56
+ """
57
+
58
+ name: str
59
+ command: str = ""
60
+ args: list[str] = field(default_factory=list)
61
+ env: dict[str, str] = field(default_factory=dict)
62
+ url: str = ""
63
+
64
+
65
+ @dataclass
66
+ class AgentSpec:
67
+ """Configuration for a subagent."""
68
+
69
+ name: str = ""
70
+ description: str = ""
71
+ system_prompt: str = ""
72
+ model: str = ""
73
+ reasoning: str = ""
74
+ tools: ToolGroup | None = None
75
+ tool_servers: list[Any] = field(default_factory=list)
76
+ mcp_servers: list[MCPServer] = field(default_factory=list)
77
+ subagents: list["AgentSpec"] = field(default_factory=list)
78
+ metadata: dict[str, Any] = field(default_factory=dict)
79
+
80
+
81
+ @dataclass
82
+ class MCPTool:
83
+ """Descriptor for a tool provided by an MCP server."""
84
+
85
+ name: str
86
+ description: str = ""
87
+ server: str = ""
88
+ input_schema: dict[str, Any] = field(default_factory=dict)
89
+
90
+
91
+ @dataclass
92
+ class TextBlock:
93
+ """A block of text output from the agent."""
94
+
95
+ text: str
96
+
97
+
98
+ @dataclass
99
+ class ThinkingBlock:
100
+ """A block of thinking/reasoning output from the agent."""
101
+
102
+ text: str
103
+
104
+
105
+ @dataclass
106
+ class ToolUse:
107
+ """A tool invocation paired with its result."""
108
+
109
+ id: str
110
+ name: str
111
+ arguments: dict[str, Any] = field(default_factory=dict)
112
+ output: str = ""
113
+ is_error: bool = False
114
+ subagent_trajectory: Trajectory | None = None
115
+
116
+
117
+ ContentBlock = Union[TextBlock, ThinkingBlock, ToolUse]
118
+
119
+
120
+ @dataclass
121
+ class UsageStats:
122
+ """Token usage and cost statistics."""
123
+
124
+ input_tokens: int = 0
125
+ output_tokens: int = 0
126
+ cache_read_tokens: int = 0
127
+ cache_write_tokens: int = 0
128
+ cost_usd: float = 0.0
129
+
130
+ @property
131
+ def total_tokens(self) -> int:
132
+ """Total tokens consumed (input + output)."""
133
+ return self.input_tokens + self.output_tokens
134
+
135
+ def __add__(self, other: UsageStats) -> UsageStats:
136
+ return UsageStats(
137
+ input_tokens=self.input_tokens + other.input_tokens,
138
+ output_tokens=self.output_tokens + other.output_tokens,
139
+ cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
140
+ cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
141
+ cost_usd=self.cost_usd + other.cost_usd,
142
+ )
143
+
144
+ def to_dict(self) -> dict[str, Any]:
145
+ return {
146
+ "input_tokens": self.input_tokens,
147
+ "output_tokens": self.output_tokens,
148
+ "cache_read_tokens": self.cache_read_tokens,
149
+ "cache_write_tokens": self.cache_write_tokens,
150
+ "cost_usd": self.cost_usd,
151
+ }
152
+
153
+ @classmethod
154
+ def from_dict(cls, d: dict[str, Any]) -> UsageStats:
155
+ return cls(
156
+ input_tokens=d.get("input_tokens", 0),
157
+ output_tokens=d.get("output_tokens", 0),
158
+ cache_read_tokens=d.get("cache_read_tokens", 0),
159
+ cache_write_tokens=d.get("cache_write_tokens", 0),
160
+ cost_usd=d.get("cost_usd", 0.0),
161
+ )
162
+
163
+
164
+ @dataclass
165
+ class Turn:
166
+ """A single turn: user sends a message, agent responds."""
167
+
168
+ input: str
169
+ output: list[ContentBlock] = field(default_factory=list)
170
+ usage: UsageStats = field(default_factory=UsageStats)
171
+ duration_ms: int = 0
172
+
173
+ @property
174
+ def result(self) -> str:
175
+ """Last text block's content."""
176
+ for block in reversed(self.output):
177
+ if isinstance(block, TextBlock) and block.text:
178
+ return block.text
179
+ return ""
180
+
181
+ @property
182
+ def tool_calls(self) -> list[ToolUse]:
183
+ """All tool calls made during this turn."""
184
+ return [b for b in self.output if isinstance(b, ToolUse)]
185
+
186
+ def to_dict(self) -> dict[str, Any]:
187
+ return {
188
+ "input": self.input,
189
+ "output": [_block_to_dict(b) for b in self.output],
190
+ "usage": self.usage.to_dict(),
191
+ "duration_ms": self.duration_ms,
192
+ }
193
+
194
+ @classmethod
195
+ def from_dict(cls, d: dict[str, Any]) -> Turn:
196
+ return cls(
197
+ input=d.get("input", ""),
198
+ output=[_block_from_dict(b) for b in d.get("output", [])],
199
+ usage=UsageStats.from_dict(d.get("usage", {})),
200
+ duration_ms=d.get("duration_ms", 0),
201
+ )
202
+
203
+
204
+ @dataclass
205
+ class Trajectory:
206
+ """Complete record of a session.
207
+
208
+ ``usage`` tracks this agent's own token usage. Use ``total_usage`` to get
209
+ the accumulated usage including all nested subagent trajectories.
210
+ """
211
+
212
+ agent: str
213
+ model: str = ""
214
+ session_id: str = ""
215
+ created_at: str = ""
216
+ completed_at: str = ""
217
+ usage_limited: bool = False
218
+ system_prompt: str = ""
219
+ reasoning: str = ""
220
+ mcp_servers: list[MCPServer] = field(default_factory=list)
221
+ turns: list[Turn] = field(default_factory=list)
222
+ usage: UsageStats = field(default_factory=UsageStats)
223
+ duration_ms: int = 0
224
+ metadata: dict[str, Any] = field(default_factory=dict)
225
+
226
+ @property
227
+ def num_turns(self) -> int:
228
+ return len(self.turns)
229
+
230
+ @property
231
+ def result(self) -> str:
232
+ """The final result from the last turn."""
233
+ if self.turns:
234
+ return self.turns[-1].result
235
+ return ""
236
+
237
+ @property
238
+ def total_tool_calls(self) -> int:
239
+ return sum(len(t.tool_calls) for t in self.turns)
240
+
241
+ @property
242
+ def total_usage(self) -> UsageStats:
243
+ """Accumulated usage: own + all nested subagent trajectories (recursive)."""
244
+ total = self.usage
245
+ for turn in self.turns:
246
+ for block in turn.output:
247
+ if isinstance(block, ToolUse) and block.subagent_trajectory:
248
+ total = total + block.subagent_trajectory.total_usage
249
+ return total
250
+
251
+ @property
252
+ def subagent_trajectories(self) -> list[Trajectory]:
253
+ """All subagent trajectories across all turns."""
254
+ trajs: list[Trajectory] = []
255
+ for turn in self.turns:
256
+ for block in turn.output:
257
+ if isinstance(block, ToolUse) and block.subagent_trajectory:
258
+ trajs.append(block.subagent_trajectory)
259
+ return trajs
260
+
261
+ @property
262
+ def is_usage_limited(self) -> bool:
263
+ """Whether the session ended due to a usage limit.
264
+
265
+ Set by ``Session.end()`` using the provider's ``detect_usage_limit``.
266
+ """
267
+ return self.usage_limited
268
+
269
+ @property
270
+ def is_complete(self) -> bool:
271
+ """Whether the session completed normally.
272
+
273
+ A trajectory is complete when it has been finalized (``completed_at``
274
+ is set by ``Session.end()``) and was not usage-limited. Mid-session
275
+ snapshots written by ``append_turn`` have an empty ``completed_at``
276
+ and are therefore not considered complete.
277
+ """
278
+ return bool(self.completed_at) and not self.is_usage_limited
279
+
280
+ def to_dict(self) -> dict[str, Any]:
281
+ return {
282
+ "agent": self.agent,
283
+ "model": self.model,
284
+ "session_id": self.session_id,
285
+ "created_at": self.created_at,
286
+ "completed_at": self.completed_at,
287
+ "usage_limited": self.usage_limited,
288
+ "system_prompt": self.system_prompt,
289
+ "reasoning": self.reasoning,
290
+ "mcp_servers": [
291
+ {"name": s.name, "command": s.command, "args": s.args, "env": s.env, "url": s.url}
292
+ for s in self.mcp_servers
293
+ ],
294
+ "turns": [t.to_dict() for t in self.turns],
295
+ "usage": self.usage.to_dict(),
296
+ "total_usage": self.total_usage.to_dict(),
297
+ "duration_ms": self.duration_ms,
298
+ "metadata": self.metadata,
299
+ }
300
+
301
+ @classmethod
302
+ def from_dict(cls, d: dict[str, Any]) -> Trajectory:
303
+ return cls(
304
+ agent=d.get("agent", ""),
305
+ model=d.get("model", ""),
306
+ session_id=d.get("session_id", ""),
307
+ created_at=d.get("created_at", ""),
308
+ completed_at=d.get("completed_at", ""),
309
+ usage_limited=d.get("usage_limited", False),
310
+ system_prompt=d.get("system_prompt", ""),
311
+ reasoning=d.get("reasoning", ""),
312
+ mcp_servers=[
313
+ MCPServer(
314
+ name=s.get("name", ""),
315
+ command=s.get("command", ""),
316
+ args=s.get("args", []),
317
+ env=s.get("env", {}),
318
+ url=s.get("url", ""),
319
+ )
320
+ for s in d.get("mcp_servers", [])
321
+ ],
322
+ turns=[Turn.from_dict(t) for t in d.get("turns", [])],
323
+ usage=UsageStats.from_dict(d.get("usage", {})),
324
+ duration_ms=d.get("duration_ms", 0),
325
+ metadata=d.get("metadata", {}),
326
+ )
327
+
328
+
329
+ # -- Serialization helpers for content blocks --------------------------------
330
+
331
+
332
+ @dataclass
333
+ class InteractiveResult:
334
+ """Result from an interactive agent session."""
335
+
336
+ exit_code: int
337
+ output: str = "" # raw terminal output (may include ANSI escape sequences)
338
+
339
+ @property
340
+ def session_id(self) -> str | None:
341
+ """Extract the session ID from Claude Code's exit output, if present."""
342
+ import re
343
+
344
+ m = re.search(r"--resume\s+(\S+)", self.output)
345
+ return m.group(1) if m else None
346
+
347
+
348
+ def _block_to_dict(block: ContentBlock) -> dict[str, Any]:
349
+ if isinstance(block, TextBlock):
350
+ return {"type": "text", "text": block.text}
351
+ elif isinstance(block, ThinkingBlock):
352
+ return {"type": "thinking", "text": block.text}
353
+ else: # ToolUse
354
+ d: dict[str, Any] = {
355
+ "type": "tool_use",
356
+ "id": block.id,
357
+ "name": block.name,
358
+ "arguments": block.arguments,
359
+ "output": block.output,
360
+ }
361
+ if block.is_error:
362
+ d["is_error"] = True
363
+ if block.subagent_trajectory:
364
+ d["subagent_trajectory"] = block.subagent_trajectory.to_dict()
365
+ return d
366
+
367
+
368
+ def _block_from_dict(d: dict[str, Any]) -> ContentBlock:
369
+ btype = d.get("type", "")
370
+ if btype == "text":
371
+ return TextBlock(text=d.get("text", ""))
372
+ elif btype == "thinking":
373
+ return ThinkingBlock(text=d.get("text", ""))
374
+ else: # tool_use
375
+ sub_traj = None
376
+ if d.get("subagent_trajectory"):
377
+ sub_traj = Trajectory.from_dict(d["subagent_trajectory"])
378
+ return ToolUse(
379
+ id=d.get("id", ""),
380
+ name=d.get("name", ""),
381
+ arguments=d.get("arguments", {}),
382
+ output=d.get("output", ""),
383
+ is_error=d.get("is_error", False),
384
+ subagent_trajectory=sub_traj,
385
+ )
caw/pricing.json ADDED
@@ -0,0 +1,15 @@
1
+ {
2
+ "_comment": "Pricing in USD per 1 million tokens, keyed by agent then model",
3
+ "codex": {
4
+ "gpt-5.2-codex": {
5
+ "input": 1.75,
6
+ "cached_input": 0.175,
7
+ "output": 14.0
8
+ },
9
+ "gpt-5.3-codex": {
10
+ "input": 1.75,
11
+ "cached_input": 0.175,
12
+ "output": 14.0
13
+ }
14
+ }
15
+ }
caw/pricing.py ADDED
@@ -0,0 +1,33 @@
1
+ """Token-based cost computation from pricing config."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from caw.models import UsageStats
10
+
11
+ _pricing_cache: dict[str, Any] | None = None
12
+
13
+
14
+ def _load_pricing() -> dict[str, Any]:
15
+ global _pricing_cache
16
+ if _pricing_cache is None:
17
+ path = Path(__file__).parent / "pricing.json"
18
+ if path.exists():
19
+ _pricing_cache = json.loads(path.read_text())
20
+ else:
21
+ _pricing_cache = {}
22
+ return _pricing_cache
23
+
24
+
25
+ def compute_cost(agent: str, model: str, usage: UsageStats) -> float:
26
+ """Compute cost in USD from token counts and pricing config."""
27
+ pricing = _load_pricing().get(agent, {}).get(model, {})
28
+ cost = (
29
+ usage.input_tokens * pricing.get("input", 0.0)
30
+ + usage.cache_read_tokens * pricing.get("cached_input", 0.0)
31
+ + usage.output_tokens * pricing.get("output", 0.0)
32
+ ) / 1_000_000
33
+ return cost
caw/provider.py ADDED
@@ -0,0 +1,135 @@
1
+ """Abstract base classes for provider implementations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
7
+ from typing import Any
8
+
9
+ from caw.models import InteractiveResult, MCPServer, ModelTier, ToolGroup, Trajectory, Turn
10
+
11
+
12
+ class ProviderSession(ABC):
13
+ """ABC that each provider implements to manage a live session."""
14
+
15
+ @abstractmethod
16
+ def send(self, message: str) -> Turn:
17
+ """Send a message and return the agent's response turn."""
18
+ ...
19
+
20
+ @abstractmethod
21
+ def end(self) -> Trajectory:
22
+ """Finalize the session and return the complete trajectory."""
23
+ ...
24
+
25
+ @property
26
+ @abstractmethod
27
+ def trajectory(self) -> Trajectory:
28
+ """The accumulated trajectory so far."""
29
+ ...
30
+
31
+ def detect_usage_limit(self, turn: Turn) -> int | None:
32
+ """Check whether *turn* indicates the provider's usage limit was hit.
33
+
34
+ Returns the number of minutes to wait before retrying, or ``None``
35
+ if no limit was detected. Override in provider subclasses to
36
+ implement provider-specific detection logic.
37
+ """
38
+ return None
39
+
40
+ @property
41
+ def session_id(self) -> str | None:
42
+ """Provider-assigned session ID (if any)."""
43
+ return None
44
+
45
+ @property
46
+ def last_raw_output(self) -> str | None:
47
+ """Raw CLI stdout from the most recent send() call (if available)."""
48
+ return None
49
+
50
+ def set_step_callback(self, callback: Callable[[list], None] | None) -> None:
51
+ """Set callback invoked after each step within send()."""
52
+ pass # default no-op; concrete providers override
53
+
54
+
55
+ class Provider(ABC):
56
+ """ABC that each coding agent backend implements."""
57
+
58
+ @property
59
+ @abstractmethod
60
+ def name(self) -> str:
61
+ """Provider identifier (e.g. 'claude_code', 'codex')."""
62
+ ...
63
+
64
+ def resolve_model(self, tier: ModelTier) -> str:
65
+ """Translate a :class:`ModelTier` into a concrete model identifier.
66
+
67
+ Each provider must override this to map abstract tiers (e.g.
68
+ ``ModelTier.STRONGEST``) to the actual model string it supports.
69
+ """
70
+ raise NotImplementedError(
71
+ f"{self.name} provider does not implement resolve_model(); "
72
+ f"pass an explicit model string instead of ModelTier.{tier.name}"
73
+ )
74
+
75
+ def resolve_tool_restrictions(self, tools: ToolGroup) -> dict[str, Any]:
76
+ """Translate ToolGroup into provider-specific session kwargs.
77
+
78
+ Receives a concrete ToolGroup value (never None — the Agent layer
79
+ applies the default before calling this).
80
+ """
81
+ return {}
82
+
83
+ def check_limit(self, model: str | None = None) -> int | None:
84
+ """Probe whether the provider's usage limit is currently active.
85
+
86
+ Sends a minimal test prompt and checks if the response indicates a
87
+ usage-limit. Returns the estimated number of minutes to wait before
88
+ the limit resets, or ``None`` if no limit is detected.
89
+
90
+ This incurs a small token cost for the probe request.
91
+ """
92
+ from caw.display import get_global_display, set_global_display
93
+
94
+ old_display = get_global_display()
95
+ set_global_display(None)
96
+ try:
97
+ session = self.start_session(
98
+ mcp_servers=[],
99
+ model=model,
100
+ system_prompt="Reply with the single word 'ok'.",
101
+ **self._limit_probe_kwargs(),
102
+ )
103
+ try:
104
+ turn = session.send("hi")
105
+ return session.detect_usage_limit(turn)
106
+ finally:
107
+ session.end()
108
+ finally:
109
+ set_global_display(old_display)
110
+
111
+ def _limit_probe_kwargs(self) -> dict[str, Any]:
112
+ """Extra session kwargs for the limit-check probe.
113
+
114
+ Override in subclasses to disable tools and minimise side-effects.
115
+ """
116
+ return {}
117
+
118
+ def start_interactive(
119
+ self, initial_prompt: str, mcp_servers: list[MCPServer], capture_bytes: int = 0, **kwargs: Any
120
+ ) -> InteractiveResult:
121
+ """Launch the provider binary interactively with an initial prompt.
122
+
123
+ Hands control to the user's terminal — stdin/stdout/stderr are
124
+ inherited so the user interacts with the agent directly.
125
+ A copy of stdout is captured via a pty.
126
+
127
+ Returns an :class:`InteractiveResult` with the exit code and
128
+ captured output.
129
+ """
130
+ raise NotImplementedError(f"{self.name} provider does not support interactive mode.")
131
+
132
+ @abstractmethod
133
+ def start_session(self, mcp_servers: list[MCPServer], **kwargs: Any) -> ProviderSession:
134
+ """Create and return a new provider session."""
135
+ ...
File without changes