coding-agent-wrapper 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caw/__init__.py +88 -0
- caw/agent.py +578 -0
- caw/auth/README.md +118 -0
- caw/auth/__init__.py +23 -0
- caw/auth/cli.py +68 -0
- caw/auth/collector.py +324 -0
- caw/auth/linker.py +174 -0
- caw/auth/manifest.py +77 -0
- caw/auth/providers.py +433 -0
- caw/auth/status.py +241 -0
- caw/cli.py +50 -0
- caw/display.py +223 -0
- caw/faststats.py +298 -0
- caw/mcp.py +602 -0
- caw/models.py +385 -0
- caw/pricing.json +15 -0
- caw/pricing.py +33 -0
- caw/provider.py +135 -0
- caw/providers/__init__.py +0 -0
- caw/providers/claude_code.py +648 -0
- caw/providers/codex.py +564 -0
- caw/py.typed +0 -0
- caw/storage.py +184 -0
- caw/toolkit.py +198 -0
- caw/viewer/__init__.py +149 -0
- caw/viewer/static/index.html +847 -0
- coding_agent_wrapper-0.1.0.dist-info/METADATA +213 -0
- coding_agent_wrapper-0.1.0.dist-info/RECORD +31 -0
- coding_agent_wrapper-0.1.0.dist-info/WHEEL +4 -0
- coding_agent_wrapper-0.1.0.dist-info/entry_points.txt +2 -0
- coding_agent_wrapper-0.1.0.dist-info/licenses/LICENSE +202 -0
caw/__init__.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""caw - Coding Agent Wrapper."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from caw.agent import Agent, Session, register_provider
|
|
6
|
+
from caw.display import Display, DisplayMode, get_global_display, set_global_display
|
|
7
|
+
from caw.faststats import FastStats
|
|
8
|
+
from caw.storage import JsonlWriter, SessionStore
|
|
9
|
+
from caw.models import (
|
|
10
|
+
AgentSpec,
|
|
11
|
+
ContentBlock,
|
|
12
|
+
InteractiveResult,
|
|
13
|
+
MCPServer,
|
|
14
|
+
MCPTool,
|
|
15
|
+
ModelTier,
|
|
16
|
+
TextBlock,
|
|
17
|
+
ThinkingBlock,
|
|
18
|
+
ToolGroup,
|
|
19
|
+
ToolUse,
|
|
20
|
+
Trajectory,
|
|
21
|
+
Turn,
|
|
22
|
+
UsageStats,
|
|
23
|
+
)
|
|
24
|
+
from caw.provider import Provider, ProviderSession
|
|
25
|
+
from caw.providers.claude_code import ClaudeCodeProvider
|
|
26
|
+
from caw.providers.codex import CodexProvider
|
|
27
|
+
from caw.mcp import (
|
|
28
|
+
MCPServerHandle,
|
|
29
|
+
create_mcp_http_server_bundle,
|
|
30
|
+
create_stateless_tool_server,
|
|
31
|
+
create_subagent_tool_server,
|
|
32
|
+
get_state_from_context,
|
|
33
|
+
mcp_tool,
|
|
34
|
+
register_tool,
|
|
35
|
+
)
|
|
36
|
+
from caw.toolkit import ToolKit, tool
|
|
37
|
+
from caw.viewer import ViewerServer, start_viewer_server
|
|
38
|
+
from caw.auth import setup as auth_setup, get_status as auth_get_status, get_docker_flags as auth_get_docker_flags
|
|
39
|
+
|
|
40
|
+
# Auto-register built-in providers
|
|
41
|
+
register_provider("claude_code", ClaudeCodeProvider)
|
|
42
|
+
register_provider("claude", ClaudeCodeProvider)
|
|
43
|
+
register_provider("cc", ClaudeCodeProvider)
|
|
44
|
+
register_provider("codex", CodexProvider)
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"Agent",
|
|
48
|
+
"AgentSpec",
|
|
49
|
+
"ClaudeCodeProvider",
|
|
50
|
+
"CodexProvider",
|
|
51
|
+
"JsonlWriter",
|
|
52
|
+
"ContentBlock",
|
|
53
|
+
"InteractiveResult",
|
|
54
|
+
"Display",
|
|
55
|
+
"DisplayMode",
|
|
56
|
+
"FastStats",
|
|
57
|
+
"get_global_display",
|
|
58
|
+
"set_global_display",
|
|
59
|
+
"MCPServer",
|
|
60
|
+
"MCPServerHandle",
|
|
61
|
+
"MCPTool",
|
|
62
|
+
"ModelTier",
|
|
63
|
+
"Provider",
|
|
64
|
+
"ProviderSession",
|
|
65
|
+
"Session",
|
|
66
|
+
"SessionStore",
|
|
67
|
+
"TextBlock",
|
|
68
|
+
"ThinkingBlock",
|
|
69
|
+
"ToolGroup",
|
|
70
|
+
"ToolUse",
|
|
71
|
+
"Trajectory",
|
|
72
|
+
"Turn",
|
|
73
|
+
"UsageStats",
|
|
74
|
+
"create_mcp_http_server_bundle",
|
|
75
|
+
"create_stateless_tool_server",
|
|
76
|
+
"create_subagent_tool_server",
|
|
77
|
+
"get_state_from_context",
|
|
78
|
+
"mcp_tool",
|
|
79
|
+
"register_provider",
|
|
80
|
+
"register_tool",
|
|
81
|
+
"start_viewer_server",
|
|
82
|
+
"tool",
|
|
83
|
+
"ToolKit",
|
|
84
|
+
"ViewerServer",
|
|
85
|
+
"auth_setup",
|
|
86
|
+
"auth_get_status",
|
|
87
|
+
"auth_get_docker_flags",
|
|
88
|
+
]
|
caw/agent.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
"""Concrete Agent and Session — the main user-facing API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import datetime
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
import uuid as uuid_mod
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from caw.display import get_global_display
|
|
19
|
+
from caw.models import AgentSpec, InteractiveResult, MCPServer, ModelTier, ToolGroup, ToolUse, Trajectory, Turn
|
|
20
|
+
from caw.provider import Provider, ProviderSession
|
|
21
|
+
from caw.storage import SessionStore
|
|
22
|
+
from caw.toolkit import ToolKit
|
|
23
|
+
from caw.mcp import create_stateless_tool_server
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Environment variable overrides
|
|
29
|
+
#
|
|
30
|
+
# These let you configure caw globally without changing code. Each one is
|
|
31
|
+
# used as a fallback when the corresponding value is not set explicitly via
|
|
32
|
+
# the Agent() constructor or method calls.
|
|
33
|
+
#
|
|
34
|
+
# CAW_PROVIDER — Provider backend ("claude_code", "codex", …)
|
|
35
|
+
# CAW_MODEL — Model name passed to the provider (e.g. "gpt-5.2-codex")
|
|
36
|
+
# CAW_EFFORT — Reasoning effort level (e.g. "high", "medium", "low")
|
|
37
|
+
# CAW_AUTOWAIT — Auto-wait on usage limit ("1"=on, "0"/"false"=off; default on)
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
DEFAULT_PROVIDER = "claude_code"
|
|
40
|
+
CAW_PROVIDER = "CAW_PROVIDER"
|
|
41
|
+
CAW_MODEL = "CAW_MODEL"
|
|
42
|
+
CAW_EFFORT = "CAW_EFFORT"
|
|
43
|
+
CAW_AUTOWAIT = "CAW_AUTOWAIT"
|
|
44
|
+
|
|
45
|
+
_PROVIDER_REGISTRY: dict[str, type[Provider]] = {}
|
|
46
|
+
|
|
47
|
+
_AUTO_WAIT_RESUME_MESSAGE = "Usage limit reached earlier, now you may continue the work."
|
|
48
|
+
|
|
49
|
+
# Must match caw.mcp._TRAJ_MARKER_PREFIX / _SUFFIX
|
|
50
|
+
_TRAJ_MARKER_RE = re.compile(r"\n<!-- caw_traj:([\w-]+) -->$")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def register_provider(name: str, cls: type[Provider]) -> None:
|
|
54
|
+
"""Register a provider class under the given name."""
|
|
55
|
+
_PROVIDER_REGISTRY[name] = cls
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _resolve_provider(name: str | None) -> Provider:
|
|
59
|
+
"""Resolve provider: explicit name > env var > default."""
|
|
60
|
+
provider_name = name or os.environ.get(CAW_PROVIDER) or DEFAULT_PROVIDER
|
|
61
|
+
if provider_name not in _PROVIDER_REGISTRY:
|
|
62
|
+
available = list(_PROVIDER_REGISTRY.keys())
|
|
63
|
+
raise ValueError(f"Unknown provider {provider_name!r}. Available: {available}")
|
|
64
|
+
return _PROVIDER_REGISTRY[provider_name]()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _attach_subagent_trajectories(turn: Turn, traj_dir: str | None) -> None:
|
|
68
|
+
"""Scan a turn's tool outputs for trajectory markers and attach them.
|
|
69
|
+
|
|
70
|
+
For each ToolUse whose output ends with ``<!-- caw_traj:<uuid> -->``:
|
|
71
|
+
1. Load the trajectory JSON from ``traj_dir/<uuid>.json``
|
|
72
|
+
2. Attach it as ``tool_use.subagent_trajectory``
|
|
73
|
+
3. Strip the marker from ``tool_use.output``
|
|
74
|
+
"""
|
|
75
|
+
if not traj_dir:
|
|
76
|
+
return
|
|
77
|
+
for block in turn.output:
|
|
78
|
+
if not isinstance(block, ToolUse):
|
|
79
|
+
continue
|
|
80
|
+
m = _TRAJ_MARKER_RE.search(block.output)
|
|
81
|
+
if not m:
|
|
82
|
+
continue
|
|
83
|
+
traj_id = m.group(1)
|
|
84
|
+
traj_path = os.path.join(traj_dir, f"{traj_id}.json")
|
|
85
|
+
try:
|
|
86
|
+
with open(traj_path) as f:
|
|
87
|
+
traj_dict = json.load(f)
|
|
88
|
+
block.subagent_trajectory = Trajectory.from_dict(traj_dict)
|
|
89
|
+
except (OSError, json.JSONDecodeError, KeyError):
|
|
90
|
+
pass # best-effort
|
|
91
|
+
# Strip marker regardless
|
|
92
|
+
block.output = block.output[: m.start()]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Session:
|
|
96
|
+
"""A live interaction session with a coding agent."""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
provider_session: ProviderSession,
|
|
101
|
+
store: SessionStore | None = None,
|
|
102
|
+
subagent_traj_dir: str | None = None,
|
|
103
|
+
tool_handles: list[Any] | None = None,
|
|
104
|
+
auto_wait: bool = True,
|
|
105
|
+
metadata: dict[str, Any] | None = None,
|
|
106
|
+
) -> None:
|
|
107
|
+
self._session = provider_session
|
|
108
|
+
self._store = store
|
|
109
|
+
self._subagent_traj_dir = subagent_traj_dir
|
|
110
|
+
self._tool_handles = tool_handles or []
|
|
111
|
+
self._auto_wait = auto_wait
|
|
112
|
+
self._metadata: dict[str, Any] = dict(metadata) if metadata else {}
|
|
113
|
+
self._readonly = False
|
|
114
|
+
self._send_lock = threading.Lock()
|
|
115
|
+
self._async_send_lock: asyncio.Lock | None = None
|
|
116
|
+
self._traj_path: str | Path | None = None
|
|
117
|
+
|
|
118
|
+
async def send_async(self, message: str) -> Turn:
|
|
119
|
+
"""Async version of :meth:`send` — runs in a thread.
|
|
120
|
+
|
|
121
|
+
Messages are processed in FIFO order: if multiple ``send_async``
|
|
122
|
+
calls overlap, each waits for the previous one to finish before
|
|
123
|
+
starting. This lets you fire-and-forget multiple messages::
|
|
124
|
+
|
|
125
|
+
tasks = [asyncio.create_task(session.send_async(m)) for m in msgs]
|
|
126
|
+
turns = await asyncio.gather(*tasks) # executed in order
|
|
127
|
+
|
|
128
|
+
You can also do async work while a send is in progress::
|
|
129
|
+
|
|
130
|
+
task = asyncio.create_task(session.send_async(prompt))
|
|
131
|
+
while not task.done():
|
|
132
|
+
source = await asyncio.wait_for(queue.get(), timeout=0.5)
|
|
133
|
+
yield source
|
|
134
|
+
turn = await task
|
|
135
|
+
"""
|
|
136
|
+
if self._async_send_lock is None:
|
|
137
|
+
self._async_send_lock = asyncio.Lock()
|
|
138
|
+
async with self._async_send_lock:
|
|
139
|
+
return await asyncio.to_thread(self.send, message)
|
|
140
|
+
|
|
141
|
+
def send(self, message: str) -> Turn:
|
|
142
|
+
"""Send a message and get the agent's response turn.
|
|
143
|
+
|
|
144
|
+
When auto-wait is enabled and the provider reports a usage limit,
|
|
145
|
+
this method sleeps until the limit resets and then automatically
|
|
146
|
+
resumes the conversation — transparently to the caller.
|
|
147
|
+
"""
|
|
148
|
+
if self._readonly:
|
|
149
|
+
raise RuntimeError("Cannot send messages on a loaded session")
|
|
150
|
+
with self._send_lock:
|
|
151
|
+
current_message = message
|
|
152
|
+
|
|
153
|
+
while True:
|
|
154
|
+
# Set up per-step callback so traj_path is updated in real time
|
|
155
|
+
def _save_step(blocks, _msg=current_message):
|
|
156
|
+
traj = self.trajectory
|
|
157
|
+
partial_turn = Turn(input=_msg, output=list(blocks))
|
|
158
|
+
traj.turns.append(partial_turn)
|
|
159
|
+
if self._traj_path:
|
|
160
|
+
p = Path(self._traj_path)
|
|
161
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
162
|
+
p.write_text(json.dumps(traj.to_dict(), indent=2))
|
|
163
|
+
if self._store:
|
|
164
|
+
self._store._save_trajectory(traj)
|
|
165
|
+
|
|
166
|
+
self._session.set_step_callback(_save_step)
|
|
167
|
+
turn = self._session.send(current_message)
|
|
168
|
+
self._session.set_step_callback(None)
|
|
169
|
+
|
|
170
|
+
# Attach subagent trajectories from marker files
|
|
171
|
+
_attach_subagent_trajectories(turn, self._subagent_traj_dir)
|
|
172
|
+
|
|
173
|
+
if self._store is not None:
|
|
174
|
+
self._store.append_turn(turn, self.trajectory, raw_output=self._session.last_raw_output)
|
|
175
|
+
|
|
176
|
+
# Ask the provider whether this turn hit a usage limit
|
|
177
|
+
if self._auto_wait:
|
|
178
|
+
wait_minutes = self._session.detect_usage_limit(turn)
|
|
179
|
+
if wait_minutes is not None:
|
|
180
|
+
logger.warning(
|
|
181
|
+
"Usage limit reached. Auto-waiting %s min before resuming.",
|
|
182
|
+
wait_minutes,
|
|
183
|
+
)
|
|
184
|
+
display = get_global_display()
|
|
185
|
+
if display:
|
|
186
|
+
display.on_metadata(
|
|
187
|
+
auto_wait=f"sleeping {wait_minutes}min until limit resets",
|
|
188
|
+
)
|
|
189
|
+
time.sleep(wait_minutes * 60)
|
|
190
|
+
current_message = _AUTO_WAIT_RESUME_MESSAGE
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
return turn
|
|
194
|
+
|
|
195
|
+
def end(self) -> Trajectory:
|
|
196
|
+
"""End the session and return the complete trajectory."""
|
|
197
|
+
if self._readonly:
|
|
198
|
+
raise RuntimeError("Cannot send messages on a loaded session")
|
|
199
|
+
self._session.end()
|
|
200
|
+
traj = self.trajectory
|
|
201
|
+
traj.completed_at = datetime.datetime.now(datetime.timezone.utc).isoformat()
|
|
202
|
+
if traj.turns and self._session.detect_usage_limit(traj.turns[-1]) is not None:
|
|
203
|
+
traj.usage_limited = True
|
|
204
|
+
if self._store is not None:
|
|
205
|
+
self._store.finalize(traj)
|
|
206
|
+
# Stop all tool server handles
|
|
207
|
+
for handle in self._tool_handles:
|
|
208
|
+
try:
|
|
209
|
+
handle.stop_sync()
|
|
210
|
+
except Exception:
|
|
211
|
+
pass
|
|
212
|
+
# Auto-save trajectory if configured
|
|
213
|
+
if self._traj_path is not None:
|
|
214
|
+
try:
|
|
215
|
+
p = Path(self._traj_path)
|
|
216
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
217
|
+
with open(p, "w") as f:
|
|
218
|
+
json.dump(traj.to_dict(), f, indent=2)
|
|
219
|
+
except Exception:
|
|
220
|
+
logger.warning("Failed to save trajectory to %s", self._traj_path, exc_info=True)
|
|
221
|
+
return traj
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def trajectory(self) -> Trajectory:
|
|
225
|
+
"""Accumulated trajectory (available during and after the session)."""
|
|
226
|
+
if self._readonly:
|
|
227
|
+
return self._loaded_trajectory
|
|
228
|
+
traj = self._session.trajectory
|
|
229
|
+
if self._metadata:
|
|
230
|
+
# Session metadata merges on top of provider metadata
|
|
231
|
+
traj.metadata = {**traj.metadata, **self._metadata}
|
|
232
|
+
return traj
|
|
233
|
+
|
|
234
|
+
def save_trajectory(self, path: str | Path) -> None:
|
|
235
|
+
"""Save the trajectory to a JSON file at the given path."""
|
|
236
|
+
p = Path(path)
|
|
237
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
238
|
+
with open(p, "w") as f:
|
|
239
|
+
json.dump(self.trajectory.to_dict(), f, indent=2)
|
|
240
|
+
|
|
241
|
+
@classmethod
|
|
242
|
+
def load_trajectory(cls, path: str | Path) -> Session:
|
|
243
|
+
"""Load a trajectory from a JSON file. The returned session is read-only."""
|
|
244
|
+
with open(path) as f:
|
|
245
|
+
data = json.load(f)
|
|
246
|
+
traj = Trajectory.from_dict(data)
|
|
247
|
+
return cls._from_trajectory(traj)
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def _from_trajectory(cls, traj: Trajectory) -> Session:
|
|
251
|
+
session = object.__new__(cls)
|
|
252
|
+
session._session = None
|
|
253
|
+
session._store = None
|
|
254
|
+
session._subagent_traj_dir = None
|
|
255
|
+
session._tool_handles = []
|
|
256
|
+
session._auto_wait = False
|
|
257
|
+
session._metadata = {}
|
|
258
|
+
session._readonly = True
|
|
259
|
+
session._send_lock = threading.Lock()
|
|
260
|
+
session._async_send_lock = None
|
|
261
|
+
session._loaded_trajectory = traj
|
|
262
|
+
return session
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def session_dir(self) -> Path | None:
|
|
266
|
+
"""Path to the session's data directory, or None if persistence is disabled."""
|
|
267
|
+
if self._store is not None:
|
|
268
|
+
return self._store.session_dir
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
def __enter__(self) -> Session:
|
|
272
|
+
return self
|
|
273
|
+
|
|
274
|
+
def __exit__(self, *args: Any) -> None:
|
|
275
|
+
self.end()
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class Agent:
|
|
279
|
+
"""Coding agent wrapper — unified interface across providers.
|
|
280
|
+
|
|
281
|
+
Provider resolution order:
|
|
282
|
+
1. Explicit ``provider`` argument
|
|
283
|
+
2. ``CAW_PROVIDER`` environment variable
|
|
284
|
+
3. Default provider
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
def __init__(
|
|
288
|
+
self,
|
|
289
|
+
provider: str | None = None,
|
|
290
|
+
data_dir: str | None = None,
|
|
291
|
+
system_prompt: str | None = None,
|
|
292
|
+
model: str | ModelTier | None = None,
|
|
293
|
+
reasoning: str | None = None,
|
|
294
|
+
tools: ToolGroup | None = None,
|
|
295
|
+
tool_servers: list[Any] | None = None,
|
|
296
|
+
stateless_tools: list[Any] | None = None,
|
|
297
|
+
name: str = "",
|
|
298
|
+
description: str = "",
|
|
299
|
+
**kwargs: Any,
|
|
300
|
+
) -> None:
|
|
301
|
+
self._provider_name = provider
|
|
302
|
+
self._provider: Provider | None = None
|
|
303
|
+
self._mcp_servers: list[MCPServer] = []
|
|
304
|
+
self._subagents: list[AgentSpec] = []
|
|
305
|
+
self._tool_servers: list[Any] = [] # list[MCPServerHandle], lazy import
|
|
306
|
+
if tool_servers:
|
|
307
|
+
for ts in tool_servers:
|
|
308
|
+
self.add_tool_server(ts)
|
|
309
|
+
if stateless_tools:
|
|
310
|
+
self._tool_servers.append(create_stateless_tool_server(stateless_tools))
|
|
311
|
+
self._data_dir = data_dir
|
|
312
|
+
self._name = name
|
|
313
|
+
self._description = description
|
|
314
|
+
self._metadata: dict[str, Any] = {}
|
|
315
|
+
if tools is not None:
|
|
316
|
+
kwargs["tools"] = tools
|
|
317
|
+
if system_prompt is not None:
|
|
318
|
+
kwargs["system_prompt"] = system_prompt
|
|
319
|
+
if model is not None:
|
|
320
|
+
kwargs["model"] = model
|
|
321
|
+
elif os.environ.get(CAW_MODEL):
|
|
322
|
+
kwargs["model"] = os.environ[CAW_MODEL]
|
|
323
|
+
if reasoning is not None:
|
|
324
|
+
kwargs["reasoning"] = reasoning
|
|
325
|
+
elif os.environ.get(CAW_EFFORT):
|
|
326
|
+
kwargs["reasoning"] = os.environ[CAW_EFFORT]
|
|
327
|
+
if "auto_wait" not in kwargs:
|
|
328
|
+
env_val = os.environ.get(CAW_AUTOWAIT, "").strip().lower()
|
|
329
|
+
if env_val in ("0", "false", "no", "off"):
|
|
330
|
+
kwargs["auto_wait"] = False
|
|
331
|
+
# Otherwise leave unset so provider default (True) applies
|
|
332
|
+
self._kwargs = kwargs
|
|
333
|
+
|
|
334
|
+
def set_provider(self, provider: str) -> None:
|
|
335
|
+
"""Set or change the provider before starting a session."""
|
|
336
|
+
self._provider_name = provider
|
|
337
|
+
self._provider = None
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def provider(self) -> Provider:
|
|
341
|
+
"""The resolved provider instance (lazily created)."""
|
|
342
|
+
if self._provider is None:
|
|
343
|
+
self._provider = _resolve_provider(self._provider_name)
|
|
344
|
+
return self._provider
|
|
345
|
+
|
|
346
|
+
@property
|
|
347
|
+
def mcp_servers(self) -> list[MCPServer]:
|
|
348
|
+
"""Currently configured MCP servers."""
|
|
349
|
+
return list(self._mcp_servers)
|
|
350
|
+
|
|
351
|
+
@property
|
|
352
|
+
def metadata(self) -> dict[str, Any]:
|
|
353
|
+
"""Mutable metadata dict carried onto every session's trajectory."""
|
|
354
|
+
return self._metadata
|
|
355
|
+
|
|
356
|
+
def add_mcp_server(self, server: MCPServer) -> None:
|
|
357
|
+
"""Register an MCP server for tool access."""
|
|
358
|
+
self._mcp_servers.append(server)
|
|
359
|
+
|
|
360
|
+
def add_tool_server(self, handle: Any) -> None:
|
|
361
|
+
"""Register a custom HTTP tool server (MCPServerHandle or ToolKit).
|
|
362
|
+
|
|
363
|
+
If *handle* is a :class:`~caw.toolkit.ToolKit` instance, ``as_server()``
|
|
364
|
+
is called automatically. The handle's lifecycle (start/stop) is managed
|
|
365
|
+
by the session.
|
|
366
|
+
"""
|
|
367
|
+
if isinstance(handle, ToolKit):
|
|
368
|
+
handle = handle.as_server()
|
|
369
|
+
self._tool_servers.append(handle)
|
|
370
|
+
|
|
371
|
+
def set_model(self, model: str | ModelTier) -> None:
|
|
372
|
+
"""Set the model to use for sessions."""
|
|
373
|
+
self._kwargs["model"] = model
|
|
374
|
+
|
|
375
|
+
def set_reasoning(self, reasoning: str) -> None:
|
|
376
|
+
"""Set the reasoning budget token (e.g. ``'medium'``)."""
|
|
377
|
+
self._kwargs["reasoning"] = reasoning
|
|
378
|
+
|
|
379
|
+
def set_system_prompt(self, system_prompt: str) -> None:
|
|
380
|
+
"""Set a system prompt that guides the agent's behavior for the session."""
|
|
381
|
+
self._kwargs["system_prompt"] = system_prompt
|
|
382
|
+
|
|
383
|
+
def set_tools(self, tools: ToolGroup) -> None:
|
|
384
|
+
"""Set the tool permission groups for sessions."""
|
|
385
|
+
self._kwargs["tools"] = tools
|
|
386
|
+
|
|
387
|
+
def add_subagent(self, spec: AgentSpec) -> None:
|
|
388
|
+
"""Register a subagent that will be exposed as a tool."""
|
|
389
|
+
self._subagents.append(spec)
|
|
390
|
+
|
|
391
|
+
def get_spec(self) -> AgentSpec:
|
|
392
|
+
"""Return an AgentSpec snapshot of this agent's current configuration."""
|
|
393
|
+
return AgentSpec(
|
|
394
|
+
name=self._name,
|
|
395
|
+
description=self._description,
|
|
396
|
+
system_prompt=self._kwargs.get("system_prompt", ""),
|
|
397
|
+
model=self._kwargs.get("model", ""),
|
|
398
|
+
reasoning=self._kwargs.get("reasoning", ""),
|
|
399
|
+
tools=self._kwargs.get("tools"),
|
|
400
|
+
tool_servers=list(self._tool_servers),
|
|
401
|
+
mcp_servers=list(self._mcp_servers),
|
|
402
|
+
subagents=list(self._subagents),
|
|
403
|
+
metadata=dict(self._metadata),
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def _subagent_tool_servers(self, traj_dir: str, jsonl_path: str | None = None) -> list[Any]:
|
|
407
|
+
"""Convert registered subagents into HTTP tool server handles."""
|
|
408
|
+
from caw.mcp import create_subagent_tool_server
|
|
409
|
+
|
|
410
|
+
handles = []
|
|
411
|
+
for spec in self._subagents:
|
|
412
|
+
handle = create_subagent_tool_server(spec, traj_dir, jsonl_path)
|
|
413
|
+
handles.append(handle)
|
|
414
|
+
return handles
|
|
415
|
+
|
|
416
|
+
def check_limit(self) -> int | None:
|
|
417
|
+
"""Check if the provider's usage limit is currently active.
|
|
418
|
+
|
|
419
|
+
Sends a minimal test prompt to detect whether the configured
|
|
420
|
+
provider and model are currently rate-limited. Returns the
|
|
421
|
+
estimated number of minutes until the limit resets, or ``None``
|
|
422
|
+
if no limit is detected.
|
|
423
|
+
|
|
424
|
+
This incurs a small token cost for the probe request.
|
|
425
|
+
"""
|
|
426
|
+
model = self._kwargs.get("model")
|
|
427
|
+
if isinstance(model, ModelTier):
|
|
428
|
+
model = self.provider.resolve_model(model)
|
|
429
|
+
return self.provider.check_limit(model=model)
|
|
430
|
+
|
|
431
|
+
def interactive(self, initial_prompt: str, capture_bytes: int = 0, **kwargs: Any) -> InteractiveResult:
|
|
432
|
+
"""Launch the provider binary interactively with an initial prompt.
|
|
433
|
+
|
|
434
|
+
The user interacts with the agent directly in their terminal.
|
|
435
|
+
A copy of stdout is captured via a pty. MCP tool servers are
|
|
436
|
+
started before launch and stopped after the process exits.
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
initial_prompt:
|
|
441
|
+
The first message sent to the agent.
|
|
442
|
+
capture_bytes:
|
|
443
|
+
Maximum bytes of terminal output to keep (tail).
|
|
444
|
+
``0`` (default) means capture everything.
|
|
445
|
+
|
|
446
|
+
Returns an :class:`InteractiveResult` with the exit code and
|
|
447
|
+
captured terminal output.
|
|
448
|
+
"""
|
|
449
|
+
merged = {**self._kwargs, **kwargs}
|
|
450
|
+
|
|
451
|
+
# Remove session-only concerns
|
|
452
|
+
merged.pop("auto_wait", None)
|
|
453
|
+
merged.pop("metadata", None)
|
|
454
|
+
|
|
455
|
+
# Resolve model tier
|
|
456
|
+
model = merged.get("model")
|
|
457
|
+
if isinstance(model, ModelTier):
|
|
458
|
+
merged["model"] = self.provider.resolve_model(model)
|
|
459
|
+
|
|
460
|
+
# Resolve tool restrictions — default to ALL (user is present)
|
|
461
|
+
tools = merged.pop("tools", None)
|
|
462
|
+
if tools is not None:
|
|
463
|
+
restrictions = self.provider.resolve_tool_restrictions(tools)
|
|
464
|
+
merged.update(restrictions)
|
|
465
|
+
|
|
466
|
+
# Start MCP tool server handles
|
|
467
|
+
all_handles: list[Any] = list(self._tool_servers)
|
|
468
|
+
for handle in all_handles:
|
|
469
|
+
handle.start_sync()
|
|
470
|
+
|
|
471
|
+
all_mcp = list(self._mcp_servers)
|
|
472
|
+
for handle in all_handles:
|
|
473
|
+
all_mcp.append(MCPServer(name=handle.server_id, url=handle.url))
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
return self.provider.start_interactive(
|
|
477
|
+
initial_prompt, mcp_servers=all_mcp, capture_bytes=capture_bytes, **merged
|
|
478
|
+
)
|
|
479
|
+
finally:
|
|
480
|
+
for handle in all_handles:
|
|
481
|
+
try:
|
|
482
|
+
handle.stop_sync()
|
|
483
|
+
except Exception:
|
|
484
|
+
pass
|
|
485
|
+
|
|
486
|
+
def completion(self, message: str, **kwargs: Any) -> Trajectory:
|
|
487
|
+
"""Send a single message and return the complete trajectory.
|
|
488
|
+
|
|
489
|
+
Convenience wrapper for simple use cases where you don't need
|
|
490
|
+
to maintain a multi-turn session::
|
|
491
|
+
|
|
492
|
+
traj = agent.completion("Explain this code")
|
|
493
|
+
print(traj.result)
|
|
494
|
+
"""
|
|
495
|
+
session = self.start_session(**kwargs)
|
|
496
|
+
session.send(message)
|
|
497
|
+
return session.end()
|
|
498
|
+
|
|
499
|
+
def start_session(self, traj_path: str | Path | None = None, **kwargs: Any) -> Session:
|
|
500
|
+
"""Start a new interactive session with the agent.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
traj_path:
|
|
505
|
+
If set, the trajectory is saved to this path after each
|
|
506
|
+
step and when :meth:`Session.end` is called.
|
|
507
|
+
"""
|
|
508
|
+
merged = {**self._kwargs, **kwargs}
|
|
509
|
+
|
|
510
|
+
# Pop auto_wait and metadata — these are Session concerns, not provider kwargs
|
|
511
|
+
auto_wait = merged.pop("auto_wait", True)
|
|
512
|
+
session_metadata: dict[str, Any] = merged.pop("metadata", {})
|
|
513
|
+
# Agent-level metadata as base, session kwargs override
|
|
514
|
+
if self._metadata:
|
|
515
|
+
session_metadata = {**self._metadata, **session_metadata}
|
|
516
|
+
|
|
517
|
+
# Resolve model tier to concrete model string
|
|
518
|
+
model = merged.get("model")
|
|
519
|
+
if isinstance(model, ModelTier):
|
|
520
|
+
merged["model"] = self.provider.resolve_model(model)
|
|
521
|
+
|
|
522
|
+
# Resolve tool restrictions: default to ALL - INTERACTION for automated pipelines
|
|
523
|
+
tools = merged.pop("tools", None)
|
|
524
|
+
if tools is None:
|
|
525
|
+
tools = ToolGroup.ALL - ToolGroup.INTERACTION
|
|
526
|
+
restrictions = self.provider.resolve_tool_restrictions(tools)
|
|
527
|
+
merged.update(restrictions)
|
|
528
|
+
|
|
529
|
+
# Generate session_id early so the JSONL path is known before MCP configs
|
|
530
|
+
session_id: str | None = None
|
|
531
|
+
store: SessionStore | None = None
|
|
532
|
+
if self._data_dir:
|
|
533
|
+
session_id = str(uuid_mod.uuid4())
|
|
534
|
+
store = SessionStore(self._data_dir, session_id)
|
|
535
|
+
|
|
536
|
+
# Create temp dir for subagent trajectory files (if subagents exist)
|
|
537
|
+
subagent_traj_dir: str | None = None
|
|
538
|
+
if self._subagents:
|
|
539
|
+
subagent_traj_dir = tempfile.mkdtemp(prefix="caw_subagent_traj_")
|
|
540
|
+
|
|
541
|
+
# Collect all tool server handles (user-registered + subagent)
|
|
542
|
+
all_handles: list[Any] = list(self._tool_servers)
|
|
543
|
+
if self._subagents:
|
|
544
|
+
all_handles += self._subagent_tool_servers(
|
|
545
|
+
subagent_traj_dir, # type: ignore[arg-type]
|
|
546
|
+
jsonl_path=str(store.jsonl_path) if store else None,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
# Start all HTTP tool servers and collect their MCPServer configs
|
|
550
|
+
for handle in all_handles:
|
|
551
|
+
handle.start_sync()
|
|
552
|
+
|
|
553
|
+
all_mcp = list(self._mcp_servers)
|
|
554
|
+
for handle in all_handles:
|
|
555
|
+
all_mcp.append(MCPServer(name=handle.server_id, url=handle.url))
|
|
556
|
+
|
|
557
|
+
# Pass our session_id so the provider uses it (instead of generating its own)
|
|
558
|
+
if session_id:
|
|
559
|
+
merged["session_id"] = session_id
|
|
560
|
+
|
|
561
|
+
provider_session = self.provider.start_session(mcp_servers=all_mcp, **merged)
|
|
562
|
+
|
|
563
|
+
session = Session(
|
|
564
|
+
provider_session,
|
|
565
|
+
store=store,
|
|
566
|
+
subagent_traj_dir=subagent_traj_dir,
|
|
567
|
+
tool_handles=all_handles,
|
|
568
|
+
auto_wait=auto_wait,
|
|
569
|
+
metadata=session_metadata,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
if traj_path is not None:
|
|
573
|
+
session._traj_path = traj_path
|
|
574
|
+
|
|
575
|
+
if store:
|
|
576
|
+
store.write_metadata(session.trajectory)
|
|
577
|
+
|
|
578
|
+
return session
|