operator-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- operator_agent/__init__.py +8 -0
- operator_agent/cli.py +594 -0
- operator_agent/config.py +64 -0
- operator_agent/core.py +376 -0
- operator_agent/providers/__init__.py +86 -0
- operator_agent/providers/claude.py +108 -0
- operator_agent/providers/codex.py +115 -0
- operator_agent/providers/gemini.py +95 -0
- operator_agent/transports/__init__.py +25 -0
- operator_agent/transports/telegram.py +305 -0
- operator_agent-0.1.0.dist-info/METADATA +191 -0
- operator_agent-0.1.0.dist-info/RECORD +16 -0
- operator_agent-0.1.0.dist-info/WHEEL +5 -0
- operator_agent-0.1.0.dist-info/entry_points.txt +2 -0
- operator_agent-0.1.0.dist-info/licenses/LICENSE +21 -0
- operator_agent-0.1.0.dist-info/top_level.txt +1 -0
operator_agent/core.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""Core runtime: state management, process spawning, and request handling."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import tempfile
|
|
11
|
+
from asyncio.subprocess import Process
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING, Any
|
|
14
|
+
|
|
15
|
+
from .config import CONFIG_DIR, STATE_FILE
|
|
16
|
+
from .providers import BaseProvider, get_provider
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .transports import TransportContext
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def split_text(text: str, limit: int = 4096) -> list[str]:
|
|
25
|
+
"""Split text at line boundaries for message size limits."""
|
|
26
|
+
if len(text) <= limit:
|
|
27
|
+
return [text]
|
|
28
|
+
|
|
29
|
+
chunks: list[str] = []
|
|
30
|
+
current = ""
|
|
31
|
+
for line in text.split("\n"):
|
|
32
|
+
if len(current) + len(line) + 1 > limit:
|
|
33
|
+
if current:
|
|
34
|
+
chunks.append(current)
|
|
35
|
+
remainder = line
|
|
36
|
+
while len(remainder) > limit:
|
|
37
|
+
chunks.append(remainder[:limit])
|
|
38
|
+
remainder = remainder[limit:]
|
|
39
|
+
current = remainder
|
|
40
|
+
else:
|
|
41
|
+
current = current + "\n" + line if current else line
|
|
42
|
+
if current:
|
|
43
|
+
chunks.append(current)
|
|
44
|
+
return chunks
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Runtime:
|
|
48
|
+
"""Central runtime holding all state and process management."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, config: dict):
|
|
51
|
+
self.config = config
|
|
52
|
+
self.working_dir: str = config.get("working_dir", str(Path.cwd()))
|
|
53
|
+
self.models: dict[str, list[str]] = {
|
|
54
|
+
name: pcfg.get("models", [])
|
|
55
|
+
for name, pcfg in config.get("providers", {}).items()
|
|
56
|
+
}
|
|
57
|
+
self.active_provider_by_chat: dict[int, str] = {}
|
|
58
|
+
self.active_model_by_chat_provider: dict[tuple[int, str], str] = {}
|
|
59
|
+
self.session_by_chat_provider: dict[tuple[int, str], str] = {}
|
|
60
|
+
self.running_process_by_chat: dict[int, Process] = {}
|
|
61
|
+
self.running_task_by_chat: dict[int, asyncio.Task] = {}
|
|
62
|
+
self.chat_lock_by_chat: dict[int, asyncio.Lock] = {}
|
|
63
|
+
|
|
64
|
+
def init_config_dir(self):
|
|
65
|
+
"""Ensure config directory exists."""
|
|
66
|
+
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
67
|
+
|
|
68
|
+
# --- State persistence ---
|
|
69
|
+
|
|
70
|
+
def save_state(self):
|
|
71
|
+
"""Atomically persist session state to disk."""
|
|
72
|
+
data = {
|
|
73
|
+
"active_provider_by_chat": {
|
|
74
|
+
str(k): v for k, v in self.active_provider_by_chat.items()
|
|
75
|
+
},
|
|
76
|
+
"active_model_by_chat_provider": {
|
|
77
|
+
f"{chat_id}:{provider}": model
|
|
78
|
+
for (chat_id, provider), model in self.active_model_by_chat_provider.items()
|
|
79
|
+
},
|
|
80
|
+
"session_by_chat_provider": {
|
|
81
|
+
f"{chat_id}:{provider}": sid
|
|
82
|
+
for (chat_id, provider), sid in self.session_by_chat_provider.items()
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
try:
|
|
86
|
+
fd, tmp = tempfile.mkstemp(dir=CONFIG_DIR, suffix=".tmp")
|
|
87
|
+
with os.fdopen(fd, "w") as f:
|
|
88
|
+
json.dump(data, f)
|
|
89
|
+
f.flush()
|
|
90
|
+
os.fsync(f.fileno())
|
|
91
|
+
os.replace(tmp, STATE_FILE)
|
|
92
|
+
log.debug(
|
|
93
|
+
"State saved: %d providers, %d sessions",
|
|
94
|
+
len(self.active_provider_by_chat),
|
|
95
|
+
len(self.session_by_chat_provider),
|
|
96
|
+
)
|
|
97
|
+
except Exception:
|
|
98
|
+
log.exception("Failed to save state")
|
|
99
|
+
|
|
100
|
+
def load_state(self):
|
|
101
|
+
"""Load persisted state from disk into memory maps."""
|
|
102
|
+
if not os.path.exists(STATE_FILE):
|
|
103
|
+
return
|
|
104
|
+
try:
|
|
105
|
+
with open(STATE_FILE) as f:
|
|
106
|
+
data = json.load(f)
|
|
107
|
+
for k, v in data.get("active_provider_by_chat", {}).items():
|
|
108
|
+
self.active_provider_by_chat[int(k)] = v
|
|
109
|
+
for k, v in data.get("active_model_by_chat_provider", {}).items():
|
|
110
|
+
chat_id_str, provider = k.split(":", 1)
|
|
111
|
+
self.active_model_by_chat_provider[(int(chat_id_str), provider)] = v
|
|
112
|
+
for k, v in data.get("session_by_chat_provider", {}).items():
|
|
113
|
+
chat_id_str, provider = k.split(":", 1)
|
|
114
|
+
self.session_by_chat_provider[(int(chat_id_str), provider)] = v
|
|
115
|
+
log.info(
|
|
116
|
+
"Loaded state: %d providers, %d sessions, %d model overrides",
|
|
117
|
+
len(self.active_provider_by_chat),
|
|
118
|
+
len(self.session_by_chat_provider),
|
|
119
|
+
len(self.active_model_by_chat_provider),
|
|
120
|
+
)
|
|
121
|
+
except Exception:
|
|
122
|
+
log.exception("Failed to load state, starting fresh")
|
|
123
|
+
|
|
124
|
+
# --- Accessors ---
|
|
125
|
+
|
|
126
|
+
def get_active_provider(self, chat_id: int) -> str:
|
|
127
|
+
"""Return active provider for chat, defaulting to claude."""
|
|
128
|
+
return self.active_provider_by_chat.get(chat_id, "claude")
|
|
129
|
+
|
|
130
|
+
def get_active_model(self, chat_id: int, provider: str) -> str:
|
|
131
|
+
"""Return active model for chat+provider, defaulting to first in list."""
|
|
132
|
+
stored = self.active_model_by_chat_provider.get((chat_id, provider))
|
|
133
|
+
if stored and stored in self.models.get(provider, []):
|
|
134
|
+
return stored
|
|
135
|
+
models = self.models.get(provider, [])
|
|
136
|
+
return models[0] if models else "default"
|
|
137
|
+
|
|
138
|
+
def get_chat_lock(self, chat_id: int) -> asyncio.Lock:
|
|
139
|
+
"""Get or create a per-chat lock to serialize requests."""
|
|
140
|
+
lock = self.chat_lock_by_chat.get(chat_id)
|
|
141
|
+
if lock is None:
|
|
142
|
+
lock = asyncio.Lock()
|
|
143
|
+
self.chat_lock_by_chat[chat_id] = lock
|
|
144
|
+
return lock
|
|
145
|
+
|
|
146
|
+
# --- Provider management ---
|
|
147
|
+
|
|
148
|
+
def _get_provider_path(self, provider_name: str) -> str:
|
|
149
|
+
providers_cfg = self.config.get("providers", {})
|
|
150
|
+
provider_cfg = providers_cfg.get(provider_name, {})
|
|
151
|
+
return provider_cfg.get("path", provider_name)
|
|
152
|
+
|
|
153
|
+
def make_provider(self, provider_name: str) -> BaseProvider:
|
|
154
|
+
"""Create a fresh provider instance."""
|
|
155
|
+
path = self._get_provider_path(provider_name)
|
|
156
|
+
return get_provider(provider_name, path)
|
|
157
|
+
|
|
158
|
+
# --- Process control ---
|
|
159
|
+
|
|
160
|
+
async def stop_chat(self, chat_id: int) -> tuple[bool, str | None]:
|
|
161
|
+
"""Stop running process/task for a chat. Returns (had_something, error_msg)."""
|
|
162
|
+
process = self.running_process_by_chat.get(chat_id)
|
|
163
|
+
task = self.running_task_by_chat.get(chat_id)
|
|
164
|
+
|
|
165
|
+
if process is None and task is None:
|
|
166
|
+
return False, None
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
if process and process.returncode is None:
|
|
170
|
+
process.terminate()
|
|
171
|
+
try:
|
|
172
|
+
await asyncio.wait_for(process.wait(), timeout=0.5)
|
|
173
|
+
except TimeoutError:
|
|
174
|
+
if process.returncode is None:
|
|
175
|
+
process.kill()
|
|
176
|
+
|
|
177
|
+
if task and not task.done():
|
|
178
|
+
task.cancel()
|
|
179
|
+
|
|
180
|
+
return True, None
|
|
181
|
+
except Exception as exc:
|
|
182
|
+
return True, str(exc)
|
|
183
|
+
|
|
184
|
+
# --- Core streaming ---
|
|
185
|
+
|
|
186
|
+
async def run_provider(
|
|
187
|
+
self, provider: BaseProvider, prompt: str, chat_id: int, model: str
|
|
188
|
+
):
|
|
189
|
+
"""Spawn a provider subprocess and yield StreamEvents."""
|
|
190
|
+
session_id = self.session_by_chat_provider.get((chat_id, provider.name))
|
|
191
|
+
cmd = provider.build_command(prompt, model, session_id)
|
|
192
|
+
|
|
193
|
+
log.info("[%s] Spawning: %s", provider.name, " ".join(cmd[:6]) + " ...")
|
|
194
|
+
|
|
195
|
+
stderr = (
|
|
196
|
+
asyncio.subprocess.STDOUT
|
|
197
|
+
if provider.stderr_to_stdout()
|
|
198
|
+
else asyncio.subprocess.PIPE
|
|
199
|
+
)
|
|
200
|
+
kwargs: dict[str, Any] = {
|
|
201
|
+
"stdout": asyncio.subprocess.PIPE,
|
|
202
|
+
"stderr": stderr,
|
|
203
|
+
"cwd": self.working_dir,
|
|
204
|
+
}
|
|
205
|
+
limit = provider.stdout_limit()
|
|
206
|
+
if limit:
|
|
207
|
+
kwargs["limit"] = limit
|
|
208
|
+
|
|
209
|
+
process = await asyncio.create_subprocess_exec(*cmd, **kwargs)
|
|
210
|
+
log.info("[%s] Process started, pid=%s", provider.name, process.pid)
|
|
211
|
+
self.running_process_by_chat[chat_id] = process
|
|
212
|
+
|
|
213
|
+
event_count = 0
|
|
214
|
+
try:
|
|
215
|
+
assert process.stdout is not None
|
|
216
|
+
async for line in process.stdout:
|
|
217
|
+
decoded = line.decode(errors="replace").strip()
|
|
218
|
+
if not decoded:
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
raw = provider.parse_line(decoded)
|
|
222
|
+
if raw is None:
|
|
223
|
+
log.debug("[%s] Skipped line: %s", provider.name, decoded[:200])
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
event_count += 1
|
|
227
|
+
log.debug(
|
|
228
|
+
"[%s] Event #%d type=%s",
|
|
229
|
+
provider.name,
|
|
230
|
+
event_count,
|
|
231
|
+
raw.get("type", "?"),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
for stream_event in provider.parse_event(raw):
|
|
235
|
+
if stream_event.kind == "session" and stream_event.session_id:
|
|
236
|
+
log.info(
|
|
237
|
+
"[%s] New session: %s",
|
|
238
|
+
provider.name,
|
|
239
|
+
stream_event.session_id,
|
|
240
|
+
)
|
|
241
|
+
self.session_by_chat_provider[
|
|
242
|
+
(chat_id, provider.name)
|
|
243
|
+
] = stream_event.session_id
|
|
244
|
+
self.save_state()
|
|
245
|
+
yield stream_event
|
|
246
|
+
|
|
247
|
+
finally:
|
|
248
|
+
rc = process.returncode
|
|
249
|
+
log.info(
|
|
250
|
+
"[%s] Process pid=%s finished, returncode=%s, events=%d",
|
|
251
|
+
provider.name,
|
|
252
|
+
process.pid,
|
|
253
|
+
rc,
|
|
254
|
+
event_count,
|
|
255
|
+
)
|
|
256
|
+
if rc and rc != 0 and not provider.stderr_to_stdout() and process.stderr:
|
|
257
|
+
stderr_data = await process.stderr.read()
|
|
258
|
+
if stderr_data:
|
|
259
|
+
stderr_text = stderr_data.decode(errors="replace").strip()[:2000]
|
|
260
|
+
log.error("[%s] stderr: %s", provider.name, stderr_text)
|
|
261
|
+
current = self.running_process_by_chat.get(chat_id)
|
|
262
|
+
if current is process:
|
|
263
|
+
self.running_process_by_chat.pop(chat_id, None)
|
|
264
|
+
|
|
265
|
+
# --- Request handling ---
|
|
266
|
+
|
|
267
|
+
async def process_request(
|
|
268
|
+
self,
|
|
269
|
+
provider_name: str,
|
|
270
|
+
prompt: str,
|
|
271
|
+
chat_id: int,
|
|
272
|
+
ctx: TransportContext,
|
|
273
|
+
):
|
|
274
|
+
"""Run a full provider request with status ticker and response delivery."""
|
|
275
|
+
provider = self.make_provider(provider_name)
|
|
276
|
+
model = self.get_active_model(chat_id, provider_name)
|
|
277
|
+
display_name = provider_name.capitalize()
|
|
278
|
+
|
|
279
|
+
log.info(
|
|
280
|
+
"[%s] Request from chat_id=%s model=%s: %.80s",
|
|
281
|
+
provider_name,
|
|
282
|
+
chat_id,
|
|
283
|
+
model,
|
|
284
|
+
prompt,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
prefix_base = f"{display_name}/{model}"
|
|
288
|
+
status_msg = await ctx.reply_status(f"[{prefix_base} 0s] Working...")
|
|
289
|
+
state = {"status": "Working...", "elapsed": 0, "provider": prefix_base}
|
|
290
|
+
stop = asyncio.Event()
|
|
291
|
+
ticker = asyncio.create_task(self._run_ticker(ctx, status_msg, state, stop))
|
|
292
|
+
|
|
293
|
+
response_parts: list[str] = []
|
|
294
|
+
error_text = ""
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
async for event in self.run_provider(provider, prompt, chat_id, model):
|
|
298
|
+
if event.kind == "status":
|
|
299
|
+
state["status"] = event.text
|
|
300
|
+
elif event.kind == "response":
|
|
301
|
+
response_parts.append(event.text)
|
|
302
|
+
elif event.kind == "error":
|
|
303
|
+
error_text = event.text
|
|
304
|
+
|
|
305
|
+
stop.set()
|
|
306
|
+
ticker.cancel()
|
|
307
|
+
|
|
308
|
+
await ctx.delete_status(status_msg)
|
|
309
|
+
|
|
310
|
+
response_text = provider.format_response(response_parts)
|
|
311
|
+
prefix = f"[{prefix_base} {state['elapsed']}s]"
|
|
312
|
+
|
|
313
|
+
if response_text:
|
|
314
|
+
log.info(
|
|
315
|
+
"[%s] Response (%d chars, %ds)",
|
|
316
|
+
provider_name,
|
|
317
|
+
len(response_text),
|
|
318
|
+
state["elapsed"],
|
|
319
|
+
)
|
|
320
|
+
chunks = split_text(response_text)
|
|
321
|
+
chunks[0] = f"{prefix} {chunks[0]}"
|
|
322
|
+
for chunk in chunks:
|
|
323
|
+
await ctx.reply(chunk)
|
|
324
|
+
elif error_text:
|
|
325
|
+
log.warning(
|
|
326
|
+
"[%s] Failed after %ds: %s",
|
|
327
|
+
provider_name,
|
|
328
|
+
state["elapsed"],
|
|
329
|
+
error_text[:200],
|
|
330
|
+
)
|
|
331
|
+
await ctx.reply(f"{prefix} Error: {error_text[:4000]}")
|
|
332
|
+
else:
|
|
333
|
+
log.warning(
|
|
334
|
+
"[%s] Empty response after %ds", provider_name, state["elapsed"]
|
|
335
|
+
)
|
|
336
|
+
await ctx.reply(f"{prefix} (No response)")
|
|
337
|
+
|
|
338
|
+
except asyncio.CancelledError:
|
|
339
|
+
stop.set()
|
|
340
|
+
ticker.cancel()
|
|
341
|
+
with contextlib.suppress(Exception):
|
|
342
|
+
await ctx.edit_status(status_msg, f"[{prefix_base}] Stopped.")
|
|
343
|
+
raise
|
|
344
|
+
|
|
345
|
+
except Exception as exc:
|
|
346
|
+
stop.set()
|
|
347
|
+
ticker.cancel()
|
|
348
|
+
log.error(
|
|
349
|
+
"Error processing %s request: %s", provider_name, exc, exc_info=True
|
|
350
|
+
)
|
|
351
|
+
error_msg = str(exc)
|
|
352
|
+
prefix = f"[{prefix_base} {state['elapsed']}s]"
|
|
353
|
+
try:
|
|
354
|
+
await ctx.edit_status(
|
|
355
|
+
status_msg, f"{prefix} Error: {error_msg[:4000]}"
|
|
356
|
+
)
|
|
357
|
+
except Exception:
|
|
358
|
+
for chunk in split_text(f"{prefix} Error: {error_msg}"):
|
|
359
|
+
await ctx.reply(chunk)
|
|
360
|
+
|
|
361
|
+
async def _run_ticker(
|
|
362
|
+
self,
|
|
363
|
+
ctx: TransportContext,
|
|
364
|
+
status_msg: Any,
|
|
365
|
+
state: dict,
|
|
366
|
+
stop: asyncio.Event,
|
|
367
|
+
):
|
|
368
|
+
"""Background task that updates [Provider Xs] status."""
|
|
369
|
+
while not stop.is_set():
|
|
370
|
+
await asyncio.sleep(1)
|
|
371
|
+
if stop.is_set():
|
|
372
|
+
break
|
|
373
|
+
state["elapsed"] += 1
|
|
374
|
+
text = f"[{state['provider']} {state['elapsed']}s] {state['status']}"
|
|
375
|
+
with contextlib.suppress(Exception):
|
|
376
|
+
await asyncio.wait_for(ctx.edit_status(status_msg, text), timeout=5.0)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Provider abstraction for CLI agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class StreamEvent:
|
|
16
|
+
"""Unified event type emitted by all providers."""
|
|
17
|
+
|
|
18
|
+
kind: Literal["status", "response", "session", "error"]
|
|
19
|
+
text: str = ""
|
|
20
|
+
session_id: str | None = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseProvider(ABC):
|
|
24
|
+
"""Base class for CLI agent providers."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
|
|
28
|
+
def __init__(self, path: str):
|
|
29
|
+
self.path = path
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def build_command(
|
|
33
|
+
self, prompt: str, model: str, session_id: str | None = None
|
|
34
|
+
) -> list[str]:
|
|
35
|
+
"""Build the CLI command to execute."""
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def parse_event(self, event: dict) -> list[StreamEvent]:
|
|
39
|
+
"""Parse a raw JSON event into StreamEvents."""
|
|
40
|
+
|
|
41
|
+
def parse_line(self, line: str) -> dict | None:
|
|
42
|
+
"""Parse a raw stdout line into a JSON dict. Returns None to skip."""
|
|
43
|
+
stripped = line.strip()
|
|
44
|
+
if not stripped:
|
|
45
|
+
return None
|
|
46
|
+
try:
|
|
47
|
+
return json.loads(stripped)
|
|
48
|
+
except json.JSONDecodeError:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
def stderr_to_stdout(self) -> bool:
|
|
52
|
+
"""If True, merge stderr into stdout."""
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
def stdout_limit(self) -> int | None:
|
|
56
|
+
"""Buffer limit for stdout, or None for default."""
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
def format_response(self, parts: list[str]) -> str:
|
|
60
|
+
"""Combine response parts into final text. Default: last part wins."""
|
|
61
|
+
return parts[-1] if parts else ""
|
|
62
|
+
|
|
63
|
+
def clear_session(self, session_id: str | None, working_dir: str) -> str:
|
|
64
|
+
"""Clear session data. Returns human-readable summary."""
|
|
65
|
+
return "session cleared" if session_id else "no session"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
PROVIDER_NAMES = ["claude", "codex", "gemini"]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_provider(name: str, path: str) -> BaseProvider:
|
|
72
|
+
"""Create a provider instance by name."""
|
|
73
|
+
# Lazy imports: subclasses import from this module, so top-level would be circular.
|
|
74
|
+
from .claude import ClaudeProvider
|
|
75
|
+
from .codex import CodexProvider
|
|
76
|
+
from .gemini import GeminiProvider
|
|
77
|
+
|
|
78
|
+
classes: dict[str, type[BaseProvider]] = {
|
|
79
|
+
"claude": ClaudeProvider,
|
|
80
|
+
"codex": CodexProvider,
|
|
81
|
+
"gemini": GeminiProvider,
|
|
82
|
+
}
|
|
83
|
+
cls = classes.get(name)
|
|
84
|
+
if cls is None:
|
|
85
|
+
raise ValueError(f"Unknown provider: {name}")
|
|
86
|
+
return cls(path)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Claude CLI provider."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from . import BaseProvider, StreamEvent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _format_tool_status(tool_name: str, tool_input: dict) -> str:
|
|
12
|
+
"""Format tool usage into human-readable status."""
|
|
13
|
+
match tool_name:
|
|
14
|
+
case "Read":
|
|
15
|
+
path = tool_input.get("file_path", "file")
|
|
16
|
+
return f"Reading {os.path.basename(path)}..."
|
|
17
|
+
case "Write":
|
|
18
|
+
path = tool_input.get("file_path", "file")
|
|
19
|
+
return f"Writing {os.path.basename(path)}..."
|
|
20
|
+
case "Edit":
|
|
21
|
+
path = tool_input.get("file_path", "file")
|
|
22
|
+
return f"Editing {os.path.basename(path)}..."
|
|
23
|
+
case "Bash":
|
|
24
|
+
cmd = tool_input.get("command", "")
|
|
25
|
+
short = cmd[:50] + "..." if len(cmd) > 50 else cmd
|
|
26
|
+
return f"Running {short}"
|
|
27
|
+
case "Glob":
|
|
28
|
+
pattern = tool_input.get("pattern", "")
|
|
29
|
+
return f"Finding {pattern}..."
|
|
30
|
+
case "Grep":
|
|
31
|
+
pattern = tool_input.get("pattern", "")
|
|
32
|
+
return f"Searching for {pattern}..."
|
|
33
|
+
case "WebFetch":
|
|
34
|
+
url = tool_input.get("url", "")
|
|
35
|
+
return f"Fetching {url[:40]}..."
|
|
36
|
+
case "WebSearch":
|
|
37
|
+
query = tool_input.get("query", "")
|
|
38
|
+
return f"Searching: {query}..."
|
|
39
|
+
case "Task":
|
|
40
|
+
return "Running subagent..."
|
|
41
|
+
case _:
|
|
42
|
+
return f"Using {tool_name}..."
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_project_dir(working_dir: str) -> str:
|
|
46
|
+
"""Derive Claude's project session directory from working_dir."""
|
|
47
|
+
resolved = str(Path(working_dir).resolve())
|
|
48
|
+
mangled = resolved.replace("/", "-")
|
|
49
|
+
return os.path.expanduser(f"~/.claude/projects/{mangled}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ClaudeProvider(BaseProvider):
|
|
53
|
+
name = "claude"
|
|
54
|
+
|
|
55
|
+
def build_command(self, prompt, model, session_id=None):
|
|
56
|
+
return [
|
|
57
|
+
self.path,
|
|
58
|
+
"-p",
|
|
59
|
+
"--continue",
|
|
60
|
+
"--output-format",
|
|
61
|
+
"stream-json",
|
|
62
|
+
"--verbose",
|
|
63
|
+
"--dangerously-skip-permissions",
|
|
64
|
+
"--model",
|
|
65
|
+
model,
|
|
66
|
+
prompt,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
def parse_event(self, event):
|
|
70
|
+
events = []
|
|
71
|
+
event_type = event.get("type", "")
|
|
72
|
+
|
|
73
|
+
if event_type == "assistant":
|
|
74
|
+
message = event.get("message", {})
|
|
75
|
+
content = message.get("content", [])
|
|
76
|
+
for block in content:
|
|
77
|
+
if block.get("type") == "tool_use":
|
|
78
|
+
events.append(
|
|
79
|
+
StreamEvent(
|
|
80
|
+
kind="status",
|
|
81
|
+
text=_format_tool_status(
|
|
82
|
+
block.get("name", ""), block.get("input", {})
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
elif block.get("type") == "text":
|
|
87
|
+
block_text = block.get("text", "")
|
|
88
|
+
if block_text:
|
|
89
|
+
events.append(StreamEvent(kind="response", text=block_text))
|
|
90
|
+
|
|
91
|
+
elif event_type == "result":
|
|
92
|
+
result_text = event.get("result", "")
|
|
93
|
+
if isinstance(result_text, str) and result_text:
|
|
94
|
+
events.append(StreamEvent(kind="response", text=result_text))
|
|
95
|
+
|
|
96
|
+
return events
|
|
97
|
+
|
|
98
|
+
def stdout_limit(self):
|
|
99
|
+
return 10 * 1024 * 1024
|
|
100
|
+
|
|
101
|
+
def clear_session(self, session_id, working_dir):
|
|
102
|
+
project_dir = Path(_get_project_dir(working_dir))
|
|
103
|
+
removed = 0
|
|
104
|
+
if project_dir.is_dir():
|
|
105
|
+
for f in project_dir.glob("*.jsonl"):
|
|
106
|
+
f.unlink()
|
|
107
|
+
removed += 1
|
|
108
|
+
return f"{removed} session file{'s' if removed != 1 else ''} deleted"
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Codex CLI provider."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
from . import BaseProvider, StreamEvent
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _format_status(event: dict) -> str | None:
|
|
11
|
+
"""Extract human-readable status from a Codex JSON event."""
|
|
12
|
+
event_type = event.get("type", "")
|
|
13
|
+
item = event.get("item", {})
|
|
14
|
+
item_type = item.get("type", "")
|
|
15
|
+
|
|
16
|
+
if event_type == "item.started":
|
|
17
|
+
if item_type == "command_execution":
|
|
18
|
+
cmd = item.get("command", "")
|
|
19
|
+
if "-lc " in cmd:
|
|
20
|
+
cmd = cmd.split("-lc ", 1)[1].strip("'\"")
|
|
21
|
+
short = cmd[:50] + "..." if len(cmd) > 50 else cmd
|
|
22
|
+
return f"Running {short}"
|
|
23
|
+
if item_type == "reasoning":
|
|
24
|
+
return "Thinking..."
|
|
25
|
+
if item_type == "file_changes":
|
|
26
|
+
return "Editing files..."
|
|
27
|
+
if item_type == "web_searches":
|
|
28
|
+
return "Searching the web..."
|
|
29
|
+
if item_type == "mcp_tool_calls":
|
|
30
|
+
return "Using tool..."
|
|
31
|
+
|
|
32
|
+
if event_type == "item.completed" and item_type == "reasoning":
|
|
33
|
+
text = item.get("text", "")
|
|
34
|
+
if text:
|
|
35
|
+
clean = text.strip("*").strip()
|
|
36
|
+
short = clean[:40] + "..." if len(clean) > 40 else clean
|
|
37
|
+
return short
|
|
38
|
+
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CodexProvider(BaseProvider):
|
|
43
|
+
name = "codex"
|
|
44
|
+
|
|
45
|
+
def build_command(self, prompt, model, session_id=None):
|
|
46
|
+
if session_id:
|
|
47
|
+
return [
|
|
48
|
+
self.path,
|
|
49
|
+
"exec",
|
|
50
|
+
"resume",
|
|
51
|
+
"--dangerously-bypass-approvals-and-sandbox",
|
|
52
|
+
"--json",
|
|
53
|
+
"-m",
|
|
54
|
+
model,
|
|
55
|
+
session_id,
|
|
56
|
+
prompt,
|
|
57
|
+
]
|
|
58
|
+
return [
|
|
59
|
+
self.path,
|
|
60
|
+
"exec",
|
|
61
|
+
"--dangerously-bypass-approvals-and-sandbox",
|
|
62
|
+
"--json",
|
|
63
|
+
"-m",
|
|
64
|
+
model,
|
|
65
|
+
prompt,
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
def parse_line(self, line):
|
|
69
|
+
stripped = line.strip()
|
|
70
|
+
if not stripped or not stripped.startswith("{"):
|
|
71
|
+
return None
|
|
72
|
+
try:
|
|
73
|
+
return json.loads(stripped)
|
|
74
|
+
except json.JSONDecodeError:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
def parse_event(self, event):
|
|
78
|
+
events = []
|
|
79
|
+
event_type = event.get("type", "")
|
|
80
|
+
|
|
81
|
+
status = _format_status(event)
|
|
82
|
+
if status:
|
|
83
|
+
events.append(StreamEvent(kind="status", text=status))
|
|
84
|
+
|
|
85
|
+
if event_type == "turn.started":
|
|
86
|
+
events.append(StreamEvent(kind="status", text="Working..."))
|
|
87
|
+
|
|
88
|
+
elif event_type == "error":
|
|
89
|
+
msg = event.get("message")
|
|
90
|
+
if isinstance(msg, str) and msg:
|
|
91
|
+
events.append(StreamEvent(kind="error", text=msg))
|
|
92
|
+
if "reconnect" in msg.lower():
|
|
93
|
+
events.append(StreamEvent(kind="status", text="Reconnecting..."))
|
|
94
|
+
|
|
95
|
+
elif event_type == "item.completed":
|
|
96
|
+
item = event.get("item", {})
|
|
97
|
+
if item.get("type") == "agent_message":
|
|
98
|
+
text_value = item.get("text")
|
|
99
|
+
if isinstance(text_value, str) and text_value:
|
|
100
|
+
events.append(StreamEvent(kind="response", text=text_value))
|
|
101
|
+
|
|
102
|
+
elif event_type == "turn.failed":
|
|
103
|
+
msg = event.get("message", "")
|
|
104
|
+
if msg:
|
|
105
|
+
events.append(StreamEvent(kind="error", text=msg))
|
|
106
|
+
|
|
107
|
+
elif event_type == "thread.started":
|
|
108
|
+
thread_id = event.get("thread_id")
|
|
109
|
+
if isinstance(thread_id, str) and thread_id:
|
|
110
|
+
events.append(StreamEvent(kind="session", session_id=thread_id))
|
|
111
|
+
|
|
112
|
+
return events
|
|
113
|
+
|
|
114
|
+
def stderr_to_stdout(self):
|
|
115
|
+
return True
|