baserun-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
baserun_cli/channel.py ADDED
@@ -0,0 +1,350 @@
1
+ """Channel client: connects to nchan over WebSocket with HMAC auth + auto-reconnect.
2
+
3
+ Models on the fieldshortcut openclaw channel plugin (client.ts/gateway.ts):
4
+ - signed message = "{app_id}:{timestamp}", HMAC-SHA256 hex
5
+ - 10-minute timestamp window enforced server-side
6
+ - reconnect: 5s delay, infinite attempts
7
+ - receives `type: "task"` frames; dispatches to a handler callback
8
+
9
+ The client subscribes to its agent channel `agent:{app_id}`. To publish run
10
+ events back, it uses the same WS connection (nchan echoes publishes to
11
+ subscribers) OR a separate HTTP publisher to `run:{run_id}`.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import hashlib
17
+ import hmac
18
+ import json
19
+ import logging
20
+ import time
21
+ from typing import Any, Awaitable, Callable
22
+ from urllib.parse import urlparse
23
+
24
+ import httpx
25
+ import websockets
26
+ from websockets.asyncio.client import connect
27
+
28
+ log = logging.getLogger(__name__)
29
+
30
+ RECONNECT_DELAY = 5.0
31
+ MAX_RECONNECT_ATTEMPTS = 0 # 0 = infinite
32
+
33
+
34
+ def sign(app_secret: str, app_id: str, timestamp: int) -> str:
35
+ """HMAC-SHA256 hex of '{app_id}:{timestamp}'."""
36
+ msg = f"{app_id}:{timestamp}".encode()
37
+ return hmac.new(app_secret.encode(), msg, hashlib.sha256).hexdigest()
38
+
39
+
40
+ TaskHandler = Callable[[dict[str, Any]], Awaitable[None]]
41
+
42
+
43
+ class ChannelClient:
44
+ """Connects to nchan as a subscriber on `agent:{app_id}`, dispatches tasks."""
45
+
46
+ def __init__(
47
+ self,
48
+ nchan_url: str, # e.g. ws://localhost:9390 (or https→wss)
49
+ app_id: str,
50
+ app_secret: str,
51
+ on_task: TaskHandler | None = None,
52
+ claim_interval: float = 3.0, # HTTP poll interval for missed tasks
53
+ ) -> None:
54
+ self.nchan_url = nchan_url
55
+ self.app_id = app_id
56
+ self.app_secret = app_secret
57
+ self.on_task = on_task
58
+ self._stopped = False
59
+ self._claim_interval = claim_interval
60
+ self._ws_connected = False # tracks WS state for claim loop
61
+ # WS publishers for run:{run_id} events (one connection per run, ordered)
62
+ pub_base = nchan_url.replace("ws://", "http://").replace("wss://", "https://")
63
+ self._publish_base = pub_base
64
+ self._http = httpx.AsyncClient(timeout=10.0)
65
+ self._run_pubs: dict[str, Any] = {}
66
+
67
+ def _ws_url(self) -> str:
68
+ # nchan channel path: /agent/{app_id}
69
+ base = self.nchan_url
70
+ if base.startswith("https://"):
71
+ base = "wss://" + base[len("https://"):]
72
+ elif base.startswith("http://"):
73
+ base = "ws://" + base[len("http://"):]
74
+ return f"{base.rstrip('/')}/agent/{self.app_id}"
75
+
76
+ def _auth_headers(self) -> dict[str, str]:
77
+ ts = int(time.time())
78
+ return {
79
+ "X-Agent-App-Id": self.app_id,
80
+ "X-Agent-Timestamp": str(ts),
81
+ "X-Agent-Signature": sign(self.app_secret, self.app_id, ts),
82
+ }
83
+
84
+ async def ensure_run_pub(self, run_id: str) -> None:
85
+ """Pre-connect the WS publisher for a run (fire-and-forget).
86
+
87
+ Called at run start so the WS handshake completes before the first
88
+ event arrives. If this fails, publish_event will retry on demand.
89
+ """
90
+ if run_id in self._run_pubs:
91
+ return
92
+ try:
93
+ ws = await self._open_run_pub(run_id)
94
+ self._run_pubs[run_id] = ws
95
+ except Exception as e:
96
+ log.debug("pre-connect ws pub for run %s failed (will retry on first event): %s", run_id, e)
97
+
98
+ async def publish_event(self, run_id: str, payload: dict[str, Any]) -> bool:
99
+ """Publish a run event via WebSocket (guarantees ordering within a run).
100
+
101
+ All events (including terminal) go through the same WS connection to
102
+ preserve event order in the nchan channel buffer. Mixing WS and HTTP
103
+ can cause ordering issues because HTTP requests may be processed by
104
+ nchan at different times relative to buffered WS messages.
105
+
106
+ Returns True if published successfully.
107
+ """
108
+ data = json.dumps(payload, ensure_ascii=False)
109
+ is_terminal = (payload.get("data", {}).get("finished") is True)
110
+ max_retries = 4 if is_terminal else 2
111
+
112
+ for attempt in range(max_retries):
113
+ ws = self._run_pubs.get(run_id)
114
+ if ws is None:
115
+ try:
116
+ ws = await self._open_run_pub(run_id)
117
+ self._run_pubs[run_id] = ws
118
+ except Exception as e:
119
+ log.warning("ws pub connect for run %s failed (attempt %d): %s", run_id, attempt + 1, e)
120
+ if attempt < max_retries - 1:
121
+ await asyncio.sleep(1.0)
122
+ continue
123
+ try:
124
+ await ws.send(data)
125
+ return True
126
+ except Exception as e:
127
+ log.warning("ws publish to run %s failed (attempt %d): %s", run_id, attempt + 1, e)
128
+ self._run_pubs.pop(run_id, None)
129
+ if attempt < max_retries - 1:
130
+ await asyncio.sleep(1.0)
131
+
132
+ log.error("publish failed for run %s seq=%s after %d attempts",
133
+ run_id, payload.get("data", {}).get("seq"), max_retries)
134
+ return False
135
+
136
+ async def _open_run_pub(self, run_id: str) -> Any:
137
+ """Open a WS publisher connection for run:{run_id}."""
138
+ base = self.nchan_url
139
+ if base.startswith("https://"):
140
+ base = "wss://" + base[len("https://"):]
141
+ elif base.startswith("http://"):
142
+ base = "ws://" + base[len("http://"):]
143
+ url = f"{base.rstrip('/')}/internal/run/{run_id}/publish"
144
+ log.info("opening ws publisher for run %s", run_id)
145
+ return await connect(url, ping_interval=30, open_timeout=30)
146
+
147
+ async def _http_publish(self, run_id: str, payload: dict[str, Any]) -> None:
148
+ """HTTP fallback for publishing events (legacy, not used in normal flow)."""
149
+ url = f"{self._publish_base.rstrip('/')}/internal/run/{run_id}/publish"
150
+ resp = await self._http.post(url, content=json.dumps(payload, ensure_ascii=False))
151
+ resp.raise_for_status()
152
+
153
+ async def close_run_pub(self, run_id: str) -> None:
154
+ """Close the WS publisher for a run."""
155
+ ws = self._run_pubs.pop(run_id, None)
156
+ if ws is not None:
157
+ try:
158
+ await ws.close()
159
+ except Exception:
160
+ pass
161
+
162
+ async def verify_and_replay(self, run_id: str, local_events: list[dict]) -> None:
163
+ """对比 nchan channel 实际状态与本地 events,补发缺失的事件。
164
+
165
+ 1. 订阅 nchan channel,读取已有消息的 seq 集合
166
+ 2. 对比本地 JSONL 的 seq 集合
167
+ 3. 缺失的事件按 seq 顺序通过 WS 重新发送
168
+
169
+ 这比盲目重发 failed_events 更可靠——能捕获 ws.send() 返回成功
170
+ 但数据实际未到达 nchan 的情况(TCP buffer 问题)。
171
+ """
172
+ if not local_events:
173
+ return
174
+
175
+ # 读取 nchan channel 当前状态
176
+ url = f"{self._publish_base.rstrip('/')}/internal/run/{run_id}"
177
+ nchan_seqs: set[int] = set()
178
+ has_terminal = False
179
+ try:
180
+ async with httpx.AsyncClient(timeout=httpx.Timeout(10.0, connect=5.0)) as c:
181
+ async with c.stream("GET", url, headers={"Accept": "text/event-stream"}) as resp:
182
+ async for line in resp.aiter_lines():
183
+ if not line.startswith("data:"):
184
+ continue
185
+ raw = line[5:].strip()
186
+ if not raw:
187
+ continue
188
+ try:
189
+ msg = json.loads(raw)
190
+ except json.JSONDecodeError:
191
+ continue
192
+ data = msg.get("data") if isinstance(msg, dict) else msg
193
+ if not isinstance(data, dict):
194
+ continue
195
+ seq = data.get("seq", 0)
196
+ if seq:
197
+ nchan_seqs.add(seq)
198
+ if data.get("finished") and data.get("kind") in ("final", "error"):
199
+ has_terminal = True
200
+ break # terminal found = all events present
201
+ except Exception as e:
202
+ log.warning("verify: failed to read nchan for run %s: %s", run_id, e)
203
+ # nchan 读取失败,保守重发所有事件
204
+ nchan_seqs = set()
205
+
206
+ # terminal 已在 nchan 中 = 全部到齐
207
+ if has_terminal:
208
+ log.info("verify: run %s terminal present in nchan, all good", run_id)
209
+ return
210
+
211
+ # 找出缺失的事件
212
+ local_seqs = {ev.get("data", {}).get("seq", 0) for ev in local_events}
213
+ missing_seqs = local_seqs - nchan_seqs
214
+
215
+ if not missing_seqs and has_terminal:
216
+ return
217
+
218
+ if not missing_seqs:
219
+ # nchan 有所有非 terminal 事件但缺 terminal → 重发 terminal
220
+ missing_seqs = {ev.get("data", {}).get("seq", 0) for ev in local_events
221
+ if ev.get("data", {}).get("finished") is True} - nchan_seqs
222
+
223
+ if not missing_seqs:
224
+ log.info("verify: run %s no missing events", run_id)
225
+ return
226
+
227
+ # 按 seq 排序重发
228
+ missing_events = sorted(
229
+ [ev for ev in local_events if ev.get("data", {}).get("seq", 0) in missing_seqs],
230
+ key=lambda ev: ev.get("data", {}).get("seq", 0),
231
+ )
232
+ log.info(
233
+ "verify: run %s replaying %d missing events (seqs: %s, nchan has %d/%d)",
234
+ run_id, len(missing_events),
235
+ [ev["data"]["seq"] for ev in missing_events],
236
+ len(nchan_seqs), len(local_seqs),
237
+ )
238
+
239
+ for ev in missing_events:
240
+ ok = await self.publish_event(run_id, ev)
241
+ if not ok:
242
+ log.error("verify: replay still failed for run %s seq=%s", run_id, ev["data"]["seq"])
243
+
244
+ async def run(self) -> None:
245
+ """Main loop: WS subscriber + HTTP claim-task poller, running in parallel.
246
+
247
+ - WS subscriber: instant task delivery when connected (push path)
248
+ - HTTP poller: claims missed tasks every N seconds (pull path)
249
+ Both dispatch to the same handler; dedup is handled by TaskRunner.
250
+ """
251
+ await asyncio.gather(
252
+ self._ws_loop(),
253
+ self._claim_loop(),
254
+ )
255
+
256
+ async def _ws_loop(self) -> None:
257
+ """WS connect-dispatch-reconnect loop (push path)."""
258
+ attempts = 0
259
+ while not self._stopped:
260
+ try:
261
+ await self._connect_and_serve()
262
+ attempts = 0
263
+ except Exception as e:
264
+ log.error("channel connection error: %s", e)
265
+ if self._stopped:
266
+ break
267
+ if MAX_RECONNECT_ATTEMPTS > 0 and attempts >= MAX_RECONNECT_ATTEMPTS:
268
+ break
269
+ attempts += 1
270
+ log.info("reconnecting in %.0fs (attempt %d)...", RECONNECT_DELAY, attempts)
271
+ await asyncio.sleep(RECONNECT_DELAY)
272
+
273
+ async def _claim_once(self) -> None:
274
+ """Single claim sweep — used after WS reconnect to catch missed tasks."""
275
+ try:
276
+ url = f"{self._publish_base.rstrip('/')}/api/agent/{self.app_id}/claim-task"
277
+ resp = await self._http.post(url)
278
+ if resp.status_code == 200:
279
+ msg = resp.json()
280
+ if msg.get("type") == "task":
281
+ log.info("claimed missed task after reconnect: run %s", msg.get("data", {}).get("run_id"))
282
+ if self.on_task:
283
+ await self._safe_dispatch(msg)
284
+ # recurse to drain the queue
285
+ await self._claim_once()
286
+ except Exception as e:
287
+ log.debug("claim-once failed: %s", e)
288
+
289
+ async def _claim_loop(self) -> None:
290
+ """HTTP poll for missed tasks (pull path).
291
+
292
+ Only polls when WS is disconnected. On reconnect, does one immediate
293
+ claim sweep (in case tasks were queued during disconnection), then
294
+ goes idle while WS handles real-time delivery.
295
+ """
296
+ await asyncio.sleep(1.0) # initial delay (let WS connect first)
297
+ while not self._stopped:
298
+ if self._ws_connected:
299
+ # WS is connected — push path handles delivery, no need to poll
300
+ await asyncio.sleep(self._claim_interval)
301
+ continue
302
+
303
+ # WS disconnected — poll for queued tasks
304
+ try:
305
+ url = f"{self._publish_base.rstrip('/')}/api/agent/{self.app_id}/claim-task"
306
+ resp = await self._http.post(url)
307
+ if resp.status_code == 200:
308
+ msg = resp.json()
309
+ if msg.get("type") == "task":
310
+ log.info("claimed task via HTTP: run %s", msg.get("data", {}).get("run_id"))
311
+ if self.on_task:
312
+ await self._safe_dispatch(msg)
313
+ continue # immediately try again (might be more queued)
314
+ except Exception as e:
315
+ log.debug("claim-task poll failed: %s", e)
316
+ await asyncio.sleep(self._claim_interval)
317
+
318
+ async def _connect_and_serve(self) -> None:
319
+ url = self._ws_url()
320
+ headers = self._auth_headers()
321
+ log.info("connecting to %s as %s", url, self.app_id)
322
+ async with connect(url, additional_headers=headers, ping_interval=30) as ws:
323
+ self._ws_connected = True
324
+ log.info("connected; waiting for tasks")
325
+ # sweep for tasks queued during disconnection
326
+ asyncio.create_task(self._claim_once())
327
+ async for raw in ws:
328
+ try:
329
+ msg = json.loads(raw)
330
+ except (json.JSONDecodeError, TypeError):
331
+ continue
332
+ if msg.get("type") != "task":
333
+ continue
334
+ if self.on_task is None:
335
+ log.warning("received task but no handler bound; dropping")
336
+ continue
337
+ # dispatch without blocking the recv loop
338
+ asyncio.create_task(self._safe_dispatch(msg))
339
+ # WS dropped — reset flag so claim loop starts polling
340
+ self._ws_connected = False
341
+
342
+ async def _safe_dispatch(self, msg: dict[str, Any]) -> None:
343
+ try:
344
+ assert self.on_task is not None
345
+ await self.on_task(msg)
346
+ except Exception:
347
+ log.exception("task handler error for %s", msg.get("data", {}).get("run_id"))
348
+
349
+ def stop(self) -> None:
350
+ self._stopped = True
baserun_cli/main.py ADDED
@@ -0,0 +1,98 @@
1
+ """Agent client entry point.
2
+
3
+ Connects to the nchan bus as the agent's channel subscriber, receives task
4
+ envelopes, spawns the CLI agent (claude/codex/...), parses streaming output,
5
+ and publishes run events back to run:{run_id}.
6
+
7
+ Config via env:
8
+ NCHAN_URL e.g. ws://localhost:9390
9
+ AGENT_APP_ID channel credential (issued by the server)
10
+ AGENT_APP_SECRET HMAC secret
11
+ CONNECTOR_TYPE claude_code | codex | <custom>
12
+ CLI_CONFIG_* JSON override for the CLISpec (bin, workdir, env, ...)
13
+ CONCURRENCY how many runs to run in parallel (default 1)
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import json
19
+ import logging
20
+ import os
21
+ import sys
22
+
23
+ from .channel import ChannelClient
24
+ from .runner import TaskRunner
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
29
+ )
30
+ log = logging.getLogger("baserun-cli")
31
+
32
+
33
+ def _load_cli_config() -> dict:
34
+ """Read CLISpec overrides from CLI_CONFIG env (JSON) if present."""
35
+ raw = os.environ.get("CLI_CONFIG")
36
+ if not raw:
37
+ return {}
38
+ try:
39
+ return json.loads(raw)
40
+ except json.JSONDecodeError as e:
41
+ log.warning("CLI_CONFIG invalid JSON, ignoring: %s", e)
42
+ return {}
43
+
44
+
45
+ async def main() -> None:
46
+ nchan_url = os.environ.get("NCHAN_URL", "ws://localhost:9390")
47
+ app_id = os.environ.get("AGENT_APP_ID", "")
48
+ app_secret = os.environ.get("AGENT_APP_SECRET", "")
49
+ connector_type = os.environ.get("CONNECTOR_TYPE", "claude_code")
50
+ concurrency = int(os.environ.get("CONCURRENCY", "1"))
51
+
52
+ if not app_id or not app_secret:
53
+ log.error("AGENT_APP_ID / AGENT_APP_SECRET must be set")
54
+ sys.exit(1)
55
+
56
+ cli_config = _load_cli_config()
57
+
58
+ # channel created first (handler bound after runner exists), breaking the
59
+ # chicken-and-egg (channel needs on_task; on_task needs runner; runner needs channel).
60
+ channel = ChannelClient(nchan_url, app_id, app_secret)
61
+ runner = TaskRunner(channel, concurrency=concurrency, app_id=app_id)
62
+
63
+ async def on_task(envelope: dict) -> None:
64
+ # enrich the data with our connector_type + cli_config so the runner
65
+ # builds the right CLISpec
66
+ data = envelope.get("data") or envelope
67
+ data.setdefault("connector_type", connector_type)
68
+ # merge cli_config under config.cli_spec so resolve_spec picks it up
69
+ cfg = data.get("config") or {}
70
+ if cli_config:
71
+ cfg.setdefault("cli_spec", {}).update(cli_config.get("cli_spec", cli_config))
72
+ if workdir := cli_config.get("workdir"):
73
+ cfg.setdefault("workdir", workdir)
74
+ if bin_ := cli_config.get("bin"):
75
+ cfg.setdefault("bin", bin_)
76
+ data["config"] = cfg
77
+ envelope["data"] = data
78
+ await runner.handle(envelope)
79
+
80
+ channel.on_task = on_task
81
+
82
+ log.info(
83
+ "agent client starting: app_id=%s connector=%s concurrency=%d nchan=%s",
84
+ app_id, connector_type, concurrency, nchan_url,
85
+ )
86
+ await channel.run()
87
+
88
+
89
+ def main_sync() -> None:
90
+ """Synchronous entry point for the console script."""
91
+ try:
92
+ asyncio.run(main())
93
+ except KeyboardInterrupt:
94
+ log.info("shutting down")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main_sync()