coding-agent-wrapper 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caw/__init__.py +88 -0
- caw/agent.py +578 -0
- caw/auth/README.md +118 -0
- caw/auth/__init__.py +23 -0
- caw/auth/cli.py +68 -0
- caw/auth/collector.py +324 -0
- caw/auth/linker.py +174 -0
- caw/auth/manifest.py +77 -0
- caw/auth/providers.py +433 -0
- caw/auth/status.py +241 -0
- caw/cli.py +50 -0
- caw/display.py +223 -0
- caw/faststats.py +298 -0
- caw/mcp.py +602 -0
- caw/models.py +385 -0
- caw/pricing.json +15 -0
- caw/pricing.py +33 -0
- caw/provider.py +135 -0
- caw/providers/__init__.py +0 -0
- caw/providers/claude_code.py +648 -0
- caw/providers/codex.py +564 -0
- caw/py.typed +0 -0
- caw/storage.py +184 -0
- caw/toolkit.py +198 -0
- caw/viewer/__init__.py +149 -0
- caw/viewer/static/index.html +847 -0
- coding_agent_wrapper-0.1.0.dist-info/METADATA +213 -0
- coding_agent_wrapper-0.1.0.dist-info/RECORD +31 -0
- coding_agent_wrapper-0.1.0.dist-info/WHEEL +4 -0
- coding_agent_wrapper-0.1.0.dist-info/entry_points.txt +2 -0
- coding_agent_wrapper-0.1.0.dist-info/licenses/LICENSE +202 -0
caw/models.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""Core data models for the coding agent wrapper."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import enum
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelTier(enum.Enum):
|
|
11
|
+
"""Abstract model selection tiers.
|
|
12
|
+
|
|
13
|
+
Each provider maps these to concrete model identifiers::
|
|
14
|
+
|
|
15
|
+
agent = Agent(model=ModelTier.STRONGEST) # provider picks its best model
|
|
16
|
+
agent = Agent(model=ModelTier.FAST) # provider picks its fast model
|
|
17
|
+
agent = Agent(model="claude-opus-4-6") # explicit model string still works
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
STRONGEST = "strongest"
|
|
21
|
+
FAST = "fast"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ToolGroup(enum.Flag):
|
|
25
|
+
"""Abstract tool permission groups.
|
|
26
|
+
|
|
27
|
+
Combine with ``|`` (union) and ``-`` (subtract) to build permission sets::
|
|
28
|
+
|
|
29
|
+
ToolGroup.READER | ToolGroup.EXEC # read + execute only
|
|
30
|
+
ToolGroup.ALL - ToolGroup.WRITER # everything except writes
|
|
31
|
+
ToolGroup.ALL - ToolGroup.INTERACTION # default for automated pipelines
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
READER = enum.auto()
|
|
35
|
+
WRITER = enum.auto()
|
|
36
|
+
EXEC = enum.auto()
|
|
37
|
+
WEB = enum.auto()
|
|
38
|
+
PARALLEL = enum.auto()
|
|
39
|
+
INTERACTION = enum.auto()
|
|
40
|
+
|
|
41
|
+
ALL = READER | WRITER | EXEC | WEB | PARALLEL | INTERACTION
|
|
42
|
+
NO_INTERACTION = READER | WRITER | EXEC | WEB | PARALLEL
|
|
43
|
+
|
|
44
|
+
def __sub__(self, other):
|
|
45
|
+
if not isinstance(other, ToolGroup):
|
|
46
|
+
return NotImplemented
|
|
47
|
+
return self & ~other
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class MCPServer:
|
|
52
|
+
"""Configuration for an MCP server.
|
|
53
|
+
|
|
54
|
+
For stdio transport, set ``command``/``args``/``env``.
|
|
55
|
+
For HTTP transport, set ``url`` (command/args/env are ignored).
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
name: str
|
|
59
|
+
command: str = ""
|
|
60
|
+
args: list[str] = field(default_factory=list)
|
|
61
|
+
env: dict[str, str] = field(default_factory=dict)
|
|
62
|
+
url: str = ""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class AgentSpec:
|
|
67
|
+
"""Configuration for a subagent."""
|
|
68
|
+
|
|
69
|
+
name: str = ""
|
|
70
|
+
description: str = ""
|
|
71
|
+
system_prompt: str = ""
|
|
72
|
+
model: str = ""
|
|
73
|
+
reasoning: str = ""
|
|
74
|
+
tools: ToolGroup | None = None
|
|
75
|
+
tool_servers: list[Any] = field(default_factory=list)
|
|
76
|
+
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
77
|
+
subagents: list["AgentSpec"] = field(default_factory=list)
|
|
78
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class MCPTool:
|
|
83
|
+
"""Descriptor for a tool provided by an MCP server."""
|
|
84
|
+
|
|
85
|
+
name: str
|
|
86
|
+
description: str = ""
|
|
87
|
+
server: str = ""
|
|
88
|
+
input_schema: dict[str, Any] = field(default_factory=dict)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class TextBlock:
|
|
93
|
+
"""A block of text output from the agent."""
|
|
94
|
+
|
|
95
|
+
text: str
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class ThinkingBlock:
|
|
100
|
+
"""A block of thinking/reasoning output from the agent."""
|
|
101
|
+
|
|
102
|
+
text: str
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class ToolUse:
|
|
107
|
+
"""A tool invocation paired with its result."""
|
|
108
|
+
|
|
109
|
+
id: str
|
|
110
|
+
name: str
|
|
111
|
+
arguments: dict[str, Any] = field(default_factory=dict)
|
|
112
|
+
output: str = ""
|
|
113
|
+
is_error: bool = False
|
|
114
|
+
subagent_trajectory: Trajectory | None = None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
ContentBlock = Union[TextBlock, ThinkingBlock, ToolUse]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class UsageStats:
|
|
122
|
+
"""Token usage and cost statistics."""
|
|
123
|
+
|
|
124
|
+
input_tokens: int = 0
|
|
125
|
+
output_tokens: int = 0
|
|
126
|
+
cache_read_tokens: int = 0
|
|
127
|
+
cache_write_tokens: int = 0
|
|
128
|
+
cost_usd: float = 0.0
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def total_tokens(self) -> int:
|
|
132
|
+
"""Total tokens consumed (input + output)."""
|
|
133
|
+
return self.input_tokens + self.output_tokens
|
|
134
|
+
|
|
135
|
+
def __add__(self, other: UsageStats) -> UsageStats:
|
|
136
|
+
return UsageStats(
|
|
137
|
+
input_tokens=self.input_tokens + other.input_tokens,
|
|
138
|
+
output_tokens=self.output_tokens + other.output_tokens,
|
|
139
|
+
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
|
140
|
+
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
|
141
|
+
cost_usd=self.cost_usd + other.cost_usd,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def to_dict(self) -> dict[str, Any]:
|
|
145
|
+
return {
|
|
146
|
+
"input_tokens": self.input_tokens,
|
|
147
|
+
"output_tokens": self.output_tokens,
|
|
148
|
+
"cache_read_tokens": self.cache_read_tokens,
|
|
149
|
+
"cache_write_tokens": self.cache_write_tokens,
|
|
150
|
+
"cost_usd": self.cost_usd,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def from_dict(cls, d: dict[str, Any]) -> UsageStats:
|
|
155
|
+
return cls(
|
|
156
|
+
input_tokens=d.get("input_tokens", 0),
|
|
157
|
+
output_tokens=d.get("output_tokens", 0),
|
|
158
|
+
cache_read_tokens=d.get("cache_read_tokens", 0),
|
|
159
|
+
cache_write_tokens=d.get("cache_write_tokens", 0),
|
|
160
|
+
cost_usd=d.get("cost_usd", 0.0),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclass
|
|
165
|
+
class Turn:
|
|
166
|
+
"""A single turn: user sends a message, agent responds."""
|
|
167
|
+
|
|
168
|
+
input: str
|
|
169
|
+
output: list[ContentBlock] = field(default_factory=list)
|
|
170
|
+
usage: UsageStats = field(default_factory=UsageStats)
|
|
171
|
+
duration_ms: int = 0
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def result(self) -> str:
|
|
175
|
+
"""Last text block's content."""
|
|
176
|
+
for block in reversed(self.output):
|
|
177
|
+
if isinstance(block, TextBlock) and block.text:
|
|
178
|
+
return block.text
|
|
179
|
+
return ""
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def tool_calls(self) -> list[ToolUse]:
|
|
183
|
+
"""All tool calls made during this turn."""
|
|
184
|
+
return [b for b in self.output if isinstance(b, ToolUse)]
|
|
185
|
+
|
|
186
|
+
def to_dict(self) -> dict[str, Any]:
|
|
187
|
+
return {
|
|
188
|
+
"input": self.input,
|
|
189
|
+
"output": [_block_to_dict(b) for b in self.output],
|
|
190
|
+
"usage": self.usage.to_dict(),
|
|
191
|
+
"duration_ms": self.duration_ms,
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def from_dict(cls, d: dict[str, Any]) -> Turn:
|
|
196
|
+
return cls(
|
|
197
|
+
input=d.get("input", ""),
|
|
198
|
+
output=[_block_from_dict(b) for b in d.get("output", [])],
|
|
199
|
+
usage=UsageStats.from_dict(d.get("usage", {})),
|
|
200
|
+
duration_ms=d.get("duration_ms", 0),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class Trajectory:
|
|
206
|
+
"""Complete record of a session.
|
|
207
|
+
|
|
208
|
+
``usage`` tracks this agent's own token usage. Use ``total_usage`` to get
|
|
209
|
+
the accumulated usage including all nested subagent trajectories.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
agent: str
|
|
213
|
+
model: str = ""
|
|
214
|
+
session_id: str = ""
|
|
215
|
+
created_at: str = ""
|
|
216
|
+
completed_at: str = ""
|
|
217
|
+
usage_limited: bool = False
|
|
218
|
+
system_prompt: str = ""
|
|
219
|
+
reasoning: str = ""
|
|
220
|
+
mcp_servers: list[MCPServer] = field(default_factory=list)
|
|
221
|
+
turns: list[Turn] = field(default_factory=list)
|
|
222
|
+
usage: UsageStats = field(default_factory=UsageStats)
|
|
223
|
+
duration_ms: int = 0
|
|
224
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def num_turns(self) -> int:
|
|
228
|
+
return len(self.turns)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def result(self) -> str:
|
|
232
|
+
"""The final result from the last turn."""
|
|
233
|
+
if self.turns:
|
|
234
|
+
return self.turns[-1].result
|
|
235
|
+
return ""
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def total_tool_calls(self) -> int:
|
|
239
|
+
return sum(len(t.tool_calls) for t in self.turns)
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def total_usage(self) -> UsageStats:
|
|
243
|
+
"""Accumulated usage: own + all nested subagent trajectories (recursive)."""
|
|
244
|
+
total = self.usage
|
|
245
|
+
for turn in self.turns:
|
|
246
|
+
for block in turn.output:
|
|
247
|
+
if isinstance(block, ToolUse) and block.subagent_trajectory:
|
|
248
|
+
total = total + block.subagent_trajectory.total_usage
|
|
249
|
+
return total
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def subagent_trajectories(self) -> list[Trajectory]:
|
|
253
|
+
"""All subagent trajectories across all turns."""
|
|
254
|
+
trajs: list[Trajectory] = []
|
|
255
|
+
for turn in self.turns:
|
|
256
|
+
for block in turn.output:
|
|
257
|
+
if isinstance(block, ToolUse) and block.subagent_trajectory:
|
|
258
|
+
trajs.append(block.subagent_trajectory)
|
|
259
|
+
return trajs
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def is_usage_limited(self) -> bool:
|
|
263
|
+
"""Whether the session ended due to a usage limit.
|
|
264
|
+
|
|
265
|
+
Set by ``Session.end()`` using the provider's ``detect_usage_limit``.
|
|
266
|
+
"""
|
|
267
|
+
return self.usage_limited
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def is_complete(self) -> bool:
|
|
271
|
+
"""Whether the session completed normally.
|
|
272
|
+
|
|
273
|
+
A trajectory is complete when it has been finalized (``completed_at``
|
|
274
|
+
is set by ``Session.end()``) and was not usage-limited. Mid-session
|
|
275
|
+
snapshots written by ``append_turn`` have an empty ``completed_at``
|
|
276
|
+
and are therefore not considered complete.
|
|
277
|
+
"""
|
|
278
|
+
return bool(self.completed_at) and not self.is_usage_limited
|
|
279
|
+
|
|
280
|
+
def to_dict(self) -> dict[str, Any]:
|
|
281
|
+
return {
|
|
282
|
+
"agent": self.agent,
|
|
283
|
+
"model": self.model,
|
|
284
|
+
"session_id": self.session_id,
|
|
285
|
+
"created_at": self.created_at,
|
|
286
|
+
"completed_at": self.completed_at,
|
|
287
|
+
"usage_limited": self.usage_limited,
|
|
288
|
+
"system_prompt": self.system_prompt,
|
|
289
|
+
"reasoning": self.reasoning,
|
|
290
|
+
"mcp_servers": [
|
|
291
|
+
{"name": s.name, "command": s.command, "args": s.args, "env": s.env, "url": s.url}
|
|
292
|
+
for s in self.mcp_servers
|
|
293
|
+
],
|
|
294
|
+
"turns": [t.to_dict() for t in self.turns],
|
|
295
|
+
"usage": self.usage.to_dict(),
|
|
296
|
+
"total_usage": self.total_usage.to_dict(),
|
|
297
|
+
"duration_ms": self.duration_ms,
|
|
298
|
+
"metadata": self.metadata,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
@classmethod
|
|
302
|
+
def from_dict(cls, d: dict[str, Any]) -> Trajectory:
|
|
303
|
+
return cls(
|
|
304
|
+
agent=d.get("agent", ""),
|
|
305
|
+
model=d.get("model", ""),
|
|
306
|
+
session_id=d.get("session_id", ""),
|
|
307
|
+
created_at=d.get("created_at", ""),
|
|
308
|
+
completed_at=d.get("completed_at", ""),
|
|
309
|
+
usage_limited=d.get("usage_limited", False),
|
|
310
|
+
system_prompt=d.get("system_prompt", ""),
|
|
311
|
+
reasoning=d.get("reasoning", ""),
|
|
312
|
+
mcp_servers=[
|
|
313
|
+
MCPServer(
|
|
314
|
+
name=s.get("name", ""),
|
|
315
|
+
command=s.get("command", ""),
|
|
316
|
+
args=s.get("args", []),
|
|
317
|
+
env=s.get("env", {}),
|
|
318
|
+
url=s.get("url", ""),
|
|
319
|
+
)
|
|
320
|
+
for s in d.get("mcp_servers", [])
|
|
321
|
+
],
|
|
322
|
+
turns=[Turn.from_dict(t) for t in d.get("turns", [])],
|
|
323
|
+
usage=UsageStats.from_dict(d.get("usage", {})),
|
|
324
|
+
duration_ms=d.get("duration_ms", 0),
|
|
325
|
+
metadata=d.get("metadata", {}),
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
# -- Serialization helpers for content blocks --------------------------------
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@dataclass
|
|
333
|
+
class InteractiveResult:
|
|
334
|
+
"""Result from an interactive agent session."""
|
|
335
|
+
|
|
336
|
+
exit_code: int
|
|
337
|
+
output: str = "" # raw terminal output (may include ANSI escape sequences)
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def session_id(self) -> str | None:
|
|
341
|
+
"""Extract the session ID from Claude Code's exit output, if present."""
|
|
342
|
+
import re
|
|
343
|
+
|
|
344
|
+
m = re.search(r"--resume\s+(\S+)", self.output)
|
|
345
|
+
return m.group(1) if m else None
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _block_to_dict(block: ContentBlock) -> dict[str, Any]:
|
|
349
|
+
if isinstance(block, TextBlock):
|
|
350
|
+
return {"type": "text", "text": block.text}
|
|
351
|
+
elif isinstance(block, ThinkingBlock):
|
|
352
|
+
return {"type": "thinking", "text": block.text}
|
|
353
|
+
else: # ToolUse
|
|
354
|
+
d: dict[str, Any] = {
|
|
355
|
+
"type": "tool_use",
|
|
356
|
+
"id": block.id,
|
|
357
|
+
"name": block.name,
|
|
358
|
+
"arguments": block.arguments,
|
|
359
|
+
"output": block.output,
|
|
360
|
+
}
|
|
361
|
+
if block.is_error:
|
|
362
|
+
d["is_error"] = True
|
|
363
|
+
if block.subagent_trajectory:
|
|
364
|
+
d["subagent_trajectory"] = block.subagent_trajectory.to_dict()
|
|
365
|
+
return d
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _block_from_dict(d: dict[str, Any]) -> ContentBlock:
|
|
369
|
+
btype = d.get("type", "")
|
|
370
|
+
if btype == "text":
|
|
371
|
+
return TextBlock(text=d.get("text", ""))
|
|
372
|
+
elif btype == "thinking":
|
|
373
|
+
return ThinkingBlock(text=d.get("text", ""))
|
|
374
|
+
else: # tool_use
|
|
375
|
+
sub_traj = None
|
|
376
|
+
if d.get("subagent_trajectory"):
|
|
377
|
+
sub_traj = Trajectory.from_dict(d["subagent_trajectory"])
|
|
378
|
+
return ToolUse(
|
|
379
|
+
id=d.get("id", ""),
|
|
380
|
+
name=d.get("name", ""),
|
|
381
|
+
arguments=d.get("arguments", {}),
|
|
382
|
+
output=d.get("output", ""),
|
|
383
|
+
is_error=d.get("is_error", False),
|
|
384
|
+
subagent_trajectory=sub_traj,
|
|
385
|
+
)
|
caw/pricing.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
{
|
|
2
|
+
"_comment": "Pricing in USD per 1 million tokens, keyed by agent then model",
|
|
3
|
+
"codex": {
|
|
4
|
+
"gpt-5.2-codex": {
|
|
5
|
+
"input": 1.75,
|
|
6
|
+
"cached_input": 0.175,
|
|
7
|
+
"output": 14.0
|
|
8
|
+
},
|
|
9
|
+
"gpt-5.3-codex": {
|
|
10
|
+
"input": 1.75,
|
|
11
|
+
"cached_input": 0.175,
|
|
12
|
+
"output": 14.0
|
|
13
|
+
}
|
|
14
|
+
}
|
|
15
|
+
}
|
caw/pricing.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Token-based cost computation from pricing config."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from caw.models import UsageStats
|
|
10
|
+
|
|
11
|
+
_pricing_cache: dict[str, Any] | None = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _load_pricing() -> dict[str, Any]:
|
|
15
|
+
global _pricing_cache
|
|
16
|
+
if _pricing_cache is None:
|
|
17
|
+
path = Path(__file__).parent / "pricing.json"
|
|
18
|
+
if path.exists():
|
|
19
|
+
_pricing_cache = json.loads(path.read_text())
|
|
20
|
+
else:
|
|
21
|
+
_pricing_cache = {}
|
|
22
|
+
return _pricing_cache
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compute_cost(agent: str, model: str, usage: UsageStats) -> float:
|
|
26
|
+
"""Compute cost in USD from token counts and pricing config."""
|
|
27
|
+
pricing = _load_pricing().get(agent, {}).get(model, {})
|
|
28
|
+
cost = (
|
|
29
|
+
usage.input_tokens * pricing.get("input", 0.0)
|
|
30
|
+
+ usage.cache_read_tokens * pricing.get("cached_input", 0.0)
|
|
31
|
+
+ usage.output_tokens * pricing.get("output", 0.0)
|
|
32
|
+
) / 1_000_000
|
|
33
|
+
return cost
|
caw/provider.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Abstract base classes for provider implementations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from caw.models import InteractiveResult, MCPServer, ModelTier, ToolGroup, Trajectory, Turn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProviderSession(ABC):
|
|
13
|
+
"""ABC that each provider implements to manage a live session."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def send(self, message: str) -> Turn:
|
|
17
|
+
"""Send a message and return the agent's response turn."""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def end(self) -> Trajectory:
|
|
22
|
+
"""Finalize the session and return the complete trajectory."""
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def trajectory(self) -> Trajectory:
|
|
28
|
+
"""The accumulated trajectory so far."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
def detect_usage_limit(self, turn: Turn) -> int | None:
|
|
32
|
+
"""Check whether *turn* indicates the provider's usage limit was hit.
|
|
33
|
+
|
|
34
|
+
Returns the number of minutes to wait before retrying, or ``None``
|
|
35
|
+
if no limit was detected. Override in provider subclasses to
|
|
36
|
+
implement provider-specific detection logic.
|
|
37
|
+
"""
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def session_id(self) -> str | None:
|
|
42
|
+
"""Provider-assigned session ID (if any)."""
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def last_raw_output(self) -> str | None:
|
|
47
|
+
"""Raw CLI stdout from the most recent send() call (if available)."""
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
def set_step_callback(self, callback: Callable[[list], None] | None) -> None:
|
|
51
|
+
"""Set callback invoked after each step within send()."""
|
|
52
|
+
pass # default no-op; concrete providers override
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Provider(ABC):
|
|
56
|
+
"""ABC that each coding agent backend implements."""
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def name(self) -> str:
|
|
61
|
+
"""Provider identifier (e.g. 'claude_code', 'codex')."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
def resolve_model(self, tier: ModelTier) -> str:
|
|
65
|
+
"""Translate a :class:`ModelTier` into a concrete model identifier.
|
|
66
|
+
|
|
67
|
+
Each provider must override this to map abstract tiers (e.g.
|
|
68
|
+
``ModelTier.STRONGEST``) to the actual model string it supports.
|
|
69
|
+
"""
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
f"{self.name} provider does not implement resolve_model(); "
|
|
72
|
+
f"pass an explicit model string instead of ModelTier.{tier.name}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def resolve_tool_restrictions(self, tools: ToolGroup) -> dict[str, Any]:
|
|
76
|
+
"""Translate ToolGroup into provider-specific session kwargs.
|
|
77
|
+
|
|
78
|
+
Receives a concrete ToolGroup value (never None — the Agent layer
|
|
79
|
+
applies the default before calling this).
|
|
80
|
+
"""
|
|
81
|
+
return {}
|
|
82
|
+
|
|
83
|
+
def check_limit(self, model: str | None = None) -> int | None:
|
|
84
|
+
"""Probe whether the provider's usage limit is currently active.
|
|
85
|
+
|
|
86
|
+
Sends a minimal test prompt and checks if the response indicates a
|
|
87
|
+
usage-limit. Returns the estimated number of minutes to wait before
|
|
88
|
+
the limit resets, or ``None`` if no limit is detected.
|
|
89
|
+
|
|
90
|
+
This incurs a small token cost for the probe request.
|
|
91
|
+
"""
|
|
92
|
+
from caw.display import get_global_display, set_global_display
|
|
93
|
+
|
|
94
|
+
old_display = get_global_display()
|
|
95
|
+
set_global_display(None)
|
|
96
|
+
try:
|
|
97
|
+
session = self.start_session(
|
|
98
|
+
mcp_servers=[],
|
|
99
|
+
model=model,
|
|
100
|
+
system_prompt="Reply with the single word 'ok'.",
|
|
101
|
+
**self._limit_probe_kwargs(),
|
|
102
|
+
)
|
|
103
|
+
try:
|
|
104
|
+
turn = session.send("hi")
|
|
105
|
+
return session.detect_usage_limit(turn)
|
|
106
|
+
finally:
|
|
107
|
+
session.end()
|
|
108
|
+
finally:
|
|
109
|
+
set_global_display(old_display)
|
|
110
|
+
|
|
111
|
+
def _limit_probe_kwargs(self) -> dict[str, Any]:
|
|
112
|
+
"""Extra session kwargs for the limit-check probe.
|
|
113
|
+
|
|
114
|
+
Override in subclasses to disable tools and minimise side-effects.
|
|
115
|
+
"""
|
|
116
|
+
return {}
|
|
117
|
+
|
|
118
|
+
def start_interactive(
|
|
119
|
+
self, initial_prompt: str, mcp_servers: list[MCPServer], capture_bytes: int = 0, **kwargs: Any
|
|
120
|
+
) -> InteractiveResult:
|
|
121
|
+
"""Launch the provider binary interactively with an initial prompt.
|
|
122
|
+
|
|
123
|
+
Hands control to the user's terminal — stdin/stdout/stderr are
|
|
124
|
+
inherited so the user interacts with the agent directly.
|
|
125
|
+
A copy of stdout is captured via a pty.
|
|
126
|
+
|
|
127
|
+
Returns an :class:`InteractiveResult` with the exit code and
|
|
128
|
+
captured output.
|
|
129
|
+
"""
|
|
130
|
+
raise NotImplementedError(f"{self.name} provider does not support interactive mode.")
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def start_session(self, mcp_servers: list[MCPServer], **kwargs: Any) -> ProviderSession:
|
|
134
|
+
"""Create and return a new provider session."""
|
|
135
|
+
...
|
|
File without changes
|