stirrup 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stirrup/core/cache.py ADDED
@@ -0,0 +1,479 @@
1
+ """Cache module for persisting and resuming agent state.
2
+
3
+ Provides functionality to cache agent state (messages, run metadata, execution environment files)
4
+ on non-success exits and restore that state for resumption in new runs.
5
+ """
6
+
7
+ import base64
8
+ import hashlib
9
+ import json
10
+ import logging
11
+ import os
12
+ import shutil
13
+ from dataclasses import dataclass, field
14
+ from datetime import UTC, datetime
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ from pydantic import TypeAdapter
19
+
20
+ from stirrup.core.models import (
21
+ AudioContentBlock,
22
+ ChatMessage,
23
+ ImageContentBlock,
24
+ VideoContentBlock,
25
+ )
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Default cache directory relative to the project root
30
+ DEFAULT_CACHE_DIR = Path("~/.cache/stirrup/").expanduser()
31
+
32
+ # TypeAdapter for deserializing ChatMessage discriminated union
33
+ ChatMessageAdapter: TypeAdapter[ChatMessage] = TypeAdapter(ChatMessage)
34
+
35
+
36
+ def compute_task_hash(init_msgs: str | list[ChatMessage]) -> str:
37
+ """Compute deterministic hash from initial messages for cache identification.
38
+
39
+ Args:
40
+ init_msgs: Either a string prompt or list of ChatMessage objects.
41
+
42
+ Returns:
43
+ First 12 characters of SHA256 hash (hex) for readability.
44
+ """
45
+ if isinstance(init_msgs, str):
46
+ content = init_msgs
47
+ else:
48
+ # Serialize messages to JSON for hashing
49
+ content = json.dumps(
50
+ [serialize_message(msg) for msg in init_msgs],
51
+ sort_keys=True,
52
+ ensure_ascii=True,
53
+ )
54
+
55
+ hash_bytes = hashlib.sha256(content.encode("utf-8")).hexdigest()
56
+ return hash_bytes[:12]
57
+
58
+
59
+ def _serialize_content_block(block: Any) -> dict | str: # noqa: ANN401
60
+ """Serialize a content block, encoding binary data as base64.
61
+
62
+ Args:
63
+ block: A content block (string, ImageContentBlock, VideoContentBlock, AudioContentBlock).
64
+
65
+ Returns:
66
+ JSON-serializable representation with base64-encoded binary data.
67
+ """
68
+ if isinstance(block, str):
69
+ return block
70
+ elif isinstance(block, ImageContentBlock):
71
+ return {
72
+ "kind": "image_content_block",
73
+ "data": base64.b64encode(block.data).decode("ascii"),
74
+ }
75
+ elif isinstance(block, VideoContentBlock):
76
+ return {
77
+ "kind": "video_content_block",
78
+ "data": base64.b64encode(block.data).decode("ascii"),
79
+ }
80
+ elif isinstance(block, AudioContentBlock):
81
+ return {
82
+ "kind": "audio_content_block",
83
+ "data": base64.b64encode(block.data).decode("ascii"),
84
+ }
85
+ elif isinstance(block, dict):
86
+ # Handle dict from model_dump that might contain unencoded bytes
87
+ # This can happen when Pydantic fails to base64-encode bytes in mode="json"
88
+ if "data" in block and isinstance(block["data"], bytes):
89
+ return {
90
+ **block,
91
+ "data": base64.b64encode(block["data"]).decode("ascii"),
92
+ }
93
+ return block
94
+ else:
95
+ raise ValueError(f"Unknown content block type: {type(block)}")
96
+
97
+
98
+ def _deserialize_content_block(data: dict | str) -> Any: # noqa: ANN401
99
+ """Deserialize a content block, decoding base64 binary data.
100
+
101
+ Args:
102
+ data: JSON-serialized content block.
103
+
104
+ Returns:
105
+ Restored content block with decoded binary data.
106
+ """
107
+ if isinstance(data, str):
108
+ return data
109
+ if not isinstance(data, dict):
110
+ return data
111
+
112
+ kind = data.get("kind")
113
+ if kind == "image_content_block":
114
+ return ImageContentBlock(data=base64.b64decode(data["data"]))
115
+ elif kind == "video_content_block":
116
+ return VideoContentBlock(data=base64.b64decode(data["data"]))
117
+ elif kind == "audio_content_block":
118
+ return AudioContentBlock(data=base64.b64decode(data["data"]))
119
+ else:
120
+ # Unknown or already-processed block
121
+ return data
122
+
123
+
124
+ def serialize_message(msg: ChatMessage) -> dict:
125
+ """Serialize a ChatMessage to JSON-compatible format.
126
+
127
+ Handles binary content blocks (images, video, audio) by base64 encoding.
128
+
129
+ Args:
130
+ msg: A ChatMessage (SystemMessage, UserMessage, AssistantMessage, ToolMessage).
131
+
132
+ Returns:
133
+ JSON-serializable dictionary.
134
+ """
135
+ # Use Pydantic's model_dump for base serialization
136
+ data = msg.model_dump(mode="json")
137
+
138
+ # Handle content field which may contain binary blocks
139
+ content = data.get("content")
140
+ if isinstance(content, list):
141
+ data["content"] = [_serialize_content_block(block) for block in content]
142
+ elif content is not None and not isinstance(content, str):
143
+ data["content"] = _serialize_content_block(content)
144
+
145
+ return data
146
+
147
+
148
+ def deserialize_message(data: dict) -> ChatMessage:
149
+ """Deserialize a ChatMessage from JSON format.
150
+
151
+ Handles base64-encoded binary content blocks.
152
+
153
+ Args:
154
+ data: JSON dictionary representing a ChatMessage.
155
+
156
+ Returns:
157
+ Restored ChatMessage object.
158
+ """
159
+ # Handle content field which may contain base64-encoded binary blocks
160
+ content = data.get("content")
161
+ if isinstance(content, list):
162
+ data["content"] = [_deserialize_content_block(block) for block in content]
163
+ elif content is not None and not isinstance(content, str):
164
+ data["content"] = _deserialize_content_block(content)
165
+
166
+ # Use TypeAdapter for discriminated union deserialization
167
+ return ChatMessageAdapter.validate_python(data)
168
+
169
+
170
+ def serialize_messages(msgs: list[ChatMessage]) -> list[dict]:
171
+ """Serialize a list of ChatMessages to JSON-compatible format.
172
+
173
+ Args:
174
+ msgs: List of ChatMessage objects.
175
+
176
+ Returns:
177
+ List of JSON-serializable dictionaries.
178
+ """
179
+ return [serialize_message(msg) for msg in msgs]
180
+
181
+
182
+ def _serialize_metadata_item(item: Any) -> Any: # noqa: ANN401
183
+ """Serialize a single metadata item to JSON-compatible format.
184
+
185
+ Handles Pydantic models by calling model_dump(mode='json').
186
+ Handles bytes by base64 encoding them.
187
+ """
188
+ from pydantic import BaseModel
189
+
190
+ if isinstance(item, BaseModel):
191
+ return item.model_dump(mode="json")
192
+ elif isinstance(item, bytes):
193
+ # Base64 encode raw bytes to make them JSON-serializable
194
+ return base64.b64encode(item).decode("ascii")
195
+ elif isinstance(item, dict):
196
+ return {k: _serialize_metadata_item(v) for k, v in item.items()}
197
+ elif isinstance(item, list):
198
+ return [_serialize_metadata_item(i) for i in item]
199
+ else:
200
+ return item
201
+
202
+
203
+ def _serialize_run_metadata(run_metadata: dict[str, list[Any]]) -> dict[str, list[Any]]:
204
+ """Serialize run_metadata dict containing Pydantic models to JSON-compatible format.
205
+
206
+ Args:
207
+ run_metadata: Dict mapping tool names to lists of metadata (may contain Pydantic models).
208
+
209
+ Returns:
210
+ JSON-serializable dictionary.
211
+ """
212
+ return {
213
+ tool_name: [_serialize_metadata_item(item) for item in metadata_list]
214
+ for tool_name, metadata_list in run_metadata.items()
215
+ }
216
+
217
+
218
+ def deserialize_messages(data: list[dict]) -> list[ChatMessage]:
219
+ """Deserialize a list of ChatMessages from JSON format.
220
+
221
+ Args:
222
+ data: List of JSON dictionaries representing ChatMessages.
223
+
224
+ Returns:
225
+ List of restored ChatMessage objects.
226
+ """
227
+ return [deserialize_message(msg_data) for msg_data in data]
228
+
229
+
230
+ @dataclass
231
+ class CacheState:
232
+ """Serializable state for resuming an agent run.
233
+
234
+ Captures all necessary state to resume execution from a specific turn.
235
+ """
236
+
237
+ msgs: list[ChatMessage]
238
+ """Current conversation messages in the active run loop."""
239
+
240
+ full_msg_history: list[list[ChatMessage]]
241
+ """Groups of messages (separated when context summarization occurs)."""
242
+
243
+ turn: int
244
+ """Current turn number (0-indexed) - resume will start from this turn."""
245
+
246
+ run_metadata: dict[str, list[Any]]
247
+ """Accumulated tool metadata from the run."""
248
+
249
+ task_hash: str
250
+ """Hash of the original init_msgs for verification on resume."""
251
+
252
+ timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
253
+ """ISO timestamp when cache was created."""
254
+
255
+ agent_name: str = ""
256
+ """Name of the agent that created this cache."""
257
+
258
+ def to_dict(self) -> dict:
259
+ """Convert to JSON-serializable dictionary."""
260
+ return {
261
+ "msgs": serialize_messages(self.msgs),
262
+ "full_msg_history": [serialize_messages(group) for group in self.full_msg_history],
263
+ "turn": self.turn,
264
+ "run_metadata": _serialize_run_metadata(self.run_metadata),
265
+ "task_hash": self.task_hash,
266
+ "timestamp": self.timestamp,
267
+ "agent_name": self.agent_name,
268
+ }
269
+
270
+ @classmethod
271
+ def from_dict(cls, data: dict) -> "CacheState":
272
+ """Create CacheState from JSON dictionary."""
273
+ return cls(
274
+ msgs=deserialize_messages(data["msgs"]),
275
+ full_msg_history=[deserialize_messages(group) for group in data["full_msg_history"]],
276
+ turn=data["turn"],
277
+ run_metadata=data["run_metadata"],
278
+ task_hash=data["task_hash"],
279
+ timestamp=data.get("timestamp", ""),
280
+ agent_name=data.get("agent_name", ""),
281
+ )
282
+
283
+
284
+ class CacheManager:
285
+ """Manages cache operations for agent sessions.
286
+
287
+ Handles saving/loading cache state and execution environment files.
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ cache_base_dir: Path | None = None,
293
+ clear_on_success: bool = True,
294
+ ) -> None:
295
+ """Initialize CacheManager.
296
+
297
+ Args:
298
+ cache_base_dir: Base directory for cache storage.
299
+ Defaults to ~/.cache/stirrup/
300
+ clear_on_success: If True (default), automatically clear the cache when
301
+ the agent completes successfully. Set to False to preserve
302
+ caches for inspection or manual management.
303
+ """
304
+ self._cache_base_dir = cache_base_dir or DEFAULT_CACHE_DIR
305
+ self.clear_on_success = clear_on_success
306
+
307
+ def _get_cache_dir(self, task_hash: str) -> Path:
308
+ """Get cache directory path for a task hash."""
309
+ return self._cache_base_dir / task_hash
310
+
311
+ def _get_state_file(self, task_hash: str) -> Path:
312
+ """Get state.json file path for a task hash."""
313
+ return self._get_cache_dir(task_hash) / "state.json"
314
+
315
+ def _get_files_dir(self, task_hash: str) -> Path:
316
+ """Get files directory path for a task hash."""
317
+ return self._get_cache_dir(task_hash) / "files"
318
+
319
+ def save_state(
320
+ self,
321
+ task_hash: str,
322
+ state: CacheState,
323
+ exec_env_dir: Path | None = None,
324
+ ) -> None:
325
+ """Save cache state and optionally archive execution environment files.
326
+
327
+ Uses atomic writes to prevent corrupted cache files if interrupted mid-write.
328
+
329
+ Args:
330
+ task_hash: Unique identifier for this task/cache.
331
+ state: CacheState to persist.
332
+ exec_env_dir: Optional path to execution environment temp directory.
333
+ If provided, all files will be copied to cache.
334
+ """
335
+ cache_dir = self._get_cache_dir(task_hash)
336
+ cache_dir.mkdir(parents=True, exist_ok=True)
337
+
338
+ # Save state JSON using atomic write (write to temp file, then rename)
339
+ state_file = self._get_state_file(task_hash)
340
+ temp_file = state_file.with_suffix(".json.tmp")
341
+
342
+ try:
343
+ state_data = state.to_dict()
344
+ logger.debug("Serialized cache state: turn=%d, msgs=%d", state.turn, len(state.msgs))
345
+
346
+ with open(temp_file, "w", encoding="utf-8") as f:
347
+ json.dump(state_data, f, indent=2, ensure_ascii=False)
348
+ f.flush()
349
+ os.fsync(f.fileno()) # Ensure data is written to disk
350
+
351
+ logger.debug("Wrote temp file: %s", temp_file)
352
+
353
+ # Atomic rename (on POSIX systems)
354
+ temp_file.replace(state_file)
355
+ logger.info("Saved cache state to %s (turn %d)", state_file, state.turn)
356
+ except Exception as e:
357
+ logger.exception("Failed to save cache state: %s", e)
358
+ # Try direct write as fallback
359
+ try:
360
+ logger.warning("Attempting direct write as fallback")
361
+ with open(state_file, "w", encoding="utf-8") as f:
362
+ json.dump(state_data, f, indent=2, ensure_ascii=False)
363
+ f.flush()
364
+ os.fsync(f.fileno())
365
+ logger.info("Fallback write succeeded to %s", state_file)
366
+ except Exception as e2:
367
+ logger.exception("Fallback write also failed: %s", e2)
368
+ # Clean up temp file if it exists
369
+ if temp_file.exists():
370
+ temp_file.unlink()
371
+ raise
372
+
373
+ # Copy execution environment files if provided
374
+ if exec_env_dir and exec_env_dir.exists():
375
+ files_dir = self._get_files_dir(task_hash)
376
+ if files_dir.exists():
377
+ shutil.rmtree(files_dir) # Clear existing files
378
+ shutil.copytree(exec_env_dir, files_dir, dirs_exist_ok=True)
379
+ logger.info("Saved execution environment files to %s", files_dir)
380
+
381
+ def load_state(self, task_hash: str) -> CacheState | None:
382
+ """Load cached state for a task hash.
383
+
384
+ Args:
385
+ task_hash: Unique identifier for the task/cache.
386
+
387
+ Returns:
388
+ CacheState if cache exists, None otherwise.
389
+ """
390
+ state_file = self._get_state_file(task_hash)
391
+ if not state_file.exists():
392
+ logger.debug("No cache found for task %s", task_hash)
393
+ return None
394
+
395
+ try:
396
+ with open(state_file, encoding="utf-8") as f:
397
+ data = json.load(f)
398
+ state = CacheState.from_dict(data)
399
+ logger.info("Loaded cache state from %s (turn %d)", state_file, state.turn)
400
+ return state
401
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
402
+ logger.warning("Failed to load cache for task %s: %s", task_hash, e)
403
+ return None
404
+
405
+ def restore_files(self, task_hash: str, dest_dir: Path) -> bool:
406
+ """Restore cached files to the destination directory.
407
+
408
+ Args:
409
+ task_hash: Unique identifier for the task/cache.
410
+ dest_dir: Destination directory (typically the new exec env temp dir).
411
+
412
+ Returns:
413
+ True if files were restored, False if no files cache exists.
414
+ """
415
+ files_dir = self._get_files_dir(task_hash)
416
+ if not files_dir.exists():
417
+ logger.debug("No cached files for task %s", task_hash)
418
+ return False
419
+
420
+ # Copy all files from cache to destination
421
+ for item in files_dir.iterdir():
422
+ dest_item = dest_dir / item.name
423
+ if item.is_file():
424
+ shutil.copy2(item, dest_item)
425
+ else:
426
+ shutil.copytree(item, dest_item, dirs_exist_ok=True)
427
+
428
+ logger.info("Restored cached files from %s to %s", files_dir, dest_dir)
429
+ return True
430
+
431
+ def clear_cache(self, task_hash: str) -> None:
432
+ """Remove cache for a specific task.
433
+
434
+ Called after successful completion to clean up.
435
+
436
+ Args:
437
+ task_hash: Unique identifier for the task/cache.
438
+ """
439
+ cache_dir = self._get_cache_dir(task_hash)
440
+ if cache_dir.exists():
441
+ shutil.rmtree(cache_dir)
442
+ logger.info("Cleared cache for task %s", task_hash)
443
+
444
+ def list_caches(self) -> list[str]:
445
+ """List all available cache hashes.
446
+
447
+ Returns:
448
+ List of task hashes with existing caches.
449
+ """
450
+ if not self._cache_base_dir.exists():
451
+ return []
452
+
453
+ return [d.name for d in self._cache_base_dir.iterdir() if d.is_dir() and (d / "state.json").exists()]
454
+
455
+ def get_cache_info(self, task_hash: str) -> dict | None:
456
+ """Get metadata about a cache without fully loading it.
457
+
458
+ Args:
459
+ task_hash: Unique identifier for the task/cache.
460
+
461
+ Returns:
462
+ Dictionary with cache info (turn, timestamp, agent_name) or None.
463
+ """
464
+ state_file = self._get_state_file(task_hash)
465
+ if not state_file.exists():
466
+ return None
467
+
468
+ try:
469
+ with open(state_file, encoding="utf-8") as f:
470
+ data = json.load(f)
471
+ return {
472
+ "task_hash": task_hash,
473
+ "turn": data.get("turn", 0),
474
+ "timestamp": data.get("timestamp", ""),
475
+ "agent_name": data.get("agent_name", ""),
476
+ "has_files": self._get_files_dir(task_hash).exists(),
477
+ }
478
+ except (json.JSONDecodeError, KeyError):
479
+ return None
stirrup/core/models.py CHANGED
@@ -1,3 +1,4 @@
1
+ import base64
1
2
  import mimetypes
2
3
  import warnings
3
4
  from abc import ABC, abstractmethod
@@ -15,7 +16,7 @@ import filetype
15
16
  from moviepy import AudioFileClip, VideoFileClip
16
17
  from moviepy.video.fx import Resize
17
18
  from PIL import Image
18
- from pydantic import BaseModel, Field, model_validator
19
+ from pydantic import BaseModel, Field, PlainSerializer, PlainValidator, model_validator
19
20
 
20
21
  from stirrup.constants import RESOLUTION_1MP, RESOLUTION_480P
21
22
 
@@ -27,6 +28,7 @@ __all__ = [
27
28
  "ChatMessage",
28
29
  "Content",
29
30
  "ContentBlock",
31
+ "EmptyParams",
30
32
  "ImageContentBlock",
31
33
  "LLMClient",
32
34
  "SubAgentMetadata",
@@ -44,6 +46,25 @@ __all__ = [
44
46
  ]
45
47
 
46
48
 
49
+ def _bytes_to_b64(v: bytes) -> str:
50
+ return base64.b64encode(v).decode("ascii")
51
+
52
+
53
+ def _b64_to_bytes(v: bytes | str) -> bytes:
54
+ if isinstance(v, bytes):
55
+ return v
56
+ if isinstance(v, str):
57
+ return base64.b64decode(v.encode("ascii"))
58
+ raise TypeError("Invalid bytes value")
59
+
60
+
61
+ Base64Bytes = Annotated[
62
+ bytes,
63
+ PlainValidator(_b64_to_bytes),
64
+ PlainSerializer(_bytes_to_b64, when_used="json"),
65
+ ]
66
+
67
+
47
68
  def downscale_image(w: int, h: int, max_pixels: int | None = 1_000_000) -> tuple[int, int]:
48
69
  """Downscale image dimensions to fit within max pixel count while maintaining aspect ratio.
49
70
 
@@ -58,7 +79,7 @@ def downscale_image(w: int, h: int, max_pixels: int | None = 1_000_000) -> tuple
58
79
  class BinaryContentBlock(BaseModel, ABC):
59
80
  """Base class for binary content (images, video, audio) with MIME type validation."""
60
81
 
61
- data: bytes
82
+ data: Base64Bytes
62
83
  allowed_mime_types: ClassVar[set[str]]
63
84
 
64
85
  @property
@@ -413,17 +434,27 @@ class ToolResult[M](BaseModel):
413
434
 
414
435
  Generic over metadata type M. M should implement Addable protocol for aggregation support,
415
436
  but this is not enforced at the class level due to Pydantic schema generation limitations.
437
+
438
+ Attributes:
439
+ content: The result content (string, list of content blocks, or images)
440
+ success: Whether the tool call was successful. For finish tools, controls if agent terminates.
441
+ metadata: Optional metadata (e.g., usage stats) that implements Addable for aggregation
416
442
  """
417
443
 
418
444
  content: Content
445
+ success: bool = True
419
446
  metadata: M | None = None
420
447
 
421
448
 
449
+ class EmptyParams(BaseModel):
450
+ """Empty parameter model for tools that don't require parameters."""
451
+
452
+
422
453
  class Tool[P: BaseModel, M](BaseModel):
423
454
  """Tool definition with name, description, parameter schema, and executor function.
424
455
 
425
456
  Generic over:
426
- P: Parameter model type (must be a Pydantic BaseModel, or None for parameterless tools)
457
+ P: Parameter model type (Pydantic BaseModel subclass, or EmptyParams for parameterless tools)
427
458
  M: Metadata type (should implement Addable for aggregation; use None for tools without metadata)
428
459
 
429
460
  Tools are simple, stateless callables. For tools requiring lifecycle management
@@ -442,9 +473,9 @@ class Tool[P: BaseModel, M](BaseModel):
442
473
  )
443
474
  ```
444
475
 
445
- Example without parameters:
476
+ Example without parameters (uses EmptyParams by default):
446
477
  ```python
447
- time_tool = Tool[None, None](
478
+ time_tool = Tool[EmptyParams, None](
448
479
  name="time",
449
480
  description="Get current time",
450
481
  executor=lambda _: ToolResult(content=datetime.now().isoformat()),
@@ -454,7 +485,7 @@ class Tool[P: BaseModel, M](BaseModel):
454
485
 
455
486
  name: str
456
487
  description: str
457
- parameters: type[P] | None = None
488
+ parameters: type[P] = EmptyParams # type: ignore[assignment]
458
489
  executor: Callable[[P], ToolResult[M] | Awaitable[ToolResult[M]]]
459
490
 
460
491
 
@@ -527,6 +558,7 @@ class ToolCall(BaseModel):
527
558
  tool_call_id: Unique identifier for tracking this tool call and its result
528
559
  """
529
560
 
561
+ signature: str | None = None
530
562
  name: str
531
563
  arguments: str
532
564
  tool_call_id: str | None = None
@@ -564,13 +596,23 @@ class AssistantMessage(BaseModel):
564
596
 
565
597
 
566
598
  class ToolMessage(BaseModel):
567
- """Tool execution result returned to the LLM."""
599
+ """Tool execution result returned to the LLM.
600
+
601
+ Attributes:
602
+ role: Always "tool"
603
+ content: The tool result content
604
+ tool_call_id: ID linking this result to the corresponding tool call
605
+ name: Name of the tool that was called
606
+ args_was_valid: Whether the tool arguments were valid
607
+ success: Whether the tool executed successfully (used by finish tool to control termination)
608
+ """
568
609
 
569
610
  role: Literal["tool"] = "tool"
570
611
  content: Content
571
612
  tool_call_id: str | None = None
572
613
  name: str | None = None
573
614
  args_was_valid: bool = True
615
+ success: bool = False
574
616
 
575
617
 
576
618
  type ChatMessage = Annotated[SystemMessage | UserMessage | AssistantMessage | ToolMessage, Field(discriminator="role")]
@@ -1 +1 @@
1
- You are an AI agent that will be given a specific task. You are to complete that task using the tools provided in {max_turns} steps. You will need to call the finish tool as your last step, where you will pass your finish reason and paths to any files that you wish to return to the user. You are not able to interact with the user during the task.
1
+ You are an AI agent that will be given a specific task. You are to complete that task using the tools provided in {max_turns} steps. You will need to call the finish tool as your last step, where you will pass your finish reason and paths to any files that you wish to return to the user.
stirrup/tools/__init__.py CHANGED
@@ -55,6 +55,7 @@ from stirrup.core.models import Tool, ToolProvider
55
55
  from stirrup.tools.calculator import CALCULATOR_TOOL
56
56
  from stirrup.tools.code_backends import CodeExecToolProvider, LocalCodeExecToolProvider
57
57
  from stirrup.tools.finish import SIMPLE_FINISH_TOOL, FinishParams
58
+ from stirrup.tools.user_input import USER_INPUT_TOOL
58
59
  from stirrup.tools.view_image import ViewImageToolProvider
59
60
  from stirrup.tools.web import WebToolProvider
60
61
 
@@ -69,6 +70,7 @@ __all__ = [
69
70
  "CALCULATOR_TOOL",
70
71
  "DEFAULT_TOOLS",
71
72
  "SIMPLE_FINISH_TOOL",
73
+ "USER_INPUT_TOOL",
72
74
  "CodeExecToolProvider",
73
75
  "FinishParams",
74
76
  "LocalCodeExecToolProvider",
@@ -21,7 +21,7 @@ def calculator_executor(params: CalculatorParams) -> ToolResult[ToolUseCountMeta
21
21
  result = eval(params.expression, {"__builtins__": {}}, {})
22
22
  return ToolResult(content=f"Result: {result}", metadata=ToolUseCountMetadata())
23
23
  except Exception as e:
24
- return ToolResult(content=f"Error evaluating expression: {e!s}", metadata=ToolUseCountMetadata())
24
+ return ToolResult(content=f"Error evaluating expression: {e!s}", success=False, metadata=ToolUseCountMetadata())
25
25
 
26
26
 
27
27
  CALCULATOR_TOOL: Tool[CalculatorParams, ToolUseCountMetadata] = Tool[CalculatorParams, ToolUseCountMetadata](
@@ -160,6 +160,11 @@ class CodeExecToolProvider(ToolProvider, ABC):
160
160
  if allowed_commands is not None:
161
161
  self._compiled_allowed = [re.compile(p) for p in allowed_commands]
162
162
 
163
+ @property
164
+ def temp_dir(self) -> Path | None:
165
+ """Return the temporary directory for this execution environment, if any."""
166
+ return None
167
+
163
168
  def _check_allowed(self, cmd: str) -> bool:
164
169
  """Check if command is allowed based on the allowlist.
165
170
 
@@ -419,11 +424,13 @@ class CodeExecToolProvider(ToolProvider, ABC):
419
424
  except FileNotFoundError:
420
425
  return ToolResult(
421
426
  content=f"Image `{params.path}` not found.",
427
+ success=False,
422
428
  metadata=ToolUseCountMetadata(),
423
429
  )
424
430
  except ValueError as e:
425
431
  return ToolResult(
426
432
  content=str(e),
433
+ success=False,
427
434
  metadata=ToolUseCountMetadata(),
428
435
  )
429
436