stirrup 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stirrup/__init__.py +2 -0
- stirrup/clients/__init__.py +5 -0
- stirrup/clients/chat_completions_client.py +0 -3
- stirrup/clients/litellm_client.py +20 -11
- stirrup/clients/open_responses_client.py +434 -0
- stirrup/clients/utils.py +6 -1
- stirrup/constants.py +6 -2
- stirrup/core/agent.py +196 -57
- stirrup/core/cache.py +479 -0
- stirrup/core/models.py +53 -9
- stirrup/prompts/base_system_prompt.txt +1 -1
- stirrup/tools/__init__.py +3 -0
- stirrup/tools/browser_use.py +591 -0
- stirrup/tools/calculator.py +1 -1
- stirrup/tools/code_backends/base.py +24 -0
- stirrup/tools/code_backends/docker.py +19 -0
- stirrup/tools/code_backends/e2b.py +43 -11
- stirrup/tools/code_backends/local.py +19 -2
- stirrup/tools/finish.py +27 -1
- stirrup/tools/user_input.py +130 -0
- stirrup/tools/web.py +1 -0
- stirrup/utils/logging.py +32 -7
- {stirrup-0.1.2.dist-info → stirrup-0.1.4.dist-info}/METADATA +16 -13
- stirrup-0.1.4.dist-info/RECORD +38 -0
- {stirrup-0.1.2.dist-info → stirrup-0.1.4.dist-info}/WHEEL +2 -2
- stirrup-0.1.2.dist-info/RECORD +0 -34
stirrup/core/cache.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
"""Cache module for persisting and resuming agent state.
|
|
2
|
+
|
|
3
|
+
Provides functionality to cache agent state (messages, run metadata, execution environment files)
|
|
4
|
+
on non-success exits and restore that state for resumption in new runs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import hashlib
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import shutil
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from pydantic import TypeAdapter
|
|
19
|
+
|
|
20
|
+
from stirrup.core.models import (
|
|
21
|
+
AudioContentBlock,
|
|
22
|
+
ChatMessage,
|
|
23
|
+
ImageContentBlock,
|
|
24
|
+
VideoContentBlock,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Default cache directory relative to the project root
|
|
30
|
+
DEFAULT_CACHE_DIR = Path("~/.cache/stirrup/").expanduser()
|
|
31
|
+
|
|
32
|
+
# TypeAdapter for deserializing ChatMessage discriminated union
|
|
33
|
+
ChatMessageAdapter: TypeAdapter[ChatMessage] = TypeAdapter(ChatMessage)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def compute_task_hash(init_msgs: str | list[ChatMessage]) -> str:
|
|
37
|
+
"""Compute deterministic hash from initial messages for cache identification.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
init_msgs: Either a string prompt or list of ChatMessage objects.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
First 12 characters of SHA256 hash (hex) for readability.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(init_msgs, str):
|
|
46
|
+
content = init_msgs
|
|
47
|
+
else:
|
|
48
|
+
# Serialize messages to JSON for hashing
|
|
49
|
+
content = json.dumps(
|
|
50
|
+
[serialize_message(msg) for msg in init_msgs],
|
|
51
|
+
sort_keys=True,
|
|
52
|
+
ensure_ascii=True,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
hash_bytes = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
56
|
+
return hash_bytes[:12]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _serialize_content_block(block: Any) -> dict | str: # noqa: ANN401
|
|
60
|
+
"""Serialize a content block, encoding binary data as base64.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
block: A content block (string, ImageContentBlock, VideoContentBlock, AudioContentBlock).
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
JSON-serializable representation with base64-encoded binary data.
|
|
67
|
+
"""
|
|
68
|
+
if isinstance(block, str):
|
|
69
|
+
return block
|
|
70
|
+
elif isinstance(block, ImageContentBlock):
|
|
71
|
+
return {
|
|
72
|
+
"kind": "image_content_block",
|
|
73
|
+
"data": base64.b64encode(block.data).decode("ascii"),
|
|
74
|
+
}
|
|
75
|
+
elif isinstance(block, VideoContentBlock):
|
|
76
|
+
return {
|
|
77
|
+
"kind": "video_content_block",
|
|
78
|
+
"data": base64.b64encode(block.data).decode("ascii"),
|
|
79
|
+
}
|
|
80
|
+
elif isinstance(block, AudioContentBlock):
|
|
81
|
+
return {
|
|
82
|
+
"kind": "audio_content_block",
|
|
83
|
+
"data": base64.b64encode(block.data).decode("ascii"),
|
|
84
|
+
}
|
|
85
|
+
elif isinstance(block, dict):
|
|
86
|
+
# Handle dict from model_dump that might contain unencoded bytes
|
|
87
|
+
# This can happen when Pydantic fails to base64-encode bytes in mode="json"
|
|
88
|
+
if "data" in block and isinstance(block["data"], bytes):
|
|
89
|
+
return {
|
|
90
|
+
**block,
|
|
91
|
+
"data": base64.b64encode(block["data"]).decode("ascii"),
|
|
92
|
+
}
|
|
93
|
+
return block
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unknown content block type: {type(block)}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _deserialize_content_block(data: dict | str) -> Any: # noqa: ANN401
|
|
99
|
+
"""Deserialize a content block, decoding base64 binary data.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
data: JSON-serialized content block.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Restored content block with decoded binary data.
|
|
106
|
+
"""
|
|
107
|
+
if isinstance(data, str):
|
|
108
|
+
return data
|
|
109
|
+
if not isinstance(data, dict):
|
|
110
|
+
return data
|
|
111
|
+
|
|
112
|
+
kind = data.get("kind")
|
|
113
|
+
if kind == "image_content_block":
|
|
114
|
+
return ImageContentBlock(data=base64.b64decode(data["data"]))
|
|
115
|
+
elif kind == "video_content_block":
|
|
116
|
+
return VideoContentBlock(data=base64.b64decode(data["data"]))
|
|
117
|
+
elif kind == "audio_content_block":
|
|
118
|
+
return AudioContentBlock(data=base64.b64decode(data["data"]))
|
|
119
|
+
else:
|
|
120
|
+
# Unknown or already-processed block
|
|
121
|
+
return data
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def serialize_message(msg: ChatMessage) -> dict:
|
|
125
|
+
"""Serialize a ChatMessage to JSON-compatible format.
|
|
126
|
+
|
|
127
|
+
Handles binary content blocks (images, video, audio) by base64 encoding.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
msg: A ChatMessage (SystemMessage, UserMessage, AssistantMessage, ToolMessage).
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
JSON-serializable dictionary.
|
|
134
|
+
"""
|
|
135
|
+
# Use Pydantic's model_dump for base serialization
|
|
136
|
+
data = msg.model_dump(mode="json")
|
|
137
|
+
|
|
138
|
+
# Handle content field which may contain binary blocks
|
|
139
|
+
content = data.get("content")
|
|
140
|
+
if isinstance(content, list):
|
|
141
|
+
data["content"] = [_serialize_content_block(block) for block in content]
|
|
142
|
+
elif content is not None and not isinstance(content, str):
|
|
143
|
+
data["content"] = _serialize_content_block(content)
|
|
144
|
+
|
|
145
|
+
return data
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def deserialize_message(data: dict) -> ChatMessage:
|
|
149
|
+
"""Deserialize a ChatMessage from JSON format.
|
|
150
|
+
|
|
151
|
+
Handles base64-encoded binary content blocks.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
data: JSON dictionary representing a ChatMessage.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Restored ChatMessage object.
|
|
158
|
+
"""
|
|
159
|
+
# Handle content field which may contain base64-encoded binary blocks
|
|
160
|
+
content = data.get("content")
|
|
161
|
+
if isinstance(content, list):
|
|
162
|
+
data["content"] = [_deserialize_content_block(block) for block in content]
|
|
163
|
+
elif content is not None and not isinstance(content, str):
|
|
164
|
+
data["content"] = _deserialize_content_block(content)
|
|
165
|
+
|
|
166
|
+
# Use TypeAdapter for discriminated union deserialization
|
|
167
|
+
return ChatMessageAdapter.validate_python(data)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def serialize_messages(msgs: list[ChatMessage]) -> list[dict]:
|
|
171
|
+
"""Serialize a list of ChatMessages to JSON-compatible format.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
msgs: List of ChatMessage objects.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
List of JSON-serializable dictionaries.
|
|
178
|
+
"""
|
|
179
|
+
return [serialize_message(msg) for msg in msgs]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _serialize_metadata_item(item: Any) -> Any: # noqa: ANN401
|
|
183
|
+
"""Serialize a single metadata item to JSON-compatible format.
|
|
184
|
+
|
|
185
|
+
Handles Pydantic models by calling model_dump(mode='json').
|
|
186
|
+
Handles bytes by base64 encoding them.
|
|
187
|
+
"""
|
|
188
|
+
from pydantic import BaseModel
|
|
189
|
+
|
|
190
|
+
if isinstance(item, BaseModel):
|
|
191
|
+
return item.model_dump(mode="json")
|
|
192
|
+
elif isinstance(item, bytes):
|
|
193
|
+
# Base64 encode raw bytes to make them JSON-serializable
|
|
194
|
+
return base64.b64encode(item).decode("ascii")
|
|
195
|
+
elif isinstance(item, dict):
|
|
196
|
+
return {k: _serialize_metadata_item(v) for k, v in item.items()}
|
|
197
|
+
elif isinstance(item, list):
|
|
198
|
+
return [_serialize_metadata_item(i) for i in item]
|
|
199
|
+
else:
|
|
200
|
+
return item
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _serialize_run_metadata(run_metadata: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
|
204
|
+
"""Serialize run_metadata dict containing Pydantic models to JSON-compatible format.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
run_metadata: Dict mapping tool names to lists of metadata (may contain Pydantic models).
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
JSON-serializable dictionary.
|
|
211
|
+
"""
|
|
212
|
+
return {
|
|
213
|
+
tool_name: [_serialize_metadata_item(item) for item in metadata_list]
|
|
214
|
+
for tool_name, metadata_list in run_metadata.items()
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def deserialize_messages(data: list[dict]) -> list[ChatMessage]:
|
|
219
|
+
"""Deserialize a list of ChatMessages from JSON format.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
data: List of JSON dictionaries representing ChatMessages.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
List of restored ChatMessage objects.
|
|
226
|
+
"""
|
|
227
|
+
return [deserialize_message(msg_data) for msg_data in data]
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class CacheState:
|
|
232
|
+
"""Serializable state for resuming an agent run.
|
|
233
|
+
|
|
234
|
+
Captures all necessary state to resume execution from a specific turn.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
msgs: list[ChatMessage]
|
|
238
|
+
"""Current conversation messages in the active run loop."""
|
|
239
|
+
|
|
240
|
+
full_msg_history: list[list[ChatMessage]]
|
|
241
|
+
"""Groups of messages (separated when context summarization occurs)."""
|
|
242
|
+
|
|
243
|
+
turn: int
|
|
244
|
+
"""Current turn number (0-indexed) - resume will start from this turn."""
|
|
245
|
+
|
|
246
|
+
run_metadata: dict[str, list[Any]]
|
|
247
|
+
"""Accumulated tool metadata from the run."""
|
|
248
|
+
|
|
249
|
+
task_hash: str
|
|
250
|
+
"""Hash of the original init_msgs for verification on resume."""
|
|
251
|
+
|
|
252
|
+
timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
|
|
253
|
+
"""ISO timestamp when cache was created."""
|
|
254
|
+
|
|
255
|
+
agent_name: str = ""
|
|
256
|
+
"""Name of the agent that created this cache."""
|
|
257
|
+
|
|
258
|
+
def to_dict(self) -> dict:
|
|
259
|
+
"""Convert to JSON-serializable dictionary."""
|
|
260
|
+
return {
|
|
261
|
+
"msgs": serialize_messages(self.msgs),
|
|
262
|
+
"full_msg_history": [serialize_messages(group) for group in self.full_msg_history],
|
|
263
|
+
"turn": self.turn,
|
|
264
|
+
"run_metadata": _serialize_run_metadata(self.run_metadata),
|
|
265
|
+
"task_hash": self.task_hash,
|
|
266
|
+
"timestamp": self.timestamp,
|
|
267
|
+
"agent_name": self.agent_name,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def from_dict(cls, data: dict) -> "CacheState":
|
|
272
|
+
"""Create CacheState from JSON dictionary."""
|
|
273
|
+
return cls(
|
|
274
|
+
msgs=deserialize_messages(data["msgs"]),
|
|
275
|
+
full_msg_history=[deserialize_messages(group) for group in data["full_msg_history"]],
|
|
276
|
+
turn=data["turn"],
|
|
277
|
+
run_metadata=data["run_metadata"],
|
|
278
|
+
task_hash=data["task_hash"],
|
|
279
|
+
timestamp=data.get("timestamp", ""),
|
|
280
|
+
agent_name=data.get("agent_name", ""),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class CacheManager:
|
|
285
|
+
"""Manages cache operations for agent sessions.
|
|
286
|
+
|
|
287
|
+
Handles saving/loading cache state and execution environment files.
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
cache_base_dir: Path | None = None,
|
|
293
|
+
clear_on_success: bool = True,
|
|
294
|
+
) -> None:
|
|
295
|
+
"""Initialize CacheManager.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
cache_base_dir: Base directory for cache storage.
|
|
299
|
+
Defaults to ~/.cache/stirrup/
|
|
300
|
+
clear_on_success: If True (default), automatically clear the cache when
|
|
301
|
+
the agent completes successfully. Set to False to preserve
|
|
302
|
+
caches for inspection or manual management.
|
|
303
|
+
"""
|
|
304
|
+
self._cache_base_dir = cache_base_dir or DEFAULT_CACHE_DIR
|
|
305
|
+
self.clear_on_success = clear_on_success
|
|
306
|
+
|
|
307
|
+
def _get_cache_dir(self, task_hash: str) -> Path:
|
|
308
|
+
"""Get cache directory path for a task hash."""
|
|
309
|
+
return self._cache_base_dir / task_hash
|
|
310
|
+
|
|
311
|
+
def _get_state_file(self, task_hash: str) -> Path:
|
|
312
|
+
"""Get state.json file path for a task hash."""
|
|
313
|
+
return self._get_cache_dir(task_hash) / "state.json"
|
|
314
|
+
|
|
315
|
+
def _get_files_dir(self, task_hash: str) -> Path:
|
|
316
|
+
"""Get files directory path for a task hash."""
|
|
317
|
+
return self._get_cache_dir(task_hash) / "files"
|
|
318
|
+
|
|
319
|
+
def save_state(
|
|
320
|
+
self,
|
|
321
|
+
task_hash: str,
|
|
322
|
+
state: CacheState,
|
|
323
|
+
exec_env_dir: Path | None = None,
|
|
324
|
+
) -> None:
|
|
325
|
+
"""Save cache state and optionally archive execution environment files.
|
|
326
|
+
|
|
327
|
+
Uses atomic writes to prevent corrupted cache files if interrupted mid-write.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
task_hash: Unique identifier for this task/cache.
|
|
331
|
+
state: CacheState to persist.
|
|
332
|
+
exec_env_dir: Optional path to execution environment temp directory.
|
|
333
|
+
If provided, all files will be copied to cache.
|
|
334
|
+
"""
|
|
335
|
+
cache_dir = self._get_cache_dir(task_hash)
|
|
336
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
337
|
+
|
|
338
|
+
# Save state JSON using atomic write (write to temp file, then rename)
|
|
339
|
+
state_file = self._get_state_file(task_hash)
|
|
340
|
+
temp_file = state_file.with_suffix(".json.tmp")
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
state_data = state.to_dict()
|
|
344
|
+
logger.debug("Serialized cache state: turn=%d, msgs=%d", state.turn, len(state.msgs))
|
|
345
|
+
|
|
346
|
+
with open(temp_file, "w", encoding="utf-8") as f:
|
|
347
|
+
json.dump(state_data, f, indent=2, ensure_ascii=False)
|
|
348
|
+
f.flush()
|
|
349
|
+
os.fsync(f.fileno()) # Ensure data is written to disk
|
|
350
|
+
|
|
351
|
+
logger.debug("Wrote temp file: %s", temp_file)
|
|
352
|
+
|
|
353
|
+
# Atomic rename (on POSIX systems)
|
|
354
|
+
temp_file.replace(state_file)
|
|
355
|
+
logger.info("Saved cache state to %s (turn %d)", state_file, state.turn)
|
|
356
|
+
except Exception as e:
|
|
357
|
+
logger.exception("Failed to save cache state: %s", e)
|
|
358
|
+
# Try direct write as fallback
|
|
359
|
+
try:
|
|
360
|
+
logger.warning("Attempting direct write as fallback")
|
|
361
|
+
with open(state_file, "w", encoding="utf-8") as f:
|
|
362
|
+
json.dump(state_data, f, indent=2, ensure_ascii=False)
|
|
363
|
+
f.flush()
|
|
364
|
+
os.fsync(f.fileno())
|
|
365
|
+
logger.info("Fallback write succeeded to %s", state_file)
|
|
366
|
+
except Exception as e2:
|
|
367
|
+
logger.exception("Fallback write also failed: %s", e2)
|
|
368
|
+
# Clean up temp file if it exists
|
|
369
|
+
if temp_file.exists():
|
|
370
|
+
temp_file.unlink()
|
|
371
|
+
raise
|
|
372
|
+
|
|
373
|
+
# Copy execution environment files if provided
|
|
374
|
+
if exec_env_dir and exec_env_dir.exists():
|
|
375
|
+
files_dir = self._get_files_dir(task_hash)
|
|
376
|
+
if files_dir.exists():
|
|
377
|
+
shutil.rmtree(files_dir) # Clear existing files
|
|
378
|
+
shutil.copytree(exec_env_dir, files_dir, dirs_exist_ok=True)
|
|
379
|
+
logger.info("Saved execution environment files to %s", files_dir)
|
|
380
|
+
|
|
381
|
+
def load_state(self, task_hash: str) -> CacheState | None:
|
|
382
|
+
"""Load cached state for a task hash.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
task_hash: Unique identifier for the task/cache.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
CacheState if cache exists, None otherwise.
|
|
389
|
+
"""
|
|
390
|
+
state_file = self._get_state_file(task_hash)
|
|
391
|
+
if not state_file.exists():
|
|
392
|
+
logger.debug("No cache found for task %s", task_hash)
|
|
393
|
+
return None
|
|
394
|
+
|
|
395
|
+
try:
|
|
396
|
+
with open(state_file, encoding="utf-8") as f:
|
|
397
|
+
data = json.load(f)
|
|
398
|
+
state = CacheState.from_dict(data)
|
|
399
|
+
logger.info("Loaded cache state from %s (turn %d)", state_file, state.turn)
|
|
400
|
+
return state
|
|
401
|
+
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
|
402
|
+
logger.warning("Failed to load cache for task %s: %s", task_hash, e)
|
|
403
|
+
return None
|
|
404
|
+
|
|
405
|
+
def restore_files(self, task_hash: str, dest_dir: Path) -> bool:
|
|
406
|
+
"""Restore cached files to the destination directory.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
task_hash: Unique identifier for the task/cache.
|
|
410
|
+
dest_dir: Destination directory (typically the new exec env temp dir).
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
True if files were restored, False if no files cache exists.
|
|
414
|
+
"""
|
|
415
|
+
files_dir = self._get_files_dir(task_hash)
|
|
416
|
+
if not files_dir.exists():
|
|
417
|
+
logger.debug("No cached files for task %s", task_hash)
|
|
418
|
+
return False
|
|
419
|
+
|
|
420
|
+
# Copy all files from cache to destination
|
|
421
|
+
for item in files_dir.iterdir():
|
|
422
|
+
dest_item = dest_dir / item.name
|
|
423
|
+
if item.is_file():
|
|
424
|
+
shutil.copy2(item, dest_item)
|
|
425
|
+
else:
|
|
426
|
+
shutil.copytree(item, dest_item, dirs_exist_ok=True)
|
|
427
|
+
|
|
428
|
+
logger.info("Restored cached files from %s to %s", files_dir, dest_dir)
|
|
429
|
+
return True
|
|
430
|
+
|
|
431
|
+
def clear_cache(self, task_hash: str) -> None:
|
|
432
|
+
"""Remove cache for a specific task.
|
|
433
|
+
|
|
434
|
+
Called after successful completion to clean up.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
task_hash: Unique identifier for the task/cache.
|
|
438
|
+
"""
|
|
439
|
+
cache_dir = self._get_cache_dir(task_hash)
|
|
440
|
+
if cache_dir.exists():
|
|
441
|
+
shutil.rmtree(cache_dir)
|
|
442
|
+
logger.info("Cleared cache for task %s", task_hash)
|
|
443
|
+
|
|
444
|
+
def list_caches(self) -> list[str]:
|
|
445
|
+
"""List all available cache hashes.
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
List of task hashes with existing caches.
|
|
449
|
+
"""
|
|
450
|
+
if not self._cache_base_dir.exists():
|
|
451
|
+
return []
|
|
452
|
+
|
|
453
|
+
return [d.name for d in self._cache_base_dir.iterdir() if d.is_dir() and (d / "state.json").exists()]
|
|
454
|
+
|
|
455
|
+
def get_cache_info(self, task_hash: str) -> dict | None:
|
|
456
|
+
"""Get metadata about a cache without fully loading it.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
task_hash: Unique identifier for the task/cache.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
Dictionary with cache info (turn, timestamp, agent_name) or None.
|
|
463
|
+
"""
|
|
464
|
+
state_file = self._get_state_file(task_hash)
|
|
465
|
+
if not state_file.exists():
|
|
466
|
+
return None
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
with open(state_file, encoding="utf-8") as f:
|
|
470
|
+
data = json.load(f)
|
|
471
|
+
return {
|
|
472
|
+
"task_hash": task_hash,
|
|
473
|
+
"turn": data.get("turn", 0),
|
|
474
|
+
"timestamp": data.get("timestamp", ""),
|
|
475
|
+
"agent_name": data.get("agent_name", ""),
|
|
476
|
+
"has_files": self._get_files_dir(task_hash).exists(),
|
|
477
|
+
}
|
|
478
|
+
except (json.JSONDecodeError, KeyError):
|
|
479
|
+
return None
|
stirrup/core/models.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import mimetypes
|
|
2
3
|
import warnings
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
@@ -15,7 +16,7 @@ import filetype
|
|
|
15
16
|
from moviepy import AudioFileClip, VideoFileClip
|
|
16
17
|
from moviepy.video.fx import Resize
|
|
17
18
|
from PIL import Image
|
|
18
|
-
from pydantic import BaseModel, Field, model_validator
|
|
19
|
+
from pydantic import BaseModel, Field, PlainSerializer, PlainValidator, model_validator
|
|
19
20
|
|
|
20
21
|
from stirrup.constants import RESOLUTION_1MP, RESOLUTION_480P
|
|
21
22
|
|
|
@@ -27,6 +28,7 @@ __all__ = [
|
|
|
27
28
|
"ChatMessage",
|
|
28
29
|
"Content",
|
|
29
30
|
"ContentBlock",
|
|
31
|
+
"EmptyParams",
|
|
30
32
|
"ImageContentBlock",
|
|
31
33
|
"LLMClient",
|
|
32
34
|
"SubAgentMetadata",
|
|
@@ -44,6 +46,25 @@ __all__ = [
|
|
|
44
46
|
]
|
|
45
47
|
|
|
46
48
|
|
|
49
|
+
def _bytes_to_b64(v: bytes) -> str:
|
|
50
|
+
return base64.b64encode(v).decode("ascii")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _b64_to_bytes(v: bytes | str) -> bytes:
|
|
54
|
+
if isinstance(v, bytes):
|
|
55
|
+
return v
|
|
56
|
+
if isinstance(v, str):
|
|
57
|
+
return base64.b64decode(v.encode("ascii"))
|
|
58
|
+
raise TypeError("Invalid bytes value")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
Base64Bytes = Annotated[
|
|
62
|
+
bytes,
|
|
63
|
+
PlainValidator(_b64_to_bytes),
|
|
64
|
+
PlainSerializer(_bytes_to_b64, when_used="json"),
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
47
68
|
def downscale_image(w: int, h: int, max_pixels: int | None = 1_000_000) -> tuple[int, int]:
|
|
48
69
|
"""Downscale image dimensions to fit within max pixel count while maintaining aspect ratio.
|
|
49
70
|
|
|
@@ -58,7 +79,7 @@ def downscale_image(w: int, h: int, max_pixels: int | None = 1_000_000) -> tuple
|
|
|
58
79
|
class BinaryContentBlock(BaseModel, ABC):
|
|
59
80
|
"""Base class for binary content (images, video, audio) with MIME type validation."""
|
|
60
81
|
|
|
61
|
-
data:
|
|
82
|
+
data: Base64Bytes
|
|
62
83
|
allowed_mime_types: ClassVar[set[str]]
|
|
63
84
|
|
|
64
85
|
@property
|
|
@@ -400,12 +421,14 @@ class ToolUseCountMetadata(BaseModel):
|
|
|
400
421
|
|
|
401
422
|
Implements Addable protocol for aggregation. Use this for tools that only need
|
|
402
423
|
to track how many times they were called.
|
|
424
|
+
|
|
425
|
+
Subclasses can override __add__ with their own type thanks to Self typing.
|
|
403
426
|
"""
|
|
404
427
|
|
|
405
428
|
num_uses: int = 1
|
|
406
429
|
|
|
407
|
-
def __add__(self, other:
|
|
408
|
-
return
|
|
430
|
+
def __add__(self, other: Self) -> Self:
|
|
431
|
+
return self.__class__(num_uses=self.num_uses + other.num_uses)
|
|
409
432
|
|
|
410
433
|
|
|
411
434
|
class ToolResult[M](BaseModel):
|
|
@@ -413,17 +436,27 @@ class ToolResult[M](BaseModel):
|
|
|
413
436
|
|
|
414
437
|
Generic over metadata type M. M should implement Addable protocol for aggregation support,
|
|
415
438
|
but this is not enforced at the class level due to Pydantic schema generation limitations.
|
|
439
|
+
|
|
440
|
+
Attributes:
|
|
441
|
+
content: The result content (string, list of content blocks, or images)
|
|
442
|
+
success: Whether the tool call was successful. For finish tools, controls if agent terminates.
|
|
443
|
+
metadata: Optional metadata (e.g., usage stats) that implements Addable for aggregation
|
|
416
444
|
"""
|
|
417
445
|
|
|
418
446
|
content: Content
|
|
447
|
+
success: bool = True
|
|
419
448
|
metadata: M | None = None
|
|
420
449
|
|
|
421
450
|
|
|
451
|
+
class EmptyParams(BaseModel):
|
|
452
|
+
"""Empty parameter model for tools that don't require parameters."""
|
|
453
|
+
|
|
454
|
+
|
|
422
455
|
class Tool[P: BaseModel, M](BaseModel):
|
|
423
456
|
"""Tool definition with name, description, parameter schema, and executor function.
|
|
424
457
|
|
|
425
458
|
Generic over:
|
|
426
|
-
P: Parameter model type (
|
|
459
|
+
P: Parameter model type (Pydantic BaseModel subclass, or EmptyParams for parameterless tools)
|
|
427
460
|
M: Metadata type (should implement Addable for aggregation; use None for tools without metadata)
|
|
428
461
|
|
|
429
462
|
Tools are simple, stateless callables. For tools requiring lifecycle management
|
|
@@ -442,9 +475,9 @@ class Tool[P: BaseModel, M](BaseModel):
|
|
|
442
475
|
)
|
|
443
476
|
```
|
|
444
477
|
|
|
445
|
-
Example without parameters:
|
|
478
|
+
Example without parameters (uses EmptyParams by default):
|
|
446
479
|
```python
|
|
447
|
-
time_tool = Tool[
|
|
480
|
+
time_tool = Tool[EmptyParams, None](
|
|
448
481
|
name="time",
|
|
449
482
|
description="Get current time",
|
|
450
483
|
executor=lambda _: ToolResult(content=datetime.now().isoformat()),
|
|
@@ -454,7 +487,7 @@ class Tool[P: BaseModel, M](BaseModel):
|
|
|
454
487
|
|
|
455
488
|
name: str
|
|
456
489
|
description: str
|
|
457
|
-
parameters: type[P]
|
|
490
|
+
parameters: type[P] = EmptyParams # type: ignore[assignment]
|
|
458
491
|
executor: Callable[[P], ToolResult[M] | Awaitable[ToolResult[M]]]
|
|
459
492
|
|
|
460
493
|
|
|
@@ -527,6 +560,7 @@ class ToolCall(BaseModel):
|
|
|
527
560
|
tool_call_id: Unique identifier for tracking this tool call and its result
|
|
528
561
|
"""
|
|
529
562
|
|
|
563
|
+
signature: str | None = None
|
|
530
564
|
name: str
|
|
531
565
|
arguments: str
|
|
532
566
|
tool_call_id: str | None = None
|
|
@@ -564,13 +598,23 @@ class AssistantMessage(BaseModel):
|
|
|
564
598
|
|
|
565
599
|
|
|
566
600
|
class ToolMessage(BaseModel):
|
|
567
|
-
"""Tool execution result returned to the LLM.
|
|
601
|
+
"""Tool execution result returned to the LLM.
|
|
602
|
+
|
|
603
|
+
Attributes:
|
|
604
|
+
role: Always "tool"
|
|
605
|
+
content: The tool result content
|
|
606
|
+
tool_call_id: ID linking this result to the corresponding tool call
|
|
607
|
+
name: Name of the tool that was called
|
|
608
|
+
args_was_valid: Whether the tool arguments were valid
|
|
609
|
+
success: Whether the tool executed successfully (used by finish tool to control termination)
|
|
610
|
+
"""
|
|
568
611
|
|
|
569
612
|
role: Literal["tool"] = "tool"
|
|
570
613
|
content: Content
|
|
571
614
|
tool_call_id: str | None = None
|
|
572
615
|
name: str | None = None
|
|
573
616
|
args_was_valid: bool = True
|
|
617
|
+
success: bool = False
|
|
574
618
|
|
|
575
619
|
|
|
576
620
|
type ChatMessage = Annotated[SystemMessage | UserMessage | AssistantMessage | ToolMessage, Field(discriminator="role")]
|
|
@@ -1 +1 @@
|
|
|
1
|
-
You are an AI agent that will be given a specific task. You are to complete that task using the tools provided in {max_turns} steps. You will need to call the finish tool as your last step, where you will pass your finish reason and paths to any files that you wish to return to the user.
|
|
1
|
+
You are an AI agent that will be given a specific task. You are to complete that task using the tools provided in {max_turns} steps. You will need to call the finish tool as your last step, where you will pass your finish reason and paths to any files that you wish to return to the user.
|
stirrup/tools/__init__.py
CHANGED
|
@@ -47,6 +47,7 @@ Optional tool providers require explicit imports from their submodules:
|
|
|
47
47
|
- DockerCodeExecToolProvider: `from stirrup.tools.code_backends.docker import DockerCodeExecToolProvider`
|
|
48
48
|
- E2BCodeExecToolProvider: `from stirrup.tools.code_backends.e2b import E2BCodeExecToolProvider`
|
|
49
49
|
- MCPToolProvider: `from stirrup.tools.mcp import MCPToolProvider`
|
|
50
|
+
- BrowserUseToolProvider: `from stirrup.tools.browser_use import BrowserUseToolProvider`
|
|
50
51
|
"""
|
|
51
52
|
|
|
52
53
|
from typing import Any
|
|
@@ -55,6 +56,7 @@ from stirrup.core.models import Tool, ToolProvider
|
|
|
55
56
|
from stirrup.tools.calculator import CALCULATOR_TOOL
|
|
56
57
|
from stirrup.tools.code_backends import CodeExecToolProvider, LocalCodeExecToolProvider
|
|
57
58
|
from stirrup.tools.finish import SIMPLE_FINISH_TOOL, FinishParams
|
|
59
|
+
from stirrup.tools.user_input import USER_INPUT_TOOL
|
|
58
60
|
from stirrup.tools.view_image import ViewImageToolProvider
|
|
59
61
|
from stirrup.tools.web import WebToolProvider
|
|
60
62
|
|
|
@@ -69,6 +71,7 @@ __all__ = [
|
|
|
69
71
|
"CALCULATOR_TOOL",
|
|
70
72
|
"DEFAULT_TOOLS",
|
|
71
73
|
"SIMPLE_FINISH_TOOL",
|
|
74
|
+
"USER_INPUT_TOOL",
|
|
72
75
|
"CodeExecToolProvider",
|
|
73
76
|
"FinishParams",
|
|
74
77
|
"LocalCodeExecToolProvider",
|