stirrup 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stirrup/__init__.py CHANGED
@@ -35,6 +35,7 @@ from stirrup.core.models import (
35
35
  AssistantMessage,
36
36
  AudioContentBlock,
37
37
  ChatMessage,
38
+ EmptyParams,
38
39
  ImageContentBlock,
39
40
  LLMClient,
40
41
  SubAgentMetadata,
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "AudioContentBlock",
59
60
  "ChatMessage",
60
61
  "ContextOverflowError",
62
+ "EmptyParams",
61
63
  "ImageContentBlock",
62
64
  "LLMClient",
63
65
  "SubAgentMetadata",
@@ -67,7 +67,6 @@ class ChatCompletionsClient(LLMClient):
67
67
  *,
68
68
  base_url: str | None = None,
69
69
  api_key: str | None = None,
70
- supports_audio_input: bool = False,
71
70
  reasoning_effort: str | None = None,
72
71
  timeout: float | None = None,
73
72
  max_retries: int = 2,
@@ -82,7 +81,6 @@ class ChatCompletionsClient(LLMClient):
82
81
  Use for OpenAI-compatible providers (e.g., 'http://localhost:8000/v1').
83
82
  api_key: API key for authentication. If None, reads from OPENROUTER_API_KEY
84
83
  environment variable.
85
- supports_audio_input: Whether the model supports audio inputs. Defaults to False.
86
84
  reasoning_effort: Reasoning effort level for extended thinking models
87
85
  (e.g., 'low', 'medium', 'high'). Only used with o1/o3 style models.
88
86
  timeout: Request timeout in seconds. If None, uses OpenAI SDK default.
@@ -92,7 +90,6 @@ class ChatCompletionsClient(LLMClient):
92
90
  """
93
91
  self._model = model
94
92
  self._max_tokens = max_tokens
95
- self._supports_audio_input = supports_audio_input
96
93
  self._reasoning_effort = reasoning_effort
97
94
  self._kwargs = kwargs or {}
98
95
 
@@ -7,7 +7,7 @@ Requires the litellm extra: `pip install stirrup[litellm]`
7
7
  """
8
8
 
9
9
  import logging
10
- from typing import Any
10
+ from typing import Any, Literal
11
11
 
12
12
  try:
13
13
  from litellm import acompletion
@@ -38,6 +38,8 @@ __all__ = [
38
38
 
39
39
  LOGGER = logging.getLogger(__name__)
40
40
 
41
+ type ReasoningEffort = Literal["none", "minimal", "low", "medium", "high", "xhigh", "default"]
42
+
41
43
 
42
44
  class LiteLLMClient(LLMClient):
43
45
  """LiteLLM-based client supporting multiple LLM providers with unified interface.
@@ -49,8 +51,8 @@ class LiteLLMClient(LLMClient):
49
51
  self,
50
52
  model_slug: str,
51
53
  max_tokens: int,
52
- supports_audio_input: bool = False,
53
- reasoning_effort: str | None = None,
54
+ api_key: str | None = None,
55
+ reasoning_effort: ReasoningEffort | None = None,
54
56
  kwargs: dict[str, Any] | None = None,
55
57
  ) -> None:
56
58
  """Initialize LiteLLM client with model configuration and capabilities.
@@ -58,15 +60,13 @@ class LiteLLMClient(LLMClient):
58
60
  Args:
59
61
  model_slug: Model identifier for LiteLLM (e.g., 'anthropic/claude-3-5-sonnet-20241022')
60
62
  max_tokens: Maximum context window size in tokens
61
- supports_audio_input: Whether the model supports audio inputs
62
63
  reasoning_effort: Reasoning effort level for extended thinking models (e.g., 'medium', 'high')
63
64
  kwargs: Additional arguments to pass to LiteLLM completion calls
64
65
  """
65
66
  self._model_slug = model_slug
66
- self._supports_video_input = False
67
- self._supports_audio_input = supports_audio_input
68
67
  self._max_tokens = max_tokens
69
- self._reasoning_effort = reasoning_effort
68
+ self._reasoning_effort: ReasoningEffort | None = reasoning_effort
69
+ self._api_key = api_key
70
70
  self._kwargs = kwargs or {}
71
71
 
72
72
  @property
@@ -92,6 +92,8 @@ class LiteLLMClient(LLMClient):
92
92
  tools=to_openai_tools(tools) if tools else None,
93
93
  tool_choice="auto" if tools else None,
94
94
  max_tokens=self._max_tokens,
95
+ reasoning_effort=self._reasoning_effort,
96
+ api_key=self._api_key,
95
97
  **self._kwargs,
96
98
  )
97
99
 
@@ -103,14 +105,20 @@ class LiteLLMClient(LLMClient):
103
105
  )
104
106
 
105
107
  msg = choice["message"]
106
-
107
108
  reasoning: Reasoning | None = None
108
109
  if getattr(msg, "reasoning_content", None) is not None:
109
110
  reasoning = Reasoning(content=msg.reasoning_content)
110
111
  if getattr(msg, "thinking_blocks", None) is not None and len(msg.thinking_blocks) > 0:
111
- reasoning = Reasoning(
112
- signature=msg.thinking_blocks[0]["signature"], content=msg.thinking_blocks[0]["content"]
113
- )
112
+ if len(msg.thinking_blocks) > 1:
113
+ raise ValueError("Found multiple thinking blocks in the response")
114
+
115
+ signature = msg.thinking_blocks[0].get("thinking_signature", None)
116
+ content = msg.thinking_blocks[0].get("thinking", None)
117
+
118
+ if signature is None and content is None:
119
+ raise ValueError("Signature and content not found in the thinking block response")
120
+
121
+ reasoning = Reasoning(signature=signature, content=content)
114
122
 
115
123
  usage = r["usage"]
116
124
 
@@ -119,6 +127,7 @@ class LiteLLMClient(LLMClient):
119
127
  tool_call_id=tc.get("id"),
120
128
  name=tc["function"]["name"],
121
129
  arguments=tc["function"].get("arguments", "") or "",
130
+ signature=tc.get("provider_specific_fields", {}).get("thought_signature", None),
122
131
  )
123
132
  for tc in (msg.get("tool_calls") or [])
124
133
  ]
stirrup/clients/utils.py CHANGED
@@ -12,6 +12,7 @@ from stirrup.core.models import (
12
12
  AudioContentBlock,
13
13
  ChatMessage,
14
14
  Content,
15
+ EmptyParams,
15
16
  ImageContentBlock,
16
17
  SystemMessage,
17
18
  Tool,
@@ -47,7 +48,7 @@ def to_openai_tools(tools: dict[str, Tool]) -> list[dict[str, Any]]:
47
48
  "name": t.name,
48
49
  "description": t.description,
49
50
  }
50
- if t.parameters is not None:
51
+ if t.parameters is not EmptyParams:
51
52
  function["parameters"] = t.parameters.model_json_schema()
52
53
  tool_payload: dict[str, Any] = {
53
54
  "type": "function",
@@ -139,6 +140,10 @@ def to_openai_messages(msgs: list[ChatMessage]) -> list[dict[str, Any]]:
139
140
  tool_dict = tool.model_dump()
140
141
  tool_dict["id"] = tool.tool_call_id
141
142
  tool_dict["type"] = "function"
143
+ if tool.signature is not None:
144
+ tool_dict["provider_specific_fields"] = {
145
+ "thought_signature": tool.signature,
146
+ }
142
147
  tool_dict["function"] = {
143
148
  "name": tool.name,
144
149
  "arguments": tool.arguments,
stirrup/constants.py CHANGED
@@ -1,14 +1,18 @@
1
+ from typing import Literal
2
+
1
3
  # Tool naming
2
- FINISH_TOOL_NAME = "finish"
4
+ FINISH_TOOL_NAME: Literal["finish"] = "finish"
3
5
 
4
6
  # Agent execution limits
5
7
  AGENT_MAX_TURNS = 30 # Maximum agent turns before forced termination
6
8
  CONTEXT_SUMMARIZATION_CUTOFF = 0.7 # Context window usage threshold (0.0-1.0) that triggers message summarization
9
+ TURNS_REMAINING_WARNING_THRESHOLD = 20
7
10
 
8
11
  # Media resolution limits
9
12
  RESOLUTION_1MP = 1_000_000 # 1 megapixel - default max resolution for images
10
13
  RESOLUTION_480P = 640 * 480 # 480p video resolution
11
14
 
12
15
  # Code execution
13
- SUBMISSION_SANDBOX_TIMEOUT = 60 * 10 # 10 minutes
16
+ SANDBOX_TIMEOUT = 60 * 10 # 10 minutes
17
+ SANDBOX_REQUEST_TIMEOUT = 60 * 3 # 3 minutes
14
18
  E2B_SANDBOX_TEMPLATE_ALIAS = "e2b-sandbox"
stirrup/core/agent.py CHANGED
@@ -2,9 +2,9 @@
2
2
  import contextvars
3
3
  import glob as glob_module
4
4
  import inspect
5
- import json
6
5
  import logging
7
6
  import re
7
+ import signal
8
8
  from contextlib import AsyncExitStack
9
9
  from dataclasses import dataclass, field
10
10
  from itertools import chain, takewhile
@@ -19,7 +19,9 @@ from stirrup.constants import (
19
19
  AGENT_MAX_TURNS,
20
20
  CONTEXT_SUMMARIZATION_CUTOFF,
21
21
  FINISH_TOOL_NAME,
22
+ TURNS_REMAINING_WARNING_THRESHOLD,
22
23
  )
24
+ from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
23
25
  from stirrup.core.models import (
24
26
  AssistantMessage,
25
27
  ChatMessage,
@@ -72,6 +74,7 @@ class SessionState:
72
74
  depth: int = 0
73
75
  uploaded_file_paths: list[str] = field(default_factory=list) # Paths of files uploaded to exec_env
74
76
  skills_metadata: list[SkillMetadata] = field(default_factory=list) # Loaded skills metadata
77
+ logger: AgentLoggerBase | None = None # Logger for pause/resume during user input
75
78
 
76
79
 
77
80
  _SESSION_STATE: contextvars.ContextVar[SessionState] = contextvars.ContextVar("session_state")
@@ -112,17 +115,19 @@ def _handle_text_only_tool_responses(tool_messages: list[ToolMessage]) -> tuple[
112
115
  return tool_messages, user_messages
113
116
 
114
117
 
115
- def _get_total_token_usage(messages: list[list[ChatMessage]]) -> TokenUsage:
116
- """Aggregate token usage across all assistant messages in grouped conversation history.
118
+ def _get_total_token_usage(messages: list[list[ChatMessage]]) -> list[TokenUsage]:
119
+ """
120
+ Returns a list of TokenUsage objects aggregated from all AssistantMessage
121
+ instances across the provided grouped message history.
117
122
 
118
123
  Args:
119
- messages: List of message groups, where each group represents a segment of conversation.
124
+ messages: A list where each item is a list of ChatMessage objects representing a segment
125
+ or turn group of the conversation history.
120
126
 
127
+ Returns:
128
+ List of TokenUsage corresponding to each AssistantMessage in the flattened conversation history.
121
129
  """
122
- return sum(
123
- [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)],
124
- start=TokenUsage(),
125
- )
130
+ return [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)]
126
131
 
127
132
 
128
133
  class SubAgentParams(BaseModel):
@@ -176,6 +181,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
176
181
  finish_tool: Tool[FinishParams, FinishMeta] | None = None,
177
182
  # Agent options
178
183
  context_summarization_cutoff: float = CONTEXT_SUMMARIZATION_CUTOFF,
184
+ turns_remaining_warning_threshold: int = TURNS_REMAINING_WARNING_THRESHOLD,
179
185
  run_sync_in_thread: bool = True,
180
186
  text_only_tool_responses: bool = True,
181
187
  # Logging
@@ -215,6 +221,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
215
221
  self._tools = tools if tools is not None else DEFAULT_TOOLS
216
222
  self._finish_tool: Tool = finish_tool if finish_tool is not None else SIMPLE_FINISH_TOOL
217
223
  self._context_summarization_cutoff = context_summarization_cutoff
224
+ self._turns_remaining_warning_threshold = turns_remaining_warning_threshold
218
225
  self._run_sync_in_thread = run_sync_in_thread
219
226
  self._text_only_tool_responses = text_only_tool_responses
220
227
 
@@ -225,6 +232,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
225
232
  self._pending_output_dir: Path | None = None
226
233
  self._pending_input_files: str | Path | list[str | Path] | None = None
227
234
  self._pending_skills_dir: Path | None = None
235
+ self._resume: bool = False
236
+ self._clear_cache_on_success: bool = True
228
237
 
229
238
  # Instance-scoped state (populated during __aenter__, isolated per agent instance)
230
239
  self._active_tools: dict[str, Tool] = {}
@@ -232,6 +241,10 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
232
241
  self._last_run_metadata: dict[str, list[Any]] = {}
233
242
  self._transferred_paths: list[str] = [] # Paths transferred to parent (for subagents)
234
243
 
244
+ # Cache state for resumption (set during run(), used in __aexit__ for caching on interrupt)
245
+ self._current_task_hash: str | None = None
246
+ self._current_run_state: CacheState | None = None
247
+
235
248
  @property
236
249
  def name(self) -> str:
237
250
  """The name of this agent."""
@@ -262,6 +275,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
262
275
  output_dir: Path | str | None = None,
263
276
  input_files: str | Path | list[str | Path] | None = None,
264
277
  skills_dir: Path | str | None = None,
278
+ resume: bool = False,
279
+ clear_cache_on_success: bool = True,
265
280
  ) -> Self:
266
281
  """Configure a session and return self for use as async context manager.
267
282
 
@@ -277,6 +292,13 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
277
292
  skills_dir: Directory containing skill definitions to load and make available
278
293
  to the agent. Skills are uploaded to the execution environment
279
294
  and their metadata is included in the system prompt.
295
+ resume: If True, attempt to resume from cached state if available.
296
+ The cache is identified by hashing the init_msgs passed to run().
297
+ Cached state includes message history, current turn, and execution
298
+ environment files from a previous interrupted run.
299
+ clear_cache_on_success: If True (default), automatically clear the cache
300
+ when the agent completes successfully. Set to False
301
+ to preserve caches for inspection or debugging.
280
302
 
281
303
  Returns:
282
304
  Self, for use with `async with agent.session(...) as session:`
@@ -293,8 +315,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
293
315
  self._pending_output_dir = Path(output_dir) if output_dir else None
294
316
  self._pending_input_files = input_files
295
317
  self._pending_skills_dir = Path(skills_dir) if skills_dir else None
318
+ self._resume = resume
319
+ self._clear_cache_on_success = clear_cache_on_success
296
320
  return self
297
321
 
322
+ def _handle_interrupt(self, _signum: int, _frame: object) -> None:
323
+ """Handle SIGINT to ensure caching before exit.
324
+
325
+ Converts the signal to a KeyboardInterrupt exception so that __aexit__
326
+ is properly called and can cache the state before cleanup.
327
+ """
328
+ raise KeyboardInterrupt("Agent interrupted - state will be cached")
329
+
298
330
  def _resolve_input_files(self, input_files: str | Path | list[str | Path]) -> list[Path]:
299
331
  """Resolve input file paths, expanding globs and normalizing to Path objects.
300
332
 
@@ -410,6 +442,15 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
410
442
  # Base prompt with max_turns
411
443
  parts.append(BASE_SYSTEM_PROMPT_TEMPLATE.format(max_turns=self._max_turns))
412
444
 
445
+ # User interaction guidance based on whether user_input tool is available
446
+ if "user_input" in self._active_tools:
447
+ parts.append(
448
+ " You have access to the user_input tool which allows you to ask the user "
449
+ "questions when you need clarification or are uncertain about something."
450
+ )
451
+ else:
452
+ parts.append(" You are not able to interact with the user during the task.")
453
+
413
454
  # Input files section (if any were uploaded)
414
455
  state = _SESSION_STATE.get(None)
415
456
  if state and state.uploaded_file_paths:
@@ -514,6 +555,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
514
555
  output_dir=str(self._pending_output_dir) if self._pending_output_dir else None,
515
556
  parent_exec_env=parent_state.exec_env if parent_state else None,
516
557
  depth=current_depth,
558
+ logger=self._logger,
517
559
  )
518
560
  _SESSION_STATE.set(state)
519
561
 
@@ -621,6 +663,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
621
663
  # depth is already set (0 for main agent, passed in for sub-agents)
622
664
  self._logger.__enter__()
623
665
 
666
+ # Set up signal handler for graceful caching on interrupt (root agent only)
667
+ if current_depth == 0:
668
+ self._original_sigint = signal.getsignal(signal.SIGINT)
669
+ signal.signal(signal.SIGINT, self._handle_interrupt)
670
+
624
671
  return self
625
672
 
626
673
  except Exception:
@@ -642,6 +689,47 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
642
689
  state = _SESSION_STATE.get()
643
690
 
644
691
  try:
692
+ # Cache state on non-success exit (only at root level)
693
+ should_cache = (
694
+ state.depth == 0
695
+ and (exc_type is not None or self._last_finish_params is None)
696
+ and self._current_task_hash is not None
697
+ and self._current_run_state is not None
698
+ )
699
+
700
+ logger.debug(
701
+ "[%s __aexit__] Cache decision: should_cache=%s, depth=%d, exc_type=%s, "
702
+ "finish_params=%s, task_hash=%s, run_state=%s",
703
+ self._name,
704
+ should_cache,
705
+ state.depth,
706
+ exc_type,
707
+ self._last_finish_params is not None,
708
+ self._current_task_hash,
709
+ self._current_run_state is not None,
710
+ )
711
+
712
+ if should_cache:
713
+ cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
714
+
715
+ exec_env_dir = state.exec_env.temp_dir if state.exec_env else None
716
+
717
+ # Explicit checks to keep type checker happy - should_cache condition guarantees these
718
+ if self._current_task_hash is None or self._current_run_state is None:
719
+ raise ValueError("Cache state is unexpectedly None after should_cache check")
720
+
721
+ # Temporarily block SIGINT during cache save to prevent interruption
722
+ original_handler = signal.getsignal(signal.SIGINT)
723
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
724
+ try:
725
+ cache_manager.save_state(
726
+ self._current_task_hash,
727
+ self._current_run_state,
728
+ exec_env_dir,
729
+ )
730
+ finally:
731
+ signal.signal(signal.SIGINT, original_handler)
732
+ self._logger.info(f"Cached state for task {self._current_task_hash}")
645
733
  # Save files from finish_params.paths based on depth
646
734
  if state.output_dir and self._last_finish_params and state.exec_env:
647
735
  paths = getattr(self._last_finish_params, "paths", None)
@@ -696,6 +784,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
696
784
  state.depth,
697
785
  )
698
786
  finally:
787
+ # Restore original signal handler (root agent only)
788
+ if hasattr(self, "_original_sigint"):
789
+ signal.signal(signal.SIGINT, self._original_sigint)
790
+ del self._original_sigint
791
+
699
792
  # Exit logger context
700
793
  self._logger.finish_params = self._last_finish_params
701
794
  self._logger.run_metadata = self._last_run_metadata
@@ -721,10 +814,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
721
814
 
722
815
  if tool:
723
816
  try:
724
- # Parse parameters if the tool has them, otherwise use None
725
- params = (
726
- tool.parameters.model_validate_json(tool_call.arguments) if tool.parameters is not None else None
727
- )
817
+ params = tool.parameters.model_validate_json(tool_call.arguments)
728
818
 
729
819
  # Set parent depth for sub-agent tools to read
730
820
  prev_depth = _PARENT_DEPTH.set(self._logger.depth)
@@ -749,17 +839,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
749
839
  tool_call.name,
750
840
  tool_call.arguments,
751
841
  )
752
- result = ToolResult(content="Tool arguments are not valid")
842
+ result = ToolResult(content="Tool arguments are not valid", success=False)
753
843
  args_valid = False
754
844
  else:
755
845
  LOGGER.debug(f"LLMClient tried to use the tool {tool_call.name} which is not in the tools list")
756
- result = ToolResult(content=f"{tool_call.name} is not a valid tool")
846
+ result = ToolResult(content=f"{tool_call.name} is not a valid tool", success=False)
757
847
 
758
848
  return ToolMessage(
759
849
  content=result.content,
760
850
  tool_call_id=tool_call.tool_call_id,
761
851
  name=tool_call.name,
762
852
  args_was_valid=args_valid,
853
+ success=result.success,
763
854
  )
764
855
 
765
856
  async def step(
@@ -768,7 +859,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
768
859
  run_metadata: dict[str, list[Any]],
769
860
  turn: int = 0,
770
861
  max_turns: int = 0,
771
- ) -> tuple[AssistantMessage, list[ToolMessage], ToolCall | None]:
862
+ ) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None]:
772
863
  """Execute one agent step: generate assistant message and run any requested tool calls.
773
864
 
774
865
  Args:
@@ -786,24 +877,21 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
786
877
  if turn > 0:
787
878
  self._logger.assistant_message(turn, max_turns, assistant_message)
788
879
 
880
+ finish_params: FinishParams | None = None
789
881
  tool_messages: list[ToolMessage] = []
790
- finish_call: ToolCall | None = None
791
-
792
882
  if assistant_message.tool_calls:
793
- finish_call = next(
794
- (tc for tc in assistant_message.tool_calls if tc.name == FINISH_TOOL_NAME),
795
- None,
796
- )
797
-
798
883
  tool_messages = []
799
884
  for tool_call in assistant_message.tool_calls:
800
885
  tool_message = await self.run_tool(tool_call, run_metadata)
801
886
  tool_messages.append(tool_message)
802
887
 
888
+ if tool_message.success and tool_message.name == FINISH_TOOL_NAME:
889
+ finish_params = self._finish_tool.parameters.model_validate_json(tool_call.arguments)
890
+
803
891
  # Log tool result immediately
804
892
  self._logger.tool_result(tool_message)
805
893
 
806
- return assistant_message, tool_messages, finish_call
894
+ return assistant_message, tool_messages, finish_params
807
895
 
808
896
  async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]:
809
897
  """Condense message history using LLM to stay within context window."""
@@ -829,7 +917,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
829
917
  init_msgs: str | list[ChatMessage],
830
918
  *,
831
919
  depth: int | None = None,
832
- ) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, list[Any]]]:
920
+ ) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, Any]]:
833
921
  """Execute the agent loop until finish tool is called or max_turns reached.
834
922
 
835
923
  A base system prompt is automatically prepended to all runs, including:
@@ -859,23 +947,59 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
859
947
  ])
860
948
 
861
949
  """
862
- msgs: list[ChatMessage] = []
863
950
 
864
- # Build the complete system prompt (base + input files + user instructions)
865
- full_system_prompt = self._build_system_prompt()
866
- msgs.append(SystemMessage(content=full_system_prompt))
951
+ # Compute task hash for caching/resume
952
+ task_hash = compute_task_hash(init_msgs)
953
+ self._current_task_hash = task_hash
954
+
955
+ # Initialize cache manager
956
+ cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
957
+ start_turn = 0
958
+ resumed = False
959
+
960
+ # Try to resume from cache if requested
961
+ if self._resume:
962
+ state = _SESSION_STATE.get()
963
+ cached = cache_manager.load_state(task_hash)
964
+ if cached:
965
+ # Restore files to exec env
966
+ if state.exec_env and state.exec_env.temp_dir:
967
+ cache_manager.restore_files(task_hash, state.exec_env.temp_dir)
968
+
969
+ # Restore state
970
+ msgs = cached.msgs
971
+ full_msg_history = cached.full_msg_history
972
+ run_metadata = cached.run_metadata
973
+ start_turn = cached.turn
974
+ resumed = True
975
+ self._logger.info(f"Resuming from cached state at turn {start_turn}")
976
+ else:
977
+ self._logger.info(f"No cache found for task {task_hash}, starting fresh")
867
978
 
868
- if isinstance(init_msgs, str):
869
- msgs.append(UserMessage(content=init_msgs))
870
- else:
871
- msgs.extend(init_msgs)
979
+ if not resumed:
980
+ msgs: list[ChatMessage] = []
981
+
982
+ # Build the complete system prompt (base + input files + user instructions)
983
+ full_system_prompt = self._build_system_prompt()
984
+ msgs.append(SystemMessage(content=full_system_prompt))
985
+
986
+ if isinstance(init_msgs, str):
987
+ msgs.append(UserMessage(content=init_msgs))
988
+ else:
989
+ msgs.extend(init_msgs)
990
+
991
+ # Local metadata storage - isolated per run() invocation for thread safety
992
+ run_metadata: dict[str, list[Any]] = {}
993
+
994
+ full_msg_history: list[list[ChatMessage]] = []
872
995
 
873
996
  # Set logger depth if provided (for sub-agent runs)
874
997
  if depth is not None:
875
998
  self._logger.depth = depth
876
999
 
877
- # Log the task at run start
878
- self._logger.task_message(msgs[-1].content)
1000
+ # Log the task at run start (only if not resuming)
1001
+ if not resumed:
1002
+ self._logger.task_message(msgs[-1].content)
879
1003
 
880
1004
  # Show warnings (top-level only, if logger supports it)
881
1005
  if self._logger.depth == 0 and isinstance(self._logger, AgentLogger):
@@ -886,25 +1010,30 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
886
1010
  # Use logger callback if available and not overridden
887
1011
  step_callback = self._logger.on_step
888
1012
 
889
- # Local metadata storage - isolated per run() invocation for thread safety
890
- run_metadata: dict[str, list[Any]] = {}
891
-
892
1013
  full_msg_history: list[list[ChatMessage]] = []
893
- finish_params: FinishParams | None = None
894
1014
 
895
1015
  # Cumulative stats for spinner
896
1016
  total_tool_calls = 0
897
1017
  total_input_tokens = 0
898
1018
  total_output_tokens = 0
899
1019
 
900
- for i in range(self._max_turns):
901
- if self._max_turns - i <= 30 and i != 0:
1020
+ for i in range(start_turn, self._max_turns):
1021
+ # Capture current state for potential caching (before any async work)
1022
+ self._current_run_state = CacheState(
1023
+ msgs=list(msgs),
1024
+ full_msg_history=[list(group) for group in full_msg_history],
1025
+ turn=i,
1026
+ run_metadata=dict(run_metadata),
1027
+ task_hash=task_hash,
1028
+ agent_name=self._name,
1029
+ )
1030
+ if self._max_turns - i <= self._turns_remaining_warning_threshold and i != 0:
902
1031
  num_turns_remaining_msg = _num_turns_remaining_msg(self._max_turns - i)
903
1032
  msgs.append(num_turns_remaining_msg)
904
1033
  self._logger.user_message(num_turns_remaining_msg)
905
1034
 
906
1035
  # Pass turn info to step() for real-time logging
907
- assistant_message, tool_messages, finish_call = await self.step(
1036
+ assistant_message, tool_messages, finish_params = await self.step(
908
1037
  msgs,
909
1038
  run_metadata,
910
1039
  turn=i + 1,
@@ -930,18 +1059,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
930
1059
 
931
1060
  msgs.extend([assistant_message, *tool_messages, *user_messages])
932
1061
 
933
- if finish_call:
934
- try:
935
- finish_arguments = json.loads(finish_call.arguments)
936
- if self._finish_tool.parameters is not None:
937
- finish_params = self._finish_tool.parameters.model_validate(finish_arguments)
938
- break
939
- except (json.JSONDecodeError, ValidationError, TypeError):
940
- LOGGER.debug(
941
- "Agent tried to use the finish tool but the tool call is not valid: %r",
942
- finish_call.arguments,
943
- )
944
- # continue until the finish tool call is valid
1062
+ if finish_params:
1063
+ break
945
1064
 
946
1065
  pct_context_used = assistant_message.token_usage.total / self._client.max_tokens
947
1066
  if pct_context_used >= self._context_summarization_cutoff and i + 1 != self._max_turns:
@@ -956,15 +1075,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
956
1075
  full_msg_history.append(msgs)
957
1076
 
958
1077
  # Add agent's own token usage to run_metadata under "token_usage" key
959
- agent_token_usage = _get_total_token_usage(full_msg_history)
960
- if "token_usage" not in run_metadata:
961
- run_metadata["token_usage"] = []
962
- run_metadata["token_usage"].append(agent_token_usage)
1078
+ run_metadata["token_usage"] = _get_total_token_usage(full_msg_history)
963
1079
 
964
1080
  # Store for __aexit__ to access (on instance for this agent)
965
1081
  self._last_finish_params = finish_params
966
1082
  self._last_run_metadata = run_metadata
967
1083
 
1084
+ # Clear cache on successful completion (finish_params is set)
1085
+ if finish_params is not None and cache_manager.clear_on_success:
1086
+ cache_manager.clear_cache(task_hash)
1087
+ self._current_task_hash = None
1088
+ self._current_run_state = None
1089
+
968
1090
  return finish_params, full_msg_history, run_metadata
969
1091
 
970
1092
  def to_tool(
@@ -1092,6 +1214,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
1092
1214
  )
1093
1215
  return ToolResult(
1094
1216
  content=f"<sub_agent_result>\n<error>{e!s}</error>\n</sub_agent_result>",
1217
+ success=False,
1095
1218
  metadata=error_metadata,
1096
1219
  )
1097
1220
  finally: