stirrup 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stirrup/__init__.py +2 -0
- stirrup/clients/chat_completions_client.py +0 -3
- stirrup/clients/litellm_client.py +20 -11
- stirrup/clients/utils.py +6 -1
- stirrup/constants.py +6 -2
- stirrup/core/agent.py +180 -57
- stirrup/core/cache.py +479 -0
- stirrup/core/models.py +49 -7
- stirrup/prompts/base_system_prompt.txt +1 -1
- stirrup/tools/__init__.py +2 -0
- stirrup/tools/calculator.py +1 -1
- stirrup/tools/code_backends/base.py +7 -0
- stirrup/tools/code_backends/e2b.py +25 -11
- stirrup/tools/code_backends/local.py +2 -2
- stirrup/tools/finish.py +1 -1
- stirrup/tools/user_input.py +130 -0
- stirrup/tools/web.py +1 -0
- stirrup/utils/logging.py +24 -0
- {stirrup-0.1.2.dist-info → stirrup-0.1.3.dist-info}/METADATA +1 -1
- stirrup-0.1.3.dist-info/RECORD +36 -0
- {stirrup-0.1.2.dist-info → stirrup-0.1.3.dist-info}/WHEEL +1 -1
- stirrup-0.1.2.dist-info/RECORD +0 -34
stirrup/__init__.py
CHANGED
|
@@ -35,6 +35,7 @@ from stirrup.core.models import (
|
|
|
35
35
|
AssistantMessage,
|
|
36
36
|
AudioContentBlock,
|
|
37
37
|
ChatMessage,
|
|
38
|
+
EmptyParams,
|
|
38
39
|
ImageContentBlock,
|
|
39
40
|
LLMClient,
|
|
40
41
|
SubAgentMetadata,
|
|
@@ -58,6 +59,7 @@ __all__ = [
|
|
|
58
59
|
"AudioContentBlock",
|
|
59
60
|
"ChatMessage",
|
|
60
61
|
"ContextOverflowError",
|
|
62
|
+
"EmptyParams",
|
|
61
63
|
"ImageContentBlock",
|
|
62
64
|
"LLMClient",
|
|
63
65
|
"SubAgentMetadata",
|
|
@@ -67,7 +67,6 @@ class ChatCompletionsClient(LLMClient):
|
|
|
67
67
|
*,
|
|
68
68
|
base_url: str | None = None,
|
|
69
69
|
api_key: str | None = None,
|
|
70
|
-
supports_audio_input: bool = False,
|
|
71
70
|
reasoning_effort: str | None = None,
|
|
72
71
|
timeout: float | None = None,
|
|
73
72
|
max_retries: int = 2,
|
|
@@ -82,7 +81,6 @@ class ChatCompletionsClient(LLMClient):
|
|
|
82
81
|
Use for OpenAI-compatible providers (e.g., 'http://localhost:8000/v1').
|
|
83
82
|
api_key: API key for authentication. If None, reads from OPENROUTER_API_KEY
|
|
84
83
|
environment variable.
|
|
85
|
-
supports_audio_input: Whether the model supports audio inputs. Defaults to False.
|
|
86
84
|
reasoning_effort: Reasoning effort level for extended thinking models
|
|
87
85
|
(e.g., 'low', 'medium', 'high'). Only used with o1/o3 style models.
|
|
88
86
|
timeout: Request timeout in seconds. If None, uses OpenAI SDK default.
|
|
@@ -92,7 +90,6 @@ class ChatCompletionsClient(LLMClient):
|
|
|
92
90
|
"""
|
|
93
91
|
self._model = model
|
|
94
92
|
self._max_tokens = max_tokens
|
|
95
|
-
self._supports_audio_input = supports_audio_input
|
|
96
93
|
self._reasoning_effort = reasoning_effort
|
|
97
94
|
self._kwargs = kwargs or {}
|
|
98
95
|
|
|
@@ -7,7 +7,7 @@ Requires the litellm extra: `pip install stirrup[litellm]`
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import logging
|
|
10
|
-
from typing import Any
|
|
10
|
+
from typing import Any, Literal
|
|
11
11
|
|
|
12
12
|
try:
|
|
13
13
|
from litellm import acompletion
|
|
@@ -38,6 +38,8 @@ __all__ = [
|
|
|
38
38
|
|
|
39
39
|
LOGGER = logging.getLogger(__name__)
|
|
40
40
|
|
|
41
|
+
type ReasoningEffort = Literal["none", "minimal", "low", "medium", "high", "xhigh", "default"]
|
|
42
|
+
|
|
41
43
|
|
|
42
44
|
class LiteLLMClient(LLMClient):
|
|
43
45
|
"""LiteLLM-based client supporting multiple LLM providers with unified interface.
|
|
@@ -49,8 +51,8 @@ class LiteLLMClient(LLMClient):
|
|
|
49
51
|
self,
|
|
50
52
|
model_slug: str,
|
|
51
53
|
max_tokens: int,
|
|
52
|
-
|
|
53
|
-
reasoning_effort:
|
|
54
|
+
api_key: str | None = None,
|
|
55
|
+
reasoning_effort: ReasoningEffort | None = None,
|
|
54
56
|
kwargs: dict[str, Any] | None = None,
|
|
55
57
|
) -> None:
|
|
56
58
|
"""Initialize LiteLLM client with model configuration and capabilities.
|
|
@@ -58,15 +60,13 @@ class LiteLLMClient(LLMClient):
|
|
|
58
60
|
Args:
|
|
59
61
|
model_slug: Model identifier for LiteLLM (e.g., 'anthropic/claude-3-5-sonnet-20241022')
|
|
60
62
|
max_tokens: Maximum context window size in tokens
|
|
61
|
-
supports_audio_input: Whether the model supports audio inputs
|
|
62
63
|
reasoning_effort: Reasoning effort level for extended thinking models (e.g., 'medium', 'high')
|
|
63
64
|
kwargs: Additional arguments to pass to LiteLLM completion calls
|
|
64
65
|
"""
|
|
65
66
|
self._model_slug = model_slug
|
|
66
|
-
self._supports_video_input = False
|
|
67
|
-
self._supports_audio_input = supports_audio_input
|
|
68
67
|
self._max_tokens = max_tokens
|
|
69
|
-
self._reasoning_effort = reasoning_effort
|
|
68
|
+
self._reasoning_effort: ReasoningEffort | None = reasoning_effort
|
|
69
|
+
self._api_key = api_key
|
|
70
70
|
self._kwargs = kwargs or {}
|
|
71
71
|
|
|
72
72
|
@property
|
|
@@ -92,6 +92,8 @@ class LiteLLMClient(LLMClient):
|
|
|
92
92
|
tools=to_openai_tools(tools) if tools else None,
|
|
93
93
|
tool_choice="auto" if tools else None,
|
|
94
94
|
max_tokens=self._max_tokens,
|
|
95
|
+
reasoning_effort=self._reasoning_effort,
|
|
96
|
+
api_key=self._api_key,
|
|
95
97
|
**self._kwargs,
|
|
96
98
|
)
|
|
97
99
|
|
|
@@ -103,14 +105,20 @@ class LiteLLMClient(LLMClient):
|
|
|
103
105
|
)
|
|
104
106
|
|
|
105
107
|
msg = choice["message"]
|
|
106
|
-
|
|
107
108
|
reasoning: Reasoning | None = None
|
|
108
109
|
if getattr(msg, "reasoning_content", None) is not None:
|
|
109
110
|
reasoning = Reasoning(content=msg.reasoning_content)
|
|
110
111
|
if getattr(msg, "thinking_blocks", None) is not None and len(msg.thinking_blocks) > 0:
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
112
|
+
if len(msg.thinking_blocks) > 1:
|
|
113
|
+
raise ValueError("Found multiple thinking blocks in the response")
|
|
114
|
+
|
|
115
|
+
signature = msg.thinking_blocks[0].get("thinking_signature", None)
|
|
116
|
+
content = msg.thinking_blocks[0].get("thinking", None)
|
|
117
|
+
|
|
118
|
+
if signature is None and content is None:
|
|
119
|
+
raise ValueError("Signature and content not found in the thinking block response")
|
|
120
|
+
|
|
121
|
+
reasoning = Reasoning(signature=signature, content=content)
|
|
114
122
|
|
|
115
123
|
usage = r["usage"]
|
|
116
124
|
|
|
@@ -119,6 +127,7 @@ class LiteLLMClient(LLMClient):
|
|
|
119
127
|
tool_call_id=tc.get("id"),
|
|
120
128
|
name=tc["function"]["name"],
|
|
121
129
|
arguments=tc["function"].get("arguments", "") or "",
|
|
130
|
+
signature=tc.get("provider_specific_fields", {}).get("thought_signature", None),
|
|
122
131
|
)
|
|
123
132
|
for tc in (msg.get("tool_calls") or [])
|
|
124
133
|
]
|
stirrup/clients/utils.py
CHANGED
|
@@ -12,6 +12,7 @@ from stirrup.core.models import (
|
|
|
12
12
|
AudioContentBlock,
|
|
13
13
|
ChatMessage,
|
|
14
14
|
Content,
|
|
15
|
+
EmptyParams,
|
|
15
16
|
ImageContentBlock,
|
|
16
17
|
SystemMessage,
|
|
17
18
|
Tool,
|
|
@@ -47,7 +48,7 @@ def to_openai_tools(tools: dict[str, Tool]) -> list[dict[str, Any]]:
|
|
|
47
48
|
"name": t.name,
|
|
48
49
|
"description": t.description,
|
|
49
50
|
}
|
|
50
|
-
if t.parameters is not
|
|
51
|
+
if t.parameters is not EmptyParams:
|
|
51
52
|
function["parameters"] = t.parameters.model_json_schema()
|
|
52
53
|
tool_payload: dict[str, Any] = {
|
|
53
54
|
"type": "function",
|
|
@@ -139,6 +140,10 @@ def to_openai_messages(msgs: list[ChatMessage]) -> list[dict[str, Any]]:
|
|
|
139
140
|
tool_dict = tool.model_dump()
|
|
140
141
|
tool_dict["id"] = tool.tool_call_id
|
|
141
142
|
tool_dict["type"] = "function"
|
|
143
|
+
if tool.signature is not None:
|
|
144
|
+
tool_dict["provider_specific_fields"] = {
|
|
145
|
+
"thought_signature": tool.signature,
|
|
146
|
+
}
|
|
142
147
|
tool_dict["function"] = {
|
|
143
148
|
"name": tool.name,
|
|
144
149
|
"arguments": tool.arguments,
|
stirrup/constants.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
1
3
|
# Tool naming
|
|
2
|
-
FINISH_TOOL_NAME = "finish"
|
|
4
|
+
FINISH_TOOL_NAME: Literal["finish"] = "finish"
|
|
3
5
|
|
|
4
6
|
# Agent execution limits
|
|
5
7
|
AGENT_MAX_TURNS = 30 # Maximum agent turns before forced termination
|
|
6
8
|
CONTEXT_SUMMARIZATION_CUTOFF = 0.7 # Context window usage threshold (0.0-1.0) that triggers message summarization
|
|
9
|
+
TURNS_REMAINING_WARNING_THRESHOLD = 20
|
|
7
10
|
|
|
8
11
|
# Media resolution limits
|
|
9
12
|
RESOLUTION_1MP = 1_000_000 # 1 megapixel - default max resolution for images
|
|
10
13
|
RESOLUTION_480P = 640 * 480 # 480p video resolution
|
|
11
14
|
|
|
12
15
|
# Code execution
|
|
13
|
-
|
|
16
|
+
SANDBOX_TIMEOUT = 60 * 10 # 10 minutes
|
|
17
|
+
SANDBOX_REQUEST_TIMEOUT = 60 * 3 # 3 minutes
|
|
14
18
|
E2B_SANDBOX_TEMPLATE_ALIAS = "e2b-sandbox"
|
stirrup/core/agent.py
CHANGED
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
import contextvars
|
|
3
3
|
import glob as glob_module
|
|
4
4
|
import inspect
|
|
5
|
-
import json
|
|
6
5
|
import logging
|
|
7
6
|
import re
|
|
7
|
+
import signal
|
|
8
8
|
from contextlib import AsyncExitStack
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from itertools import chain, takewhile
|
|
@@ -19,7 +19,9 @@ from stirrup.constants import (
|
|
|
19
19
|
AGENT_MAX_TURNS,
|
|
20
20
|
CONTEXT_SUMMARIZATION_CUTOFF,
|
|
21
21
|
FINISH_TOOL_NAME,
|
|
22
|
+
TURNS_REMAINING_WARNING_THRESHOLD,
|
|
22
23
|
)
|
|
24
|
+
from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
|
|
23
25
|
from stirrup.core.models import (
|
|
24
26
|
AssistantMessage,
|
|
25
27
|
ChatMessage,
|
|
@@ -72,6 +74,7 @@ class SessionState:
|
|
|
72
74
|
depth: int = 0
|
|
73
75
|
uploaded_file_paths: list[str] = field(default_factory=list) # Paths of files uploaded to exec_env
|
|
74
76
|
skills_metadata: list[SkillMetadata] = field(default_factory=list) # Loaded skills metadata
|
|
77
|
+
logger: AgentLoggerBase | None = None # Logger for pause/resume during user input
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
_SESSION_STATE: contextvars.ContextVar[SessionState] = contextvars.ContextVar("session_state")
|
|
@@ -112,17 +115,19 @@ def _handle_text_only_tool_responses(tool_messages: list[ToolMessage]) -> tuple[
|
|
|
112
115
|
return tool_messages, user_messages
|
|
113
116
|
|
|
114
117
|
|
|
115
|
-
def _get_total_token_usage(messages: list[list[ChatMessage]]) -> TokenUsage:
|
|
116
|
-
"""
|
|
118
|
+
def _get_total_token_usage(messages: list[list[ChatMessage]]) -> list[TokenUsage]:
|
|
119
|
+
"""
|
|
120
|
+
Returns a list of TokenUsage objects aggregated from all AssistantMessage
|
|
121
|
+
instances across the provided grouped message history.
|
|
117
122
|
|
|
118
123
|
Args:
|
|
119
|
-
messages:
|
|
124
|
+
messages: A list where each item is a list of ChatMessage objects representing a segment
|
|
125
|
+
or turn group of the conversation history.
|
|
120
126
|
|
|
127
|
+
Returns:
|
|
128
|
+
List of TokenUsage corresponding to each AssistantMessage in the flattened conversation history.
|
|
121
129
|
"""
|
|
122
|
-
return
|
|
123
|
-
[msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)],
|
|
124
|
-
start=TokenUsage(),
|
|
125
|
-
)
|
|
130
|
+
return [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)]
|
|
126
131
|
|
|
127
132
|
|
|
128
133
|
class SubAgentParams(BaseModel):
|
|
@@ -176,6 +181,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
176
181
|
finish_tool: Tool[FinishParams, FinishMeta] | None = None,
|
|
177
182
|
# Agent options
|
|
178
183
|
context_summarization_cutoff: float = CONTEXT_SUMMARIZATION_CUTOFF,
|
|
184
|
+
turns_remaining_warning_threshold: int = TURNS_REMAINING_WARNING_THRESHOLD,
|
|
179
185
|
run_sync_in_thread: bool = True,
|
|
180
186
|
text_only_tool_responses: bool = True,
|
|
181
187
|
# Logging
|
|
@@ -215,6 +221,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
215
221
|
self._tools = tools if tools is not None else DEFAULT_TOOLS
|
|
216
222
|
self._finish_tool: Tool = finish_tool if finish_tool is not None else SIMPLE_FINISH_TOOL
|
|
217
223
|
self._context_summarization_cutoff = context_summarization_cutoff
|
|
224
|
+
self._turns_remaining_warning_threshold = turns_remaining_warning_threshold
|
|
218
225
|
self._run_sync_in_thread = run_sync_in_thread
|
|
219
226
|
self._text_only_tool_responses = text_only_tool_responses
|
|
220
227
|
|
|
@@ -225,6 +232,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
225
232
|
self._pending_output_dir: Path | None = None
|
|
226
233
|
self._pending_input_files: str | Path | list[str | Path] | None = None
|
|
227
234
|
self._pending_skills_dir: Path | None = None
|
|
235
|
+
self._resume: bool = False
|
|
236
|
+
self._clear_cache_on_success: bool = True
|
|
228
237
|
|
|
229
238
|
# Instance-scoped state (populated during __aenter__, isolated per agent instance)
|
|
230
239
|
self._active_tools: dict[str, Tool] = {}
|
|
@@ -232,6 +241,10 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
232
241
|
self._last_run_metadata: dict[str, list[Any]] = {}
|
|
233
242
|
self._transferred_paths: list[str] = [] # Paths transferred to parent (for subagents)
|
|
234
243
|
|
|
244
|
+
# Cache state for resumption (set during run(), used in __aexit__ for caching on interrupt)
|
|
245
|
+
self._current_task_hash: str | None = None
|
|
246
|
+
self._current_run_state: CacheState | None = None
|
|
247
|
+
|
|
235
248
|
@property
|
|
236
249
|
def name(self) -> str:
|
|
237
250
|
"""The name of this agent."""
|
|
@@ -262,6 +275,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
262
275
|
output_dir: Path | str | None = None,
|
|
263
276
|
input_files: str | Path | list[str | Path] | None = None,
|
|
264
277
|
skills_dir: Path | str | None = None,
|
|
278
|
+
resume: bool = False,
|
|
279
|
+
clear_cache_on_success: bool = True,
|
|
265
280
|
) -> Self:
|
|
266
281
|
"""Configure a session and return self for use as async context manager.
|
|
267
282
|
|
|
@@ -277,6 +292,13 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
277
292
|
skills_dir: Directory containing skill definitions to load and make available
|
|
278
293
|
to the agent. Skills are uploaded to the execution environment
|
|
279
294
|
and their metadata is included in the system prompt.
|
|
295
|
+
resume: If True, attempt to resume from cached state if available.
|
|
296
|
+
The cache is identified by hashing the init_msgs passed to run().
|
|
297
|
+
Cached state includes message history, current turn, and execution
|
|
298
|
+
environment files from a previous interrupted run.
|
|
299
|
+
clear_cache_on_success: If True (default), automatically clear the cache
|
|
300
|
+
when the agent completes successfully. Set to False
|
|
301
|
+
to preserve caches for inspection or debugging.
|
|
280
302
|
|
|
281
303
|
Returns:
|
|
282
304
|
Self, for use with `async with agent.session(...) as session:`
|
|
@@ -293,8 +315,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
293
315
|
self._pending_output_dir = Path(output_dir) if output_dir else None
|
|
294
316
|
self._pending_input_files = input_files
|
|
295
317
|
self._pending_skills_dir = Path(skills_dir) if skills_dir else None
|
|
318
|
+
self._resume = resume
|
|
319
|
+
self._clear_cache_on_success = clear_cache_on_success
|
|
296
320
|
return self
|
|
297
321
|
|
|
322
|
+
def _handle_interrupt(self, _signum: int, _frame: object) -> None:
|
|
323
|
+
"""Handle SIGINT to ensure caching before exit.
|
|
324
|
+
|
|
325
|
+
Converts the signal to a KeyboardInterrupt exception so that __aexit__
|
|
326
|
+
is properly called and can cache the state before cleanup.
|
|
327
|
+
"""
|
|
328
|
+
raise KeyboardInterrupt("Agent interrupted - state will be cached")
|
|
329
|
+
|
|
298
330
|
def _resolve_input_files(self, input_files: str | Path | list[str | Path]) -> list[Path]:
|
|
299
331
|
"""Resolve input file paths, expanding globs and normalizing to Path objects.
|
|
300
332
|
|
|
@@ -410,6 +442,15 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
410
442
|
# Base prompt with max_turns
|
|
411
443
|
parts.append(BASE_SYSTEM_PROMPT_TEMPLATE.format(max_turns=self._max_turns))
|
|
412
444
|
|
|
445
|
+
# User interaction guidance based on whether user_input tool is available
|
|
446
|
+
if "user_input" in self._active_tools:
|
|
447
|
+
parts.append(
|
|
448
|
+
" You have access to the user_input tool which allows you to ask the user "
|
|
449
|
+
"questions when you need clarification or are uncertain about something."
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
parts.append(" You are not able to interact with the user during the task.")
|
|
453
|
+
|
|
413
454
|
# Input files section (if any were uploaded)
|
|
414
455
|
state = _SESSION_STATE.get(None)
|
|
415
456
|
if state and state.uploaded_file_paths:
|
|
@@ -514,6 +555,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
514
555
|
output_dir=str(self._pending_output_dir) if self._pending_output_dir else None,
|
|
515
556
|
parent_exec_env=parent_state.exec_env if parent_state else None,
|
|
516
557
|
depth=current_depth,
|
|
558
|
+
logger=self._logger,
|
|
517
559
|
)
|
|
518
560
|
_SESSION_STATE.set(state)
|
|
519
561
|
|
|
@@ -621,6 +663,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
621
663
|
# depth is already set (0 for main agent, passed in for sub-agents)
|
|
622
664
|
self._logger.__enter__()
|
|
623
665
|
|
|
666
|
+
# Set up signal handler for graceful caching on interrupt (root agent only)
|
|
667
|
+
if current_depth == 0:
|
|
668
|
+
self._original_sigint = signal.getsignal(signal.SIGINT)
|
|
669
|
+
signal.signal(signal.SIGINT, self._handle_interrupt)
|
|
670
|
+
|
|
624
671
|
return self
|
|
625
672
|
|
|
626
673
|
except Exception:
|
|
@@ -642,6 +689,47 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
642
689
|
state = _SESSION_STATE.get()
|
|
643
690
|
|
|
644
691
|
try:
|
|
692
|
+
# Cache state on non-success exit (only at root level)
|
|
693
|
+
should_cache = (
|
|
694
|
+
state.depth == 0
|
|
695
|
+
and (exc_type is not None or self._last_finish_params is None)
|
|
696
|
+
and self._current_task_hash is not None
|
|
697
|
+
and self._current_run_state is not None
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
logger.debug(
|
|
701
|
+
"[%s __aexit__] Cache decision: should_cache=%s, depth=%d, exc_type=%s, "
|
|
702
|
+
"finish_params=%s, task_hash=%s, run_state=%s",
|
|
703
|
+
self._name,
|
|
704
|
+
should_cache,
|
|
705
|
+
state.depth,
|
|
706
|
+
exc_type,
|
|
707
|
+
self._last_finish_params is not None,
|
|
708
|
+
self._current_task_hash,
|
|
709
|
+
self._current_run_state is not None,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if should_cache:
|
|
713
|
+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
|
|
714
|
+
|
|
715
|
+
exec_env_dir = state.exec_env.temp_dir if state.exec_env else None
|
|
716
|
+
|
|
717
|
+
# Explicit checks to keep type checker happy - should_cache condition guarantees these
|
|
718
|
+
if self._current_task_hash is None or self._current_run_state is None:
|
|
719
|
+
raise ValueError("Cache state is unexpectedly None after should_cache check")
|
|
720
|
+
|
|
721
|
+
# Temporarily block SIGINT during cache save to prevent interruption
|
|
722
|
+
original_handler = signal.getsignal(signal.SIGINT)
|
|
723
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
724
|
+
try:
|
|
725
|
+
cache_manager.save_state(
|
|
726
|
+
self._current_task_hash,
|
|
727
|
+
self._current_run_state,
|
|
728
|
+
exec_env_dir,
|
|
729
|
+
)
|
|
730
|
+
finally:
|
|
731
|
+
signal.signal(signal.SIGINT, original_handler)
|
|
732
|
+
self._logger.info(f"Cached state for task {self._current_task_hash}")
|
|
645
733
|
# Save files from finish_params.paths based on depth
|
|
646
734
|
if state.output_dir and self._last_finish_params and state.exec_env:
|
|
647
735
|
paths = getattr(self._last_finish_params, "paths", None)
|
|
@@ -696,6 +784,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
696
784
|
state.depth,
|
|
697
785
|
)
|
|
698
786
|
finally:
|
|
787
|
+
# Restore original signal handler (root agent only)
|
|
788
|
+
if hasattr(self, "_original_sigint"):
|
|
789
|
+
signal.signal(signal.SIGINT, self._original_sigint)
|
|
790
|
+
del self._original_sigint
|
|
791
|
+
|
|
699
792
|
# Exit logger context
|
|
700
793
|
self._logger.finish_params = self._last_finish_params
|
|
701
794
|
self._logger.run_metadata = self._last_run_metadata
|
|
@@ -721,10 +814,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
721
814
|
|
|
722
815
|
if tool:
|
|
723
816
|
try:
|
|
724
|
-
|
|
725
|
-
params = (
|
|
726
|
-
tool.parameters.model_validate_json(tool_call.arguments) if tool.parameters is not None else None
|
|
727
|
-
)
|
|
817
|
+
params = tool.parameters.model_validate_json(tool_call.arguments)
|
|
728
818
|
|
|
729
819
|
# Set parent depth for sub-agent tools to read
|
|
730
820
|
prev_depth = _PARENT_DEPTH.set(self._logger.depth)
|
|
@@ -749,17 +839,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
749
839
|
tool_call.name,
|
|
750
840
|
tool_call.arguments,
|
|
751
841
|
)
|
|
752
|
-
result = ToolResult(content="Tool arguments are not valid")
|
|
842
|
+
result = ToolResult(content="Tool arguments are not valid", success=False)
|
|
753
843
|
args_valid = False
|
|
754
844
|
else:
|
|
755
845
|
LOGGER.debug(f"LLMClient tried to use the tool {tool_call.name} which is not in the tools list")
|
|
756
|
-
result = ToolResult(content=f"{tool_call.name} is not a valid tool")
|
|
846
|
+
result = ToolResult(content=f"{tool_call.name} is not a valid tool", success=False)
|
|
757
847
|
|
|
758
848
|
return ToolMessage(
|
|
759
849
|
content=result.content,
|
|
760
850
|
tool_call_id=tool_call.tool_call_id,
|
|
761
851
|
name=tool_call.name,
|
|
762
852
|
args_was_valid=args_valid,
|
|
853
|
+
success=result.success,
|
|
763
854
|
)
|
|
764
855
|
|
|
765
856
|
async def step(
|
|
@@ -768,7 +859,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
768
859
|
run_metadata: dict[str, list[Any]],
|
|
769
860
|
turn: int = 0,
|
|
770
861
|
max_turns: int = 0,
|
|
771
|
-
) -> tuple[AssistantMessage, list[ToolMessage],
|
|
862
|
+
) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None]:
|
|
772
863
|
"""Execute one agent step: generate assistant message and run any requested tool calls.
|
|
773
864
|
|
|
774
865
|
Args:
|
|
@@ -786,24 +877,21 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
786
877
|
if turn > 0:
|
|
787
878
|
self._logger.assistant_message(turn, max_turns, assistant_message)
|
|
788
879
|
|
|
880
|
+
finish_params: FinishParams | None = None
|
|
789
881
|
tool_messages: list[ToolMessage] = []
|
|
790
|
-
finish_call: ToolCall | None = None
|
|
791
|
-
|
|
792
882
|
if assistant_message.tool_calls:
|
|
793
|
-
finish_call = next(
|
|
794
|
-
(tc for tc in assistant_message.tool_calls if tc.name == FINISH_TOOL_NAME),
|
|
795
|
-
None,
|
|
796
|
-
)
|
|
797
|
-
|
|
798
883
|
tool_messages = []
|
|
799
884
|
for tool_call in assistant_message.tool_calls:
|
|
800
885
|
tool_message = await self.run_tool(tool_call, run_metadata)
|
|
801
886
|
tool_messages.append(tool_message)
|
|
802
887
|
|
|
888
|
+
if tool_message.success and tool_message.name == FINISH_TOOL_NAME:
|
|
889
|
+
finish_params = self._finish_tool.parameters.model_validate_json(tool_call.arguments)
|
|
890
|
+
|
|
803
891
|
# Log tool result immediately
|
|
804
892
|
self._logger.tool_result(tool_message)
|
|
805
893
|
|
|
806
|
-
return assistant_message, tool_messages,
|
|
894
|
+
return assistant_message, tool_messages, finish_params
|
|
807
895
|
|
|
808
896
|
async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]:
|
|
809
897
|
"""Condense message history using LLM to stay within context window."""
|
|
@@ -829,7 +917,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
829
917
|
init_msgs: str | list[ChatMessage],
|
|
830
918
|
*,
|
|
831
919
|
depth: int | None = None,
|
|
832
|
-
) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str,
|
|
920
|
+
) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, Any]]:
|
|
833
921
|
"""Execute the agent loop until finish tool is called or max_turns reached.
|
|
834
922
|
|
|
835
923
|
A base system prompt is automatically prepended to all runs, including:
|
|
@@ -859,23 +947,59 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
859
947
|
])
|
|
860
948
|
|
|
861
949
|
"""
|
|
862
|
-
msgs: list[ChatMessage] = []
|
|
863
950
|
|
|
864
|
-
#
|
|
865
|
-
|
|
866
|
-
|
|
951
|
+
# Compute task hash for caching/resume
|
|
952
|
+
task_hash = compute_task_hash(init_msgs)
|
|
953
|
+
self._current_task_hash = task_hash
|
|
954
|
+
|
|
955
|
+
# Initialize cache manager
|
|
956
|
+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
|
|
957
|
+
start_turn = 0
|
|
958
|
+
resumed = False
|
|
959
|
+
|
|
960
|
+
# Try to resume from cache if requested
|
|
961
|
+
if self._resume:
|
|
962
|
+
state = _SESSION_STATE.get()
|
|
963
|
+
cached = cache_manager.load_state(task_hash)
|
|
964
|
+
if cached:
|
|
965
|
+
# Restore files to exec env
|
|
966
|
+
if state.exec_env and state.exec_env.temp_dir:
|
|
967
|
+
cache_manager.restore_files(task_hash, state.exec_env.temp_dir)
|
|
968
|
+
|
|
969
|
+
# Restore state
|
|
970
|
+
msgs = cached.msgs
|
|
971
|
+
full_msg_history = cached.full_msg_history
|
|
972
|
+
run_metadata = cached.run_metadata
|
|
973
|
+
start_turn = cached.turn
|
|
974
|
+
resumed = True
|
|
975
|
+
self._logger.info(f"Resuming from cached state at turn {start_turn}")
|
|
976
|
+
else:
|
|
977
|
+
self._logger.info(f"No cache found for task {task_hash}, starting fresh")
|
|
867
978
|
|
|
868
|
-
if
|
|
869
|
-
msgs
|
|
870
|
-
|
|
871
|
-
|
|
979
|
+
if not resumed:
|
|
980
|
+
msgs: list[ChatMessage] = []
|
|
981
|
+
|
|
982
|
+
# Build the complete system prompt (base + input files + user instructions)
|
|
983
|
+
full_system_prompt = self._build_system_prompt()
|
|
984
|
+
msgs.append(SystemMessage(content=full_system_prompt))
|
|
985
|
+
|
|
986
|
+
if isinstance(init_msgs, str):
|
|
987
|
+
msgs.append(UserMessage(content=init_msgs))
|
|
988
|
+
else:
|
|
989
|
+
msgs.extend(init_msgs)
|
|
990
|
+
|
|
991
|
+
# Local metadata storage - isolated per run() invocation for thread safety
|
|
992
|
+
run_metadata: dict[str, list[Any]] = {}
|
|
993
|
+
|
|
994
|
+
full_msg_history: list[list[ChatMessage]] = []
|
|
872
995
|
|
|
873
996
|
# Set logger depth if provided (for sub-agent runs)
|
|
874
997
|
if depth is not None:
|
|
875
998
|
self._logger.depth = depth
|
|
876
999
|
|
|
877
|
-
# Log the task at run start
|
|
878
|
-
|
|
1000
|
+
# Log the task at run start (only if not resuming)
|
|
1001
|
+
if not resumed:
|
|
1002
|
+
self._logger.task_message(msgs[-1].content)
|
|
879
1003
|
|
|
880
1004
|
# Show warnings (top-level only, if logger supports it)
|
|
881
1005
|
if self._logger.depth == 0 and isinstance(self._logger, AgentLogger):
|
|
@@ -886,25 +1010,30 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
886
1010
|
# Use logger callback if available and not overridden
|
|
887
1011
|
step_callback = self._logger.on_step
|
|
888
1012
|
|
|
889
|
-
# Local metadata storage - isolated per run() invocation for thread safety
|
|
890
|
-
run_metadata: dict[str, list[Any]] = {}
|
|
891
|
-
|
|
892
1013
|
full_msg_history: list[list[ChatMessage]] = []
|
|
893
|
-
finish_params: FinishParams | None = None
|
|
894
1014
|
|
|
895
1015
|
# Cumulative stats for spinner
|
|
896
1016
|
total_tool_calls = 0
|
|
897
1017
|
total_input_tokens = 0
|
|
898
1018
|
total_output_tokens = 0
|
|
899
1019
|
|
|
900
|
-
for i in range(self._max_turns):
|
|
901
|
-
|
|
1020
|
+
for i in range(start_turn, self._max_turns):
|
|
1021
|
+
# Capture current state for potential caching (before any async work)
|
|
1022
|
+
self._current_run_state = CacheState(
|
|
1023
|
+
msgs=list(msgs),
|
|
1024
|
+
full_msg_history=[list(group) for group in full_msg_history],
|
|
1025
|
+
turn=i,
|
|
1026
|
+
run_metadata=dict(run_metadata),
|
|
1027
|
+
task_hash=task_hash,
|
|
1028
|
+
agent_name=self._name,
|
|
1029
|
+
)
|
|
1030
|
+
if self._max_turns - i <= self._turns_remaining_warning_threshold and i != 0:
|
|
902
1031
|
num_turns_remaining_msg = _num_turns_remaining_msg(self._max_turns - i)
|
|
903
1032
|
msgs.append(num_turns_remaining_msg)
|
|
904
1033
|
self._logger.user_message(num_turns_remaining_msg)
|
|
905
1034
|
|
|
906
1035
|
# Pass turn info to step() for real-time logging
|
|
907
|
-
assistant_message, tool_messages,
|
|
1036
|
+
assistant_message, tool_messages, finish_params = await self.step(
|
|
908
1037
|
msgs,
|
|
909
1038
|
run_metadata,
|
|
910
1039
|
turn=i + 1,
|
|
@@ -930,18 +1059,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
930
1059
|
|
|
931
1060
|
msgs.extend([assistant_message, *tool_messages, *user_messages])
|
|
932
1061
|
|
|
933
|
-
if
|
|
934
|
-
|
|
935
|
-
finish_arguments = json.loads(finish_call.arguments)
|
|
936
|
-
if self._finish_tool.parameters is not None:
|
|
937
|
-
finish_params = self._finish_tool.parameters.model_validate(finish_arguments)
|
|
938
|
-
break
|
|
939
|
-
except (json.JSONDecodeError, ValidationError, TypeError):
|
|
940
|
-
LOGGER.debug(
|
|
941
|
-
"Agent tried to use the finish tool but the tool call is not valid: %r",
|
|
942
|
-
finish_call.arguments,
|
|
943
|
-
)
|
|
944
|
-
# continue until the finish tool call is valid
|
|
1062
|
+
if finish_params:
|
|
1063
|
+
break
|
|
945
1064
|
|
|
946
1065
|
pct_context_used = assistant_message.token_usage.total / self._client.max_tokens
|
|
947
1066
|
if pct_context_used >= self._context_summarization_cutoff and i + 1 != self._max_turns:
|
|
@@ -956,15 +1075,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
956
1075
|
full_msg_history.append(msgs)
|
|
957
1076
|
|
|
958
1077
|
# Add agent's own token usage to run_metadata under "token_usage" key
|
|
959
|
-
|
|
960
|
-
if "token_usage" not in run_metadata:
|
|
961
|
-
run_metadata["token_usage"] = []
|
|
962
|
-
run_metadata["token_usage"].append(agent_token_usage)
|
|
1078
|
+
run_metadata["token_usage"] = _get_total_token_usage(full_msg_history)
|
|
963
1079
|
|
|
964
1080
|
# Store for __aexit__ to access (on instance for this agent)
|
|
965
1081
|
self._last_finish_params = finish_params
|
|
966
1082
|
self._last_run_metadata = run_metadata
|
|
967
1083
|
|
|
1084
|
+
# Clear cache on successful completion (finish_params is set)
|
|
1085
|
+
if finish_params is not None and cache_manager.clear_on_success:
|
|
1086
|
+
cache_manager.clear_cache(task_hash)
|
|
1087
|
+
self._current_task_hash = None
|
|
1088
|
+
self._current_run_state = None
|
|
1089
|
+
|
|
968
1090
|
return finish_params, full_msg_history, run_metadata
|
|
969
1091
|
|
|
970
1092
|
def to_tool(
|
|
@@ -1092,6 +1214,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
1092
1214
|
)
|
|
1093
1215
|
return ToolResult(
|
|
1094
1216
|
content=f"<sub_agent_result>\n<error>{e!s}</error>\n</sub_agent_result>",
|
|
1217
|
+
success=False,
|
|
1095
1218
|
metadata=error_metadata,
|
|
1096
1219
|
)
|
|
1097
1220
|
finally:
|