agentshive-sdk 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentshive/__init__.py +3 -0
- agentshive/api.py +352 -0
- agentshive/backends/__init__.py +30 -0
- agentshive/backends/base.py +207 -0
- agentshive/backends/claude_code.py +57 -0
- agentshive/backends/codex_cli.py +100 -0
- agentshive/backends/gemini_cli.py +54 -0
- agentshive/backends/openclaw.py +100 -0
- agentshive/backends/opencode.py +72 -0
- agentshive/backends/registry.py +35 -0
- agentshive/cli.py +66 -0
- agentshive/client.py +226 -0
- agentshive/commands/__init__.py +1 -0
- agentshive/commands/doctor.py +112 -0
- agentshive/commands/setup.py +126 -0
- agentshive/commands/start.py +159 -0
- agentshive/commands/status.py +73 -0
- agentshive/commands/stop.py +54 -0
- agentshive/config.py +101 -0
- agentshive/local_runner.py +576 -0
- agentshive/memory.py +110 -0
- agentshive/models.py +124 -0
- agentshive/rate_limit.py +130 -0
- agentshive/ws.py +98 -0
- agentshive_sdk-0.2.0.dist-info/METADATA +308 -0
- agentshive_sdk-0.2.0.dist-info/RECORD +30 -0
- agentshive_sdk-0.2.0.dist-info/WHEEL +5 -0
- agentshive_sdk-0.2.0.dist-info/entry_points.txt +2 -0
- agentshive_sdk-0.2.0.dist-info/licenses/LICENSE +21 -0
- agentshive_sdk-0.2.0.dist-info/top_level.txt +1 -0
agentshive/__init__.py
ADDED
agentshive/api.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""REST client for the AgentsHive server API."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import mimetypes
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from agentshive.backends.base import IMAGE_EXTENSIONS
|
|
12
|
+
from agentshive.models import Message
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Fallback MIME types for extensions that mimetypes may not know.
|
|
17
|
+
_MIME_FALLBACK = {
|
|
18
|
+
".apng": "image/apng",
|
|
19
|
+
".avif": "image/avif",
|
|
20
|
+
".heic": "image/heic",
|
|
21
|
+
".heif": "image/heif",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _mime_from_ext(ext: str) -> str:
|
|
26
|
+
"""Derive MIME type from extension. Only allows extensions in IMAGE_EXTENSIONS."""
|
|
27
|
+
if ext not in IMAGE_EXTENSIONS:
|
|
28
|
+
return "application/octet-stream"
|
|
29
|
+
mime, _ = mimetypes.guess_type(f"file{ext}")
|
|
30
|
+
return mime or _MIME_FALLBACK.get(ext, "application/octet-stream")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AgentAPI:
|
|
34
|
+
"""Async REST client for the AgentsHive server API.
|
|
35
|
+
|
|
36
|
+
Used internally by :class:`~agenthive.client.BackendClient` and also
|
|
37
|
+
available directly via ``client.api`` for custom API calls.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, server_url: str, token: str):
|
|
41
|
+
"""Create a new API client.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
server_url: Base URL of the AgentsHive server.
|
|
45
|
+
token: Backend machine token (``ahv_...``).
|
|
46
|
+
"""
|
|
47
|
+
self._base = server_url.rstrip("/")
|
|
48
|
+
self._headers = {"Authorization": f"Bearer {token}"}
|
|
49
|
+
self._client = httpx.AsyncClient(base_url=self._base, headers=self._headers, timeout=30.0)
|
|
50
|
+
|
|
51
|
+
async def close(self):
|
|
52
|
+
"""Close the underlying HTTP client."""
|
|
53
|
+
await self._client.aclose()
|
|
54
|
+
|
|
55
|
+
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
|
56
|
+
"""Issue an HTTP request with debug logging."""
|
|
57
|
+
logger.debug("API %s %s", method, path)
|
|
58
|
+
resp = await self._client.request(method, path, **kwargs)
|
|
59
|
+
logger.debug("API %s %s -> %d (%d bytes)", method, path, resp.status_code, len(resp.content))
|
|
60
|
+
return resp
|
|
61
|
+
|
|
62
|
+
async def get_me(self) -> dict:
|
|
63
|
+
"""Return the authenticated backend machine's metadata."""
|
|
64
|
+
resp = await self._request("GET", "/api/backends/me")
|
|
65
|
+
resp.raise_for_status()
|
|
66
|
+
return resp.json()
|
|
67
|
+
|
|
68
|
+
async def get_tasks(self, status: str = "queued") -> list[dict]:
|
|
69
|
+
"""List tasks assigned to this backend, filtered by status."""
|
|
70
|
+
resp = await self._request("GET", "/api/backends/me/tasks", params={"status": status})
|
|
71
|
+
resp.raise_for_status()
|
|
72
|
+
return resp.json()
|
|
73
|
+
|
|
74
|
+
async def ack_task(self, task_id: str) -> dict:
|
|
75
|
+
"""Acknowledge a task (transition from queued to processing)."""
|
|
76
|
+
resp = await self._request("POST", f"/api/backends/me/tasks/{task_id}/ack")
|
|
77
|
+
resp.raise_for_status()
|
|
78
|
+
return resp.json()
|
|
79
|
+
|
|
80
|
+
async def complete_task(self, task_id: str) -> dict:
|
|
81
|
+
"""Mark a task as successfully completed."""
|
|
82
|
+
resp = await self._request("POST", f"/api/backends/me/tasks/{task_id}/complete")
|
|
83
|
+
resp.raise_for_status()
|
|
84
|
+
return resp.json()
|
|
85
|
+
|
|
86
|
+
async def fail_task(self, task_id: str, error: str = "") -> dict:
|
|
87
|
+
"""Mark a task as failed with an optional error message."""
|
|
88
|
+
resp = await self._request("POST", f"/api/backends/me/tasks/{task_id}/fail", json={"error": error})
|
|
89
|
+
resp.raise_for_status()
|
|
90
|
+
return resp.json()
|
|
91
|
+
|
|
92
|
+
async def get_messages(self, thread_id: str, limit: int = 50, after: str | None = None) -> list[Message]:
|
|
93
|
+
"""Fetch messages from a thread, most recent last.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
thread_id: Thread to fetch messages from.
|
|
97
|
+
limit: Maximum number of messages to return.
|
|
98
|
+
after: Only return messages created after this message ID.
|
|
99
|
+
"""
|
|
100
|
+
params = {}
|
|
101
|
+
if after:
|
|
102
|
+
params["after"] = after
|
|
103
|
+
resp = await self._request("GET", f"/api/threads/{thread_id}/messages", params=params)
|
|
104
|
+
resp.raise_for_status()
|
|
105
|
+
return [Message.from_dict(m, api=self) for m in resp.json()[-limit:]]
|
|
106
|
+
|
|
107
|
+
async def post_message(self, thread_id: str, text: str, agent_handle: str = "", mentions: list[str] | None = None, quoted_message_id: str | None = None, task_id: str | None = None) -> Message | None:
|
|
108
|
+
"""Post a message to a thread on behalf of an agent.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
thread_id: Target thread.
|
|
112
|
+
text: Message body (supports markdown).
|
|
113
|
+
agent_handle: The agent posting this message.
|
|
114
|
+
mentions: Explicit @mentions to include (triggers those agents).
|
|
115
|
+
quoted_message_id: ID of a message to quote.
|
|
116
|
+
task_id: The dispatch task this reply is for (links message to batch).
|
|
117
|
+
"""
|
|
118
|
+
data = {"text": text}
|
|
119
|
+
if agent_handle:
|
|
120
|
+
data["agent_handle"] = agent_handle
|
|
121
|
+
if mentions:
|
|
122
|
+
data["explicit_mentions"] = json.dumps(mentions)
|
|
123
|
+
if quoted_message_id:
|
|
124
|
+
data["quoted_message_id"] = quoted_message_id
|
|
125
|
+
if task_id:
|
|
126
|
+
data["task_id"] = task_id
|
|
127
|
+
resp = await self._request("POST", f"/api/threads/{thread_id}/messages", data=data)
|
|
128
|
+
resp.raise_for_status()
|
|
129
|
+
body = resp.json()
|
|
130
|
+
# Server may filter non-substantive agent replies (soft-PASSes)
|
|
131
|
+
if body.get("filtered"):
|
|
132
|
+
return None
|
|
133
|
+
return Message.from_dict(body, api=self)
|
|
134
|
+
|
|
135
|
+
async def post_message_with_files(
|
|
136
|
+
self,
|
|
137
|
+
thread_id: str,
|
|
138
|
+
text: str,
|
|
139
|
+
file_paths: list[str],
|
|
140
|
+
agent_handle: str = "",
|
|
141
|
+
mentions: list[str] | None = None,
|
|
142
|
+
quoted_message_id: str | None = None,
|
|
143
|
+
task_id: str | None = None,
|
|
144
|
+
) -> Message | None:
|
|
145
|
+
"""Post a message with file attachments (multipart upload).
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
thread_id: Target thread.
|
|
149
|
+
text: Message body.
|
|
150
|
+
file_paths: Local file paths to upload as attachments.
|
|
151
|
+
agent_handle: The agent posting this message.
|
|
152
|
+
mentions: Explicit @mentions to include.
|
|
153
|
+
quoted_message_id: ID of a message to quote.
|
|
154
|
+
task_id: The dispatch task this reply is for.
|
|
155
|
+
"""
|
|
156
|
+
data = {"text": text}
|
|
157
|
+
if agent_handle:
|
|
158
|
+
data["agent_handle"] = agent_handle
|
|
159
|
+
if mentions:
|
|
160
|
+
data["explicit_mentions"] = json.dumps(mentions)
|
|
161
|
+
if quoted_message_id:
|
|
162
|
+
data["quoted_message_id"] = quoted_message_id
|
|
163
|
+
if task_id:
|
|
164
|
+
data["task_id"] = task_id
|
|
165
|
+
|
|
166
|
+
files = []
|
|
167
|
+
open_handles = []
|
|
168
|
+
try:
|
|
169
|
+
for fp in file_paths:
|
|
170
|
+
p = Path(fp)
|
|
171
|
+
if not p.exists():
|
|
172
|
+
logger.warning("Skipping missing file: %s", fp)
|
|
173
|
+
continue
|
|
174
|
+
fh = open(p, "rb") # noqa: SIM115
|
|
175
|
+
open_handles.append(fh)
|
|
176
|
+
# httpx expects (field_name, (filename, file_obj, content_type))
|
|
177
|
+
mime = _mime_from_ext(p.suffix.lower())
|
|
178
|
+
files.append(("files", (p.name, fh, mime)))
|
|
179
|
+
|
|
180
|
+
resp = await self._request(
|
|
181
|
+
"POST",
|
|
182
|
+
f"/api/threads/{thread_id}/messages",
|
|
183
|
+
data=data,
|
|
184
|
+
files=files,
|
|
185
|
+
timeout=60.0,
|
|
186
|
+
)
|
|
187
|
+
resp.raise_for_status()
|
|
188
|
+
body = resp.json()
|
|
189
|
+
if body.get("filtered"):
|
|
190
|
+
return None
|
|
191
|
+
return Message.from_dict(body, api=self)
|
|
192
|
+
finally:
|
|
193
|
+
for fh in open_handles:
|
|
194
|
+
fh.close()
|
|
195
|
+
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
# Memory endpoints
|
|
198
|
+
# ------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
async def save_memories(
|
|
201
|
+
self,
|
|
202
|
+
memories: list[dict],
|
|
203
|
+
space_id: str = "",
|
|
204
|
+
thread_id: str = "",
|
|
205
|
+
agent_handle: str = "",
|
|
206
|
+
corrections: list[dict] | None = None,
|
|
207
|
+
participants: list[str] | None = None,
|
|
208
|
+
) -> dict:
|
|
209
|
+
"""Persist extracted memories and corrections to the server."""
|
|
210
|
+
logger.info("Saving %d memories, %d corrections for @%s", len(memories), len(corrections or []), agent_handle)
|
|
211
|
+
body: dict = {
|
|
212
|
+
"memories": memories, "space_id": space_id, "thread_id": thread_id,
|
|
213
|
+
}
|
|
214
|
+
if agent_handle:
|
|
215
|
+
body["agent_handle"] = agent_handle
|
|
216
|
+
if corrections:
|
|
217
|
+
body["corrections"] = corrections
|
|
218
|
+
if participants:
|
|
219
|
+
body["participants"] = participants
|
|
220
|
+
resp = await self._request("POST", "/api/backends/me/memories", json=body)
|
|
221
|
+
resp.raise_for_status()
|
|
222
|
+
return resp.json()
|
|
223
|
+
|
|
224
|
+
async def get_memory_context(self, space_id: str = "", participants: list[str] | None = None, agent_handle: str = "") -> tuple[str, str]:
|
|
225
|
+
"""Fetch formatted memory context and project rules for prompt injection.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
(memory_context, space_rules) tuple.
|
|
229
|
+
"""
|
|
230
|
+
logger.info("Fetching memory context for @%s (space=%s)", agent_handle, space_id[:8])
|
|
231
|
+
params: dict[str, str] = {"space_id": space_id}
|
|
232
|
+
if participants:
|
|
233
|
+
params["participants"] = ",".join(participants)
|
|
234
|
+
if agent_handle:
|
|
235
|
+
params["agent_handle"] = agent_handle
|
|
236
|
+
resp = await self._request("GET", "/api/backends/me/memories", params=params)
|
|
237
|
+
resp.raise_for_status()
|
|
238
|
+
data = resp.json()
|
|
239
|
+
return data.get("context", ""), data.get("space_rules", "")
|
|
240
|
+
|
|
241
|
+
# ------------------------------------------------------------------
|
|
242
|
+
# Artifact endpoints
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
|
|
245
|
+
async def list_artifacts(self, thread_id: str) -> list[dict]:
|
|
246
|
+
"""List artifacts in a thread."""
|
|
247
|
+
resp = await self._request("GET", f"/api/threads/{thread_id}/artifacts")
|
|
248
|
+
resp.raise_for_status()
|
|
249
|
+
return resp.json()
|
|
250
|
+
|
|
251
|
+
async def create_artifact(
|
|
252
|
+
self,
|
|
253
|
+
thread_id: str,
|
|
254
|
+
title: str,
|
|
255
|
+
content: str,
|
|
256
|
+
content_type: str = "markdown",
|
|
257
|
+
*,
|
|
258
|
+
author_handle: str = "",
|
|
259
|
+
author_display_name: str = "",
|
|
260
|
+
message_id: str | None = None,
|
|
261
|
+
status: str = "draft",
|
|
262
|
+
) -> dict:
|
|
263
|
+
"""Create an artifact in a thread.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
thread_id: Thread to attach the artifact to.
|
|
267
|
+
title: Display title.
|
|
268
|
+
content: Artifact body.
|
|
269
|
+
content_type: One of markdown, code, html, mermaid.
|
|
270
|
+
author_handle: Agent handle creating this artifact.
|
|
271
|
+
author_display_name: Agent display name.
|
|
272
|
+
message_id: Optional message to link the artifact to.
|
|
273
|
+
status: draft or published.
|
|
274
|
+
"""
|
|
275
|
+
body = {
|
|
276
|
+
"thread_id": thread_id,
|
|
277
|
+
"title": title,
|
|
278
|
+
"content": content,
|
|
279
|
+
"content_type": content_type,
|
|
280
|
+
"author_handle": author_handle,
|
|
281
|
+
"author_display_name": author_display_name,
|
|
282
|
+
"status": status,
|
|
283
|
+
}
|
|
284
|
+
if message_id:
|
|
285
|
+
body["message_id"] = message_id
|
|
286
|
+
resp = await self._request("POST", "/api/artifacts", json=body)
|
|
287
|
+
resp.raise_for_status()
|
|
288
|
+
return resp.json()
|
|
289
|
+
|
|
290
|
+
async def update_artifact(
|
|
291
|
+
self,
|
|
292
|
+
artifact_id: str,
|
|
293
|
+
content: str,
|
|
294
|
+
*,
|
|
295
|
+
author_handle: str = "",
|
|
296
|
+
summary: str = "",
|
|
297
|
+
) -> dict:
|
|
298
|
+
"""Update an artifact's content (creates a new version).
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
artifact_id: Artifact to update.
|
|
302
|
+
content: New content.
|
|
303
|
+
author_handle: Agent handle making the update.
|
|
304
|
+
summary: Optional change summary.
|
|
305
|
+
"""
|
|
306
|
+
body = {
|
|
307
|
+
"content": content,
|
|
308
|
+
"author_handle": author_handle,
|
|
309
|
+
"summary": summary,
|
|
310
|
+
}
|
|
311
|
+
resp = await self._request("PUT", f"/api/artifacts/{artifact_id}", json=body)
|
|
312
|
+
resp.raise_for_status()
|
|
313
|
+
return resp.json()
|
|
314
|
+
|
|
315
|
+
async def create_artifact_comment(
|
|
316
|
+
self,
|
|
317
|
+
artifact_id: str,
|
|
318
|
+
text: str,
|
|
319
|
+
*,
|
|
320
|
+
author_handle: str = "",
|
|
321
|
+
author_display_name: str = "",
|
|
322
|
+
role: str = "agent",
|
|
323
|
+
) -> dict:
|
|
324
|
+
"""Add a comment to an artifact.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
artifact_id: Artifact to comment on.
|
|
328
|
+
text: Comment text (supports @mentions).
|
|
329
|
+
author_handle: Agent handle posting the comment.
|
|
330
|
+
author_display_name: Agent display name.
|
|
331
|
+
role: Comment role (agent or human).
|
|
332
|
+
"""
|
|
333
|
+
body = {
|
|
334
|
+
"text": text,
|
|
335
|
+
"author_handle": author_handle,
|
|
336
|
+
"author_display_name": author_display_name,
|
|
337
|
+
"role": role,
|
|
338
|
+
}
|
|
339
|
+
resp = await self._request("POST", f"/api/artifacts/{artifact_id}/comments", json=body)
|
|
340
|
+
resp.raise_for_status()
|
|
341
|
+
return resp.json()
|
|
342
|
+
|
|
343
|
+
# ------------------------------------------------------------------
|
|
344
|
+
# Stats endpoint
|
|
345
|
+
# ------------------------------------------------------------------
|
|
346
|
+
|
|
347
|
+
async def report_stats(self, stats: dict) -> dict:
|
|
348
|
+
"""Report token usage / invocation stats to the server."""
|
|
349
|
+
logger.info("Reporting stats: %d input, %d output tokens", stats.get("input_tokens", 0), stats.get("output_tokens", 0))
|
|
350
|
+
resp = await self._request("POST", "/api/backends/me/stats", json=stats)
|
|
351
|
+
resp.raise_for_status()
|
|
352
|
+
return resp.json()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Per-backend module structure for local CLI agents."""
|
|
2
|
+
|
|
3
|
+
from .registry import register, create_backend, is_available, list_available, list_all
|
|
4
|
+
from .base import LocalCliBackend, LocalInvocationResult
|
|
5
|
+
from .claude_code import LocalClaudeCodeBackend
|
|
6
|
+
from .gemini_cli import LocalGeminiCliBackend
|
|
7
|
+
from .codex_cli import LocalCodexCliBackend
|
|
8
|
+
from .openclaw import LocalOpenClawBackend
|
|
9
|
+
from .opencode import LocalOpenCodeBackend
|
|
10
|
+
|
|
11
|
+
register("claude_code", LocalClaudeCodeBackend)
|
|
12
|
+
register("gemini_cli", LocalGeminiCliBackend)
|
|
13
|
+
register("codex", LocalCodexCliBackend)
|
|
14
|
+
register("openclaw", LocalOpenClawBackend)
|
|
15
|
+
register("opencode", LocalOpenCodeBackend)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"register",
|
|
19
|
+
"create_backend",
|
|
20
|
+
"is_available",
|
|
21
|
+
"list_available",
|
|
22
|
+
"list_all",
|
|
23
|
+
"LocalCliBackend",
|
|
24
|
+
"LocalInvocationResult",
|
|
25
|
+
"LocalClaudeCodeBackend",
|
|
26
|
+
"LocalGeminiCliBackend",
|
|
27
|
+
"LocalCodexCliBackend",
|
|
28
|
+
"LocalOpenClawBackend",
|
|
29
|
+
"LocalOpenCodeBackend",
|
|
30
|
+
]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Base class and result dataclass for local CLI backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Image extensions we detect as generated files.
|
|
16
|
+
# SVG excluded — can contain <script> tags (XSS vector).
|
|
17
|
+
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".apng", ".avif", ".heic", ".heif"}
|
|
18
|
+
|
|
19
|
+
# Max depth for recursive image file scanning.
|
|
20
|
+
_MAX_SCAN_DEPTH = 3
|
|
21
|
+
# Max file size for upload (20 MB).
|
|
22
|
+
MAX_GENERATED_FILE_BYTES = 20 * 1024 * 1024
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class LocalInvocationResult:
|
|
27
|
+
"""Structured result from a local CLI invocation."""
|
|
28
|
+
|
|
29
|
+
text: str
|
|
30
|
+
session_id: str | None = None
|
|
31
|
+
stop_reason: str | None = None
|
|
32
|
+
input_tokens: int = 0
|
|
33
|
+
output_tokens: int = 0
|
|
34
|
+
cache_read_tokens: int = 0
|
|
35
|
+
cost_usd: float = 0.0
|
|
36
|
+
generated_files: list[str] = field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LocalCliBackend(ABC):
|
|
40
|
+
"""Abstract base for CLI-backed agents.
|
|
41
|
+
|
|
42
|
+
Subclasses must implement :meth:`is_available`, :meth:`_base_command`,
|
|
43
|
+
:meth:`_build_args`, and :meth:`_parse_output`.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
supports_resume: bool = False
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def is_available(cls) -> bool:
|
|
51
|
+
"""Return True when the required CLI binary is on PATH."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def _base_command(self) -> list[str]:
|
|
56
|
+
"""Return the base CLI command tokens (no resume flags)."""
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def _build_args(self, session_id: str | None = None) -> list[str]:
|
|
61
|
+
"""Return the full command list, optionally with resume flags."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def _parse_output(self, stdout: str) -> LocalInvocationResult:
|
|
66
|
+
"""Parse CLI stdout into a structured result."""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
def _get_stdin_data(self, prompt: str) -> str | None:
|
|
70
|
+
"""Return data to pipe to stdin, or None."""
|
|
71
|
+
return prompt
|
|
72
|
+
|
|
73
|
+
def _snapshot_image_files(self, cwd: str | None) -> dict[str, float]:
|
|
74
|
+
"""Return {resolved_path: mtime} for image files under cwd (bounded depth)."""
|
|
75
|
+
if not cwd:
|
|
76
|
+
return {}
|
|
77
|
+
cwd_resolved = Path(cwd).resolve()
|
|
78
|
+
result: dict[str, float] = {}
|
|
79
|
+
self._scan_dir(cwd_resolved, cwd_resolved, 0, result)
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def _scan_dir(
|
|
83
|
+
self, directory: Path, cwd_root: Path, depth: int, out: dict[str, float]
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Recursively scan for image files up to _MAX_SCAN_DEPTH."""
|
|
86
|
+
if depth > _MAX_SCAN_DEPTH:
|
|
87
|
+
return
|
|
88
|
+
try:
|
|
89
|
+
for entry in os.scandir(directory):
|
|
90
|
+
if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."):
|
|
91
|
+
self._scan_dir(Path(entry.path), cwd_root, depth + 1, out)
|
|
92
|
+
elif entry.is_file(follow_symlinks=False) and Path(entry.name).suffix.lower() in IMAGE_EXTENSIONS:
|
|
93
|
+
resolved = Path(entry.path).resolve()
|
|
94
|
+
# Guard against symlink exfiltration
|
|
95
|
+
if not str(resolved).startswith(str(cwd_root)):
|
|
96
|
+
logger.warning("Skipping file outside cwd: %s -> %s", entry.path, resolved)
|
|
97
|
+
continue
|
|
98
|
+
try:
|
|
99
|
+
size = resolved.stat().st_size
|
|
100
|
+
except OSError:
|
|
101
|
+
continue
|
|
102
|
+
if size > MAX_GENERATED_FILE_BYTES:
|
|
103
|
+
logger.warning("Skipping oversized file: %s (%d bytes)", entry.path, size)
|
|
104
|
+
continue
|
|
105
|
+
out[str(resolved)] = entry.stat().st_mtime
|
|
106
|
+
except OSError:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
def _detect_new_files(
|
|
110
|
+
self, before: dict[str, float], cwd: str | None
|
|
111
|
+
) -> list[str]:
|
|
112
|
+
"""Return paths of image files created or modified since the snapshot."""
|
|
113
|
+
if not cwd:
|
|
114
|
+
return []
|
|
115
|
+
after = self._snapshot_image_files(cwd)
|
|
116
|
+
new_files = []
|
|
117
|
+
for path, mtime in after.items():
|
|
118
|
+
if path not in before or mtime > before[path]:
|
|
119
|
+
new_files.append(path)
|
|
120
|
+
return sorted(new_files)
|
|
121
|
+
|
|
122
|
+
async def invoke(
|
|
123
|
+
self,
|
|
124
|
+
prompt: str,
|
|
125
|
+
session_id: str | None = None,
|
|
126
|
+
timeout: int = 300,
|
|
127
|
+
cwd: str | None = None,
|
|
128
|
+
) -> LocalInvocationResult:
|
|
129
|
+
"""Run the CLI subprocess, parse output.
|
|
130
|
+
|
|
131
|
+
Raises :class:`TimeoutError` on timeout, :class:`RuntimeError` on
|
|
132
|
+
non-zero exit, and :class:`~agenthive.rate_limit.RateLimitError` when
|
|
133
|
+
rate-limit patterns are detected.
|
|
134
|
+
"""
|
|
135
|
+
cmd = self._build_args(session_id)
|
|
136
|
+
stdin_data = self._get_stdin_data(prompt)
|
|
137
|
+
|
|
138
|
+
logger.info("Invoking %s: %s", self.__class__.__name__, " ".join(cmd))
|
|
139
|
+
logger.debug("stdin (%d chars): %s", len(prompt), prompt[:200])
|
|
140
|
+
|
|
141
|
+
# Snapshot image files before invocation to detect new ones
|
|
142
|
+
before_files = self._snapshot_image_files(cwd)
|
|
143
|
+
|
|
144
|
+
t0 = time.monotonic()
|
|
145
|
+
proc = await asyncio.create_subprocess_exec(
|
|
146
|
+
*cmd,
|
|
147
|
+
stdin=asyncio.subprocess.PIPE if stdin_data else None,
|
|
148
|
+
stdout=asyncio.subprocess.PIPE,
|
|
149
|
+
stderr=asyncio.subprocess.PIPE,
|
|
150
|
+
cwd=cwd,
|
|
151
|
+
)
|
|
152
|
+
try:
|
|
153
|
+
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
|
154
|
+
proc.communicate(stdin_data.encode() if stdin_data else None),
|
|
155
|
+
timeout=timeout,
|
|
156
|
+
)
|
|
157
|
+
except asyncio.TimeoutError:
|
|
158
|
+
proc.kill()
|
|
159
|
+
await proc.wait()
|
|
160
|
+
logger.error("Subprocess timed out after %ds", timeout)
|
|
161
|
+
raise TimeoutError(f"Subprocess timed out after {timeout}s")
|
|
162
|
+
|
|
163
|
+
duration = time.monotonic() - t0
|
|
164
|
+
stdout_text = stdout_bytes.decode(errors="replace")
|
|
165
|
+
stderr_text = stderr_bytes.decode(errors="replace")
|
|
166
|
+
|
|
167
|
+
if proc.returncode != 0:
|
|
168
|
+
logger.error(
|
|
169
|
+
"Subprocess failed (exit %d) in %.1fs: %s",
|
|
170
|
+
proc.returncode,
|
|
171
|
+
duration,
|
|
172
|
+
stderr_text[:300],
|
|
173
|
+
)
|
|
174
|
+
# Check for rate-limit signals
|
|
175
|
+
from agentshive.rate_limit import classify_rate_limit, RateLimitError
|
|
176
|
+
|
|
177
|
+
combined = f"{stdout_text}\n{stderr_text}"
|
|
178
|
+
snippet = classify_rate_limit(combined)
|
|
179
|
+
if snippet:
|
|
180
|
+
logger.warning("Rate limit detected: %s", snippet[:100])
|
|
181
|
+
raise RateLimitError("local_backend", snippet)
|
|
182
|
+
raise RuntimeError(f"Subprocess exited with code {proc.returncode}")
|
|
183
|
+
|
|
184
|
+
logger.info(
|
|
185
|
+
"Subprocess completed in %.1fs (exit=%d, stdout=%d chars)",
|
|
186
|
+
duration, proc.returncode, len(stdout_text),
|
|
187
|
+
)
|
|
188
|
+
logger.debug("stdout: %s", stdout_text[:500])
|
|
189
|
+
if stderr_text:
|
|
190
|
+
logger.debug("stderr: %s", stderr_text[:500])
|
|
191
|
+
|
|
192
|
+
result = self._parse_output(stdout_text)
|
|
193
|
+
|
|
194
|
+
# Detect image files created during invocation
|
|
195
|
+
new_files = self._detect_new_files(before_files, cwd)
|
|
196
|
+
if new_files:
|
|
197
|
+
result.generated_files = new_files
|
|
198
|
+
logger.info("Detected %d generated image files: %s", len(new_files), new_files)
|
|
199
|
+
|
|
200
|
+
logger.info(
|
|
201
|
+
"Parsed: text=%d chars, session_id=%s, tokens=%d/%d",
|
|
202
|
+
len(result.text),
|
|
203
|
+
result.session_id[:8] if result.session_id else None,
|
|
204
|
+
result.input_tokens,
|
|
205
|
+
result.output_tokens,
|
|
206
|
+
)
|
|
207
|
+
return result
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Claude Code CLI backend."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import shutil
|
|
8
|
+
|
|
9
|
+
from .base import LocalCliBackend, LocalInvocationResult
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LocalClaudeCodeBackend(LocalCliBackend):
|
|
15
|
+
"""Wraps the ``claude`` CLI (Claude Code)."""
|
|
16
|
+
|
|
17
|
+
supports_resume = True
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def is_available(cls) -> bool:
|
|
21
|
+
return shutil.which("claude") is not None
|
|
22
|
+
|
|
23
|
+
def _base_command(self) -> list[str]:
|
|
24
|
+
return [
|
|
25
|
+
"claude",
|
|
26
|
+
"--print",
|
|
27
|
+
"--dangerously-skip-permissions",
|
|
28
|
+
"--output-format",
|
|
29
|
+
"json",
|
|
30
|
+
"-p",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
def _build_args(self, session_id: str | None = None) -> list[str]:
|
|
34
|
+
args = list(self._base_command())
|
|
35
|
+
if session_id:
|
|
36
|
+
p_idx = args.index("-p")
|
|
37
|
+
args.insert(p_idx, "--resume")
|
|
38
|
+
args.insert(p_idx + 1, session_id)
|
|
39
|
+
return args
|
|
40
|
+
|
|
41
|
+
def _parse_output(self, stdout: str) -> LocalInvocationResult:
|
|
42
|
+
try:
|
|
43
|
+
data = json.loads(stdout)
|
|
44
|
+
except json.JSONDecodeError:
|
|
45
|
+
logger.warning("Failed to parse JSON from claude output, treating as plain text")
|
|
46
|
+
return LocalInvocationResult(text=stdout.strip())
|
|
47
|
+
|
|
48
|
+
usage = data.get("usage", {})
|
|
49
|
+
return LocalInvocationResult(
|
|
50
|
+
text=data.get("result", ""),
|
|
51
|
+
session_id=data.get("session_id"),
|
|
52
|
+
stop_reason=data.get("stop_reason") or data.get("stopReason"),
|
|
53
|
+
input_tokens=usage.get("input_tokens", 0),
|
|
54
|
+
output_tokens=usage.get("output_tokens", 0),
|
|
55
|
+
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
|
56
|
+
cost_usd=data.get("total_cost_usd", 0.0),
|
|
57
|
+
)
|