stirrup 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stirrup/__init__.py +2 -0
- stirrup/clients/chat_completions_client.py +0 -3
- stirrup/clients/litellm_client.py +20 -11
- stirrup/clients/utils.py +6 -1
- stirrup/constants.py +6 -2
- stirrup/core/agent.py +206 -57
- stirrup/core/cache.py +479 -0
- stirrup/core/models.py +53 -7
- stirrup/prompts/base_system_prompt.txt +1 -1
- stirrup/skills/__init__.py +24 -0
- stirrup/skills/skills.py +145 -0
- stirrup/tools/__init__.py +2 -0
- stirrup/tools/calculator.py +1 -1
- stirrup/tools/code_backends/base.py +7 -0
- stirrup/tools/code_backends/docker.py +16 -4
- stirrup/tools/code_backends/e2b.py +32 -13
- stirrup/tools/code_backends/local.py +16 -4
- stirrup/tools/finish.py +1 -1
- stirrup/tools/user_input.py +130 -0
- stirrup/tools/web.py +1 -0
- stirrup/utils/logging.py +24 -0
- {stirrup-0.1.1.dist-info → stirrup-0.1.3.dist-info}/METADATA +36 -16
- stirrup-0.1.3.dist-info/RECORD +36 -0
- {stirrup-0.1.1.dist-info → stirrup-0.1.3.dist-info}/WHEEL +1 -1
- stirrup-0.1.1.dist-info/RECORD +0 -32
stirrup/core/agent.py
CHANGED
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
import contextvars
|
|
3
3
|
import glob as glob_module
|
|
4
4
|
import inspect
|
|
5
|
-
import json
|
|
6
5
|
import logging
|
|
7
6
|
import re
|
|
7
|
+
import signal
|
|
8
8
|
from contextlib import AsyncExitStack
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from itertools import chain, takewhile
|
|
@@ -19,7 +19,9 @@ from stirrup.constants import (
|
|
|
19
19
|
AGENT_MAX_TURNS,
|
|
20
20
|
CONTEXT_SUMMARIZATION_CUTOFF,
|
|
21
21
|
FINISH_TOOL_NAME,
|
|
22
|
+
TURNS_REMAINING_WARNING_THRESHOLD,
|
|
22
23
|
)
|
|
24
|
+
from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
|
|
23
25
|
from stirrup.core.models import (
|
|
24
26
|
AssistantMessage,
|
|
25
27
|
ChatMessage,
|
|
@@ -36,6 +38,7 @@ from stirrup.core.models import (
|
|
|
36
38
|
UserMessage,
|
|
37
39
|
)
|
|
38
40
|
from stirrup.prompts import MESSAGE_SUMMARIZER, MESSAGE_SUMMARIZER_BRIDGE_TEMPLATE
|
|
41
|
+
from stirrup.skills import SkillMetadata, format_skills_section, load_skills_metadata
|
|
39
42
|
from stirrup.tools import DEFAULT_TOOLS
|
|
40
43
|
from stirrup.tools.code_backends.base import CodeExecToolProvider
|
|
41
44
|
from stirrup.tools.code_backends.local import LocalCodeExecToolProvider
|
|
@@ -70,6 +73,8 @@ class SessionState:
|
|
|
70
73
|
parent_exec_env: CodeExecToolProvider | None = None
|
|
71
74
|
depth: int = 0
|
|
72
75
|
uploaded_file_paths: list[str] = field(default_factory=list) # Paths of files uploaded to exec_env
|
|
76
|
+
skills_metadata: list[SkillMetadata] = field(default_factory=list) # Loaded skills metadata
|
|
77
|
+
logger: AgentLoggerBase | None = None # Logger for pause/resume during user input
|
|
73
78
|
|
|
74
79
|
|
|
75
80
|
_SESSION_STATE: contextvars.ContextVar[SessionState] = contextvars.ContextVar("session_state")
|
|
@@ -110,17 +115,19 @@ def _handle_text_only_tool_responses(tool_messages: list[ToolMessage]) -> tuple[
|
|
|
110
115
|
return tool_messages, user_messages
|
|
111
116
|
|
|
112
117
|
|
|
113
|
-
def _get_total_token_usage(messages: list[list[ChatMessage]]) -> TokenUsage:
|
|
114
|
-
"""
|
|
118
|
+
def _get_total_token_usage(messages: list[list[ChatMessage]]) -> list[TokenUsage]:
|
|
119
|
+
"""
|
|
120
|
+
Returns a list of TokenUsage objects aggregated from all AssistantMessage
|
|
121
|
+
instances across the provided grouped message history.
|
|
115
122
|
|
|
116
123
|
Args:
|
|
117
|
-
messages:
|
|
124
|
+
messages: A list where each item is a list of ChatMessage objects representing a segment
|
|
125
|
+
or turn group of the conversation history.
|
|
118
126
|
|
|
127
|
+
Returns:
|
|
128
|
+
List of TokenUsage corresponding to each AssistantMessage in the flattened conversation history.
|
|
119
129
|
"""
|
|
120
|
-
return
|
|
121
|
-
[msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)],
|
|
122
|
-
start=TokenUsage(),
|
|
123
|
-
)
|
|
130
|
+
return [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)]
|
|
124
131
|
|
|
125
132
|
|
|
126
133
|
class SubAgentParams(BaseModel):
|
|
@@ -174,6 +181,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
174
181
|
finish_tool: Tool[FinishParams, FinishMeta] | None = None,
|
|
175
182
|
# Agent options
|
|
176
183
|
context_summarization_cutoff: float = CONTEXT_SUMMARIZATION_CUTOFF,
|
|
184
|
+
turns_remaining_warning_threshold: int = TURNS_REMAINING_WARNING_THRESHOLD,
|
|
177
185
|
run_sync_in_thread: bool = True,
|
|
178
186
|
text_only_tool_responses: bool = True,
|
|
179
187
|
# Logging
|
|
@@ -213,6 +221,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
213
221
|
self._tools = tools if tools is not None else DEFAULT_TOOLS
|
|
214
222
|
self._finish_tool: Tool = finish_tool if finish_tool is not None else SIMPLE_FINISH_TOOL
|
|
215
223
|
self._context_summarization_cutoff = context_summarization_cutoff
|
|
224
|
+
self._turns_remaining_warning_threshold = turns_remaining_warning_threshold
|
|
216
225
|
self._run_sync_in_thread = run_sync_in_thread
|
|
217
226
|
self._text_only_tool_responses = text_only_tool_responses
|
|
218
227
|
|
|
@@ -222,6 +231,9 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
222
231
|
# Session configuration (set during session(), used in __aenter__)
|
|
223
232
|
self._pending_output_dir: Path | None = None
|
|
224
233
|
self._pending_input_files: str | Path | list[str | Path] | None = None
|
|
234
|
+
self._pending_skills_dir: Path | None = None
|
|
235
|
+
self._resume: bool = False
|
|
236
|
+
self._clear_cache_on_success: bool = True
|
|
225
237
|
|
|
226
238
|
# Instance-scoped state (populated during __aenter__, isolated per agent instance)
|
|
227
239
|
self._active_tools: dict[str, Tool] = {}
|
|
@@ -229,6 +241,10 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
229
241
|
self._last_run_metadata: dict[str, list[Any]] = {}
|
|
230
242
|
self._transferred_paths: list[str] = [] # Paths transferred to parent (for subagents)
|
|
231
243
|
|
|
244
|
+
# Cache state for resumption (set during run(), used in __aexit__ for caching on interrupt)
|
|
245
|
+
self._current_task_hash: str | None = None
|
|
246
|
+
self._current_run_state: CacheState | None = None
|
|
247
|
+
|
|
232
248
|
@property
|
|
233
249
|
def name(self) -> str:
|
|
234
250
|
"""The name of this agent."""
|
|
@@ -258,6 +274,9 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
258
274
|
self,
|
|
259
275
|
output_dir: Path | str | None = None,
|
|
260
276
|
input_files: str | Path | list[str | Path] | None = None,
|
|
277
|
+
skills_dir: Path | str | None = None,
|
|
278
|
+
resume: bool = False,
|
|
279
|
+
clear_cache_on_success: bool = True,
|
|
261
280
|
) -> Self:
|
|
262
281
|
"""Configure a session and return self for use as async context manager.
|
|
263
282
|
|
|
@@ -270,6 +289,16 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
270
289
|
- Glob patterns (e.g., "data/*.csv", "**/*.py")
|
|
271
290
|
Raises ValueError if no CodeExecToolProvider is configured
|
|
272
291
|
or if a glob pattern matches no files.
|
|
292
|
+
skills_dir: Directory containing skill definitions to load and make available
|
|
293
|
+
to the agent. Skills are uploaded to the execution environment
|
|
294
|
+
and their metadata is included in the system prompt.
|
|
295
|
+
resume: If True, attempt to resume from cached state if available.
|
|
296
|
+
The cache is identified by hashing the init_msgs passed to run().
|
|
297
|
+
Cached state includes message history, current turn, and execution
|
|
298
|
+
environment files from a previous interrupted run.
|
|
299
|
+
clear_cache_on_success: If True (default), automatically clear the cache
|
|
300
|
+
when the agent completes successfully. Set to False
|
|
301
|
+
to preserve caches for inspection or debugging.
|
|
273
302
|
|
|
274
303
|
Returns:
|
|
275
304
|
Self, for use with `async with agent.session(...) as session:`
|
|
@@ -285,8 +314,19 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
285
314
|
"""
|
|
286
315
|
self._pending_output_dir = Path(output_dir) if output_dir else None
|
|
287
316
|
self._pending_input_files = input_files
|
|
317
|
+
self._pending_skills_dir = Path(skills_dir) if skills_dir else None
|
|
318
|
+
self._resume = resume
|
|
319
|
+
self._clear_cache_on_success = clear_cache_on_success
|
|
288
320
|
return self
|
|
289
321
|
|
|
322
|
+
def _handle_interrupt(self, _signum: int, _frame: object) -> None:
|
|
323
|
+
"""Handle SIGINT to ensure caching before exit.
|
|
324
|
+
|
|
325
|
+
Converts the signal to a KeyboardInterrupt exception so that __aexit__
|
|
326
|
+
is properly called and can cache the state before cleanup.
|
|
327
|
+
"""
|
|
328
|
+
raise KeyboardInterrupt("Agent interrupted - state will be cached")
|
|
329
|
+
|
|
290
330
|
def _resolve_input_files(self, input_files: str | Path | list[str | Path]) -> list[Path]:
|
|
291
331
|
"""Resolve input file paths, expanding globs and normalizing to Path objects.
|
|
292
332
|
|
|
@@ -402,6 +442,15 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
402
442
|
# Base prompt with max_turns
|
|
403
443
|
parts.append(BASE_SYSTEM_PROMPT_TEMPLATE.format(max_turns=self._max_turns))
|
|
404
444
|
|
|
445
|
+
# User interaction guidance based on whether user_input tool is available
|
|
446
|
+
if "user_input" in self._active_tools:
|
|
447
|
+
parts.append(
|
|
448
|
+
" You have access to the user_input tool which allows you to ask the user "
|
|
449
|
+
"questions when you need clarification or are uncertain about something."
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
parts.append(" You are not able to interact with the user during the task.")
|
|
453
|
+
|
|
405
454
|
# Input files section (if any were uploaded)
|
|
406
455
|
state = _SESSION_STATE.get(None)
|
|
407
456
|
if state and state.uploaded_file_paths:
|
|
@@ -410,6 +459,12 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
410
459
|
files_section += f"\n- {file_path}"
|
|
411
460
|
parts.append(files_section)
|
|
412
461
|
|
|
462
|
+
# Skills section (if skills were loaded)
|
|
463
|
+
if state and state.skills_metadata:
|
|
464
|
+
skills_section = format_skills_section(state.skills_metadata)
|
|
465
|
+
if skills_section:
|
|
466
|
+
parts.append(f"\n\n{skills_section}")
|
|
467
|
+
|
|
413
468
|
# User's custom system prompt (if provided)
|
|
414
469
|
if self._system_prompt:
|
|
415
470
|
parts.append(f"\n\nFollow these instructions from the User:\n{self._system_prompt}")
|
|
@@ -500,6 +555,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
500
555
|
output_dir=str(self._pending_output_dir) if self._pending_output_dir else None,
|
|
501
556
|
parent_exec_env=parent_state.exec_env if parent_state else None,
|
|
502
557
|
depth=current_depth,
|
|
558
|
+
logger=self._logger,
|
|
503
559
|
)
|
|
504
560
|
_SESSION_STATE.set(state)
|
|
505
561
|
|
|
@@ -588,6 +644,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
588
644
|
raise RuntimeError(f"Failed to upload files: {result.failed}")
|
|
589
645
|
self._pending_input_files = None # Clear pending state
|
|
590
646
|
|
|
647
|
+
# Upload skills directory if it exists and load metadata
|
|
648
|
+
if self._pending_skills_dir:
|
|
649
|
+
skills_path = self._pending_skills_dir
|
|
650
|
+
if skills_path.exists() and skills_path.is_dir():
|
|
651
|
+
if state.exec_env:
|
|
652
|
+
logger.debug("[%s __aenter__] Uploading skills directory: %s", self._name, skills_path)
|
|
653
|
+
await state.exec_env.upload_files(skills_path, dest_dir="skills")
|
|
654
|
+
# Load skills metadata (even if no exec_env, for system prompt)
|
|
655
|
+
state.skills_metadata = load_skills_metadata(skills_path)
|
|
656
|
+
logger.debug("[%s __aenter__] Loaded %d skills", self._name, len(state.skills_metadata))
|
|
657
|
+
self._pending_skills_dir = None # Clear pending state
|
|
658
|
+
|
|
591
659
|
# Configure and enter logger context
|
|
592
660
|
self._logger.name = self._name
|
|
593
661
|
self._logger.model = self._client.model_slug
|
|
@@ -595,6 +663,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
595
663
|
# depth is already set (0 for main agent, passed in for sub-agents)
|
|
596
664
|
self._logger.__enter__()
|
|
597
665
|
|
|
666
|
+
# Set up signal handler for graceful caching on interrupt (root agent only)
|
|
667
|
+
if current_depth == 0:
|
|
668
|
+
self._original_sigint = signal.getsignal(signal.SIGINT)
|
|
669
|
+
signal.signal(signal.SIGINT, self._handle_interrupt)
|
|
670
|
+
|
|
598
671
|
return self
|
|
599
672
|
|
|
600
673
|
except Exception:
|
|
@@ -616,6 +689,47 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
616
689
|
state = _SESSION_STATE.get()
|
|
617
690
|
|
|
618
691
|
try:
|
|
692
|
+
# Cache state on non-success exit (only at root level)
|
|
693
|
+
should_cache = (
|
|
694
|
+
state.depth == 0
|
|
695
|
+
and (exc_type is not None or self._last_finish_params is None)
|
|
696
|
+
and self._current_task_hash is not None
|
|
697
|
+
and self._current_run_state is not None
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
logger.debug(
|
|
701
|
+
"[%s __aexit__] Cache decision: should_cache=%s, depth=%d, exc_type=%s, "
|
|
702
|
+
"finish_params=%s, task_hash=%s, run_state=%s",
|
|
703
|
+
self._name,
|
|
704
|
+
should_cache,
|
|
705
|
+
state.depth,
|
|
706
|
+
exc_type,
|
|
707
|
+
self._last_finish_params is not None,
|
|
708
|
+
self._current_task_hash,
|
|
709
|
+
self._current_run_state is not None,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if should_cache:
|
|
713
|
+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
|
|
714
|
+
|
|
715
|
+
exec_env_dir = state.exec_env.temp_dir if state.exec_env else None
|
|
716
|
+
|
|
717
|
+
# Explicit checks to keep type checker happy - should_cache condition guarantees these
|
|
718
|
+
if self._current_task_hash is None or self._current_run_state is None:
|
|
719
|
+
raise ValueError("Cache state is unexpectedly None after should_cache check")
|
|
720
|
+
|
|
721
|
+
# Temporarily block SIGINT during cache save to prevent interruption
|
|
722
|
+
original_handler = signal.getsignal(signal.SIGINT)
|
|
723
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
724
|
+
try:
|
|
725
|
+
cache_manager.save_state(
|
|
726
|
+
self._current_task_hash,
|
|
727
|
+
self._current_run_state,
|
|
728
|
+
exec_env_dir,
|
|
729
|
+
)
|
|
730
|
+
finally:
|
|
731
|
+
signal.signal(signal.SIGINT, original_handler)
|
|
732
|
+
self._logger.info(f"Cached state for task {self._current_task_hash}")
|
|
619
733
|
# Save files from finish_params.paths based on depth
|
|
620
734
|
if state.output_dir and self._last_finish_params and state.exec_env:
|
|
621
735
|
paths = getattr(self._last_finish_params, "paths", None)
|
|
@@ -670,6 +784,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
670
784
|
state.depth,
|
|
671
785
|
)
|
|
672
786
|
finally:
|
|
787
|
+
# Restore original signal handler (root agent only)
|
|
788
|
+
if hasattr(self, "_original_sigint"):
|
|
789
|
+
signal.signal(signal.SIGINT, self._original_sigint)
|
|
790
|
+
del self._original_sigint
|
|
791
|
+
|
|
673
792
|
# Exit logger context
|
|
674
793
|
self._logger.finish_params = self._last_finish_params
|
|
675
794
|
self._logger.run_metadata = self._last_run_metadata
|
|
@@ -695,10 +814,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
695
814
|
|
|
696
815
|
if tool:
|
|
697
816
|
try:
|
|
698
|
-
|
|
699
|
-
params = (
|
|
700
|
-
tool.parameters.model_validate_json(tool_call.arguments) if tool.parameters is not None else None
|
|
701
|
-
)
|
|
817
|
+
params = tool.parameters.model_validate_json(tool_call.arguments)
|
|
702
818
|
|
|
703
819
|
# Set parent depth for sub-agent tools to read
|
|
704
820
|
prev_depth = _PARENT_DEPTH.set(self._logger.depth)
|
|
@@ -723,17 +839,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
723
839
|
tool_call.name,
|
|
724
840
|
tool_call.arguments,
|
|
725
841
|
)
|
|
726
|
-
result = ToolResult(content="Tool arguments are not valid")
|
|
842
|
+
result = ToolResult(content="Tool arguments are not valid", success=False)
|
|
727
843
|
args_valid = False
|
|
728
844
|
else:
|
|
729
845
|
LOGGER.debug(f"LLMClient tried to use the tool {tool_call.name} which is not in the tools list")
|
|
730
|
-
result = ToolResult(content=f"{tool_call.name} is not a valid tool")
|
|
846
|
+
result = ToolResult(content=f"{tool_call.name} is not a valid tool", success=False)
|
|
731
847
|
|
|
732
848
|
return ToolMessage(
|
|
733
849
|
content=result.content,
|
|
734
850
|
tool_call_id=tool_call.tool_call_id,
|
|
735
851
|
name=tool_call.name,
|
|
736
852
|
args_was_valid=args_valid,
|
|
853
|
+
success=result.success,
|
|
737
854
|
)
|
|
738
855
|
|
|
739
856
|
async def step(
|
|
@@ -742,7 +859,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
742
859
|
run_metadata: dict[str, list[Any]],
|
|
743
860
|
turn: int = 0,
|
|
744
861
|
max_turns: int = 0,
|
|
745
|
-
) -> tuple[AssistantMessage, list[ToolMessage],
|
|
862
|
+
) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None]:
|
|
746
863
|
"""Execute one agent step: generate assistant message and run any requested tool calls.
|
|
747
864
|
|
|
748
865
|
Args:
|
|
@@ -760,24 +877,21 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
760
877
|
if turn > 0:
|
|
761
878
|
self._logger.assistant_message(turn, max_turns, assistant_message)
|
|
762
879
|
|
|
880
|
+
finish_params: FinishParams | None = None
|
|
763
881
|
tool_messages: list[ToolMessage] = []
|
|
764
|
-
finish_call: ToolCall | None = None
|
|
765
|
-
|
|
766
882
|
if assistant_message.tool_calls:
|
|
767
|
-
finish_call = next(
|
|
768
|
-
(tc for tc in assistant_message.tool_calls if tc.name == FINISH_TOOL_NAME),
|
|
769
|
-
None,
|
|
770
|
-
)
|
|
771
|
-
|
|
772
883
|
tool_messages = []
|
|
773
884
|
for tool_call in assistant_message.tool_calls:
|
|
774
885
|
tool_message = await self.run_tool(tool_call, run_metadata)
|
|
775
886
|
tool_messages.append(tool_message)
|
|
776
887
|
|
|
888
|
+
if tool_message.success and tool_message.name == FINISH_TOOL_NAME:
|
|
889
|
+
finish_params = self._finish_tool.parameters.model_validate_json(tool_call.arguments)
|
|
890
|
+
|
|
777
891
|
# Log tool result immediately
|
|
778
892
|
self._logger.tool_result(tool_message)
|
|
779
893
|
|
|
780
|
-
return assistant_message, tool_messages,
|
|
894
|
+
return assistant_message, tool_messages, finish_params
|
|
781
895
|
|
|
782
896
|
async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]:
|
|
783
897
|
"""Condense message history using LLM to stay within context window."""
|
|
@@ -803,7 +917,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
803
917
|
init_msgs: str | list[ChatMessage],
|
|
804
918
|
*,
|
|
805
919
|
depth: int | None = None,
|
|
806
|
-
) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str,
|
|
920
|
+
) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, Any]]:
|
|
807
921
|
"""Execute the agent loop until finish tool is called or max_turns reached.
|
|
808
922
|
|
|
809
923
|
A base system prompt is automatically prepended to all runs, including:
|
|
@@ -833,23 +947,59 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
833
947
|
])
|
|
834
948
|
|
|
835
949
|
"""
|
|
836
|
-
msgs: list[ChatMessage] = []
|
|
837
950
|
|
|
838
|
-
#
|
|
839
|
-
|
|
840
|
-
|
|
951
|
+
# Compute task hash for caching/resume
|
|
952
|
+
task_hash = compute_task_hash(init_msgs)
|
|
953
|
+
self._current_task_hash = task_hash
|
|
954
|
+
|
|
955
|
+
# Initialize cache manager
|
|
956
|
+
cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
|
|
957
|
+
start_turn = 0
|
|
958
|
+
resumed = False
|
|
959
|
+
|
|
960
|
+
# Try to resume from cache if requested
|
|
961
|
+
if self._resume:
|
|
962
|
+
state = _SESSION_STATE.get()
|
|
963
|
+
cached = cache_manager.load_state(task_hash)
|
|
964
|
+
if cached:
|
|
965
|
+
# Restore files to exec env
|
|
966
|
+
if state.exec_env and state.exec_env.temp_dir:
|
|
967
|
+
cache_manager.restore_files(task_hash, state.exec_env.temp_dir)
|
|
968
|
+
|
|
969
|
+
# Restore state
|
|
970
|
+
msgs = cached.msgs
|
|
971
|
+
full_msg_history = cached.full_msg_history
|
|
972
|
+
run_metadata = cached.run_metadata
|
|
973
|
+
start_turn = cached.turn
|
|
974
|
+
resumed = True
|
|
975
|
+
self._logger.info(f"Resuming from cached state at turn {start_turn}")
|
|
976
|
+
else:
|
|
977
|
+
self._logger.info(f"No cache found for task {task_hash}, starting fresh")
|
|
841
978
|
|
|
842
|
-
if
|
|
843
|
-
msgs
|
|
844
|
-
|
|
845
|
-
|
|
979
|
+
if not resumed:
|
|
980
|
+
msgs: list[ChatMessage] = []
|
|
981
|
+
|
|
982
|
+
# Build the complete system prompt (base + input files + user instructions)
|
|
983
|
+
full_system_prompt = self._build_system_prompt()
|
|
984
|
+
msgs.append(SystemMessage(content=full_system_prompt))
|
|
985
|
+
|
|
986
|
+
if isinstance(init_msgs, str):
|
|
987
|
+
msgs.append(UserMessage(content=init_msgs))
|
|
988
|
+
else:
|
|
989
|
+
msgs.extend(init_msgs)
|
|
990
|
+
|
|
991
|
+
# Local metadata storage - isolated per run() invocation for thread safety
|
|
992
|
+
run_metadata: dict[str, list[Any]] = {}
|
|
993
|
+
|
|
994
|
+
full_msg_history: list[list[ChatMessage]] = []
|
|
846
995
|
|
|
847
996
|
# Set logger depth if provided (for sub-agent runs)
|
|
848
997
|
if depth is not None:
|
|
849
998
|
self._logger.depth = depth
|
|
850
999
|
|
|
851
|
-
# Log the task at run start
|
|
852
|
-
|
|
1000
|
+
# Log the task at run start (only if not resuming)
|
|
1001
|
+
if not resumed:
|
|
1002
|
+
self._logger.task_message(msgs[-1].content)
|
|
853
1003
|
|
|
854
1004
|
# Show warnings (top-level only, if logger supports it)
|
|
855
1005
|
if self._logger.depth == 0 and isinstance(self._logger, AgentLogger):
|
|
@@ -860,25 +1010,30 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
860
1010
|
# Use logger callback if available and not overridden
|
|
861
1011
|
step_callback = self._logger.on_step
|
|
862
1012
|
|
|
863
|
-
# Local metadata storage - isolated per run() invocation for thread safety
|
|
864
|
-
run_metadata: dict[str, list[Any]] = {}
|
|
865
|
-
|
|
866
1013
|
full_msg_history: list[list[ChatMessage]] = []
|
|
867
|
-
finish_params: FinishParams | None = None
|
|
868
1014
|
|
|
869
1015
|
# Cumulative stats for spinner
|
|
870
1016
|
total_tool_calls = 0
|
|
871
1017
|
total_input_tokens = 0
|
|
872
1018
|
total_output_tokens = 0
|
|
873
1019
|
|
|
874
|
-
for i in range(self._max_turns):
|
|
875
|
-
|
|
1020
|
+
for i in range(start_turn, self._max_turns):
|
|
1021
|
+
# Capture current state for potential caching (before any async work)
|
|
1022
|
+
self._current_run_state = CacheState(
|
|
1023
|
+
msgs=list(msgs),
|
|
1024
|
+
full_msg_history=[list(group) for group in full_msg_history],
|
|
1025
|
+
turn=i,
|
|
1026
|
+
run_metadata=dict(run_metadata),
|
|
1027
|
+
task_hash=task_hash,
|
|
1028
|
+
agent_name=self._name,
|
|
1029
|
+
)
|
|
1030
|
+
if self._max_turns - i <= self._turns_remaining_warning_threshold and i != 0:
|
|
876
1031
|
num_turns_remaining_msg = _num_turns_remaining_msg(self._max_turns - i)
|
|
877
1032
|
msgs.append(num_turns_remaining_msg)
|
|
878
1033
|
self._logger.user_message(num_turns_remaining_msg)
|
|
879
1034
|
|
|
880
1035
|
# Pass turn info to step() for real-time logging
|
|
881
|
-
assistant_message, tool_messages,
|
|
1036
|
+
assistant_message, tool_messages, finish_params = await self.step(
|
|
882
1037
|
msgs,
|
|
883
1038
|
run_metadata,
|
|
884
1039
|
turn=i + 1,
|
|
@@ -904,18 +1059,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
904
1059
|
|
|
905
1060
|
msgs.extend([assistant_message, *tool_messages, *user_messages])
|
|
906
1061
|
|
|
907
|
-
if
|
|
908
|
-
|
|
909
|
-
finish_arguments = json.loads(finish_call.arguments)
|
|
910
|
-
if self._finish_tool.parameters is not None:
|
|
911
|
-
finish_params = self._finish_tool.parameters.model_validate(finish_arguments)
|
|
912
|
-
break
|
|
913
|
-
except (json.JSONDecodeError, ValidationError, TypeError):
|
|
914
|
-
LOGGER.debug(
|
|
915
|
-
"Agent tried to use the finish tool but the tool call is not valid: %r",
|
|
916
|
-
finish_call.arguments,
|
|
917
|
-
)
|
|
918
|
-
# continue until the finish tool call is valid
|
|
1062
|
+
if finish_params:
|
|
1063
|
+
break
|
|
919
1064
|
|
|
920
1065
|
pct_context_used = assistant_message.token_usage.total / self._client.max_tokens
|
|
921
1066
|
if pct_context_used >= self._context_summarization_cutoff and i + 1 != self._max_turns:
|
|
@@ -930,15 +1075,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
930
1075
|
full_msg_history.append(msgs)
|
|
931
1076
|
|
|
932
1077
|
# Add agent's own token usage to run_metadata under "token_usage" key
|
|
933
|
-
|
|
934
|
-
if "token_usage" not in run_metadata:
|
|
935
|
-
run_metadata["token_usage"] = []
|
|
936
|
-
run_metadata["token_usage"].append(agent_token_usage)
|
|
1078
|
+
run_metadata["token_usage"] = _get_total_token_usage(full_msg_history)
|
|
937
1079
|
|
|
938
1080
|
# Store for __aexit__ to access (on instance for this agent)
|
|
939
1081
|
self._last_finish_params = finish_params
|
|
940
1082
|
self._last_run_metadata = run_metadata
|
|
941
1083
|
|
|
1084
|
+
# Clear cache on successful completion (finish_params is set)
|
|
1085
|
+
if finish_params is not None and cache_manager.clear_on_success:
|
|
1086
|
+
cache_manager.clear_cache(task_hash)
|
|
1087
|
+
self._current_task_hash = None
|
|
1088
|
+
self._current_run_state = None
|
|
1089
|
+
|
|
942
1090
|
return finish_params, full_msg_history, run_metadata
|
|
943
1091
|
|
|
944
1092
|
def to_tool(
|
|
@@ -1066,6 +1214,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
|
|
|
1066
1214
|
)
|
|
1067
1215
|
return ToolResult(
|
|
1068
1216
|
content=f"<sub_agent_result>\n<error>{e!s}</error>\n</sub_agent_result>",
|
|
1217
|
+
success=False,
|
|
1069
1218
|
metadata=error_metadata,
|
|
1070
1219
|
)
|
|
1071
1220
|
finally:
|