stirrup 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stirrup/core/agent.py CHANGED
@@ -2,9 +2,9 @@
2
2
  import contextvars
3
3
  import glob as glob_module
4
4
  import inspect
5
- import json
6
5
  import logging
7
6
  import re
7
+ import signal
8
8
  from contextlib import AsyncExitStack
9
9
  from dataclasses import dataclass, field
10
10
  from itertools import chain, takewhile
@@ -19,7 +19,9 @@ from stirrup.constants import (
19
19
  AGENT_MAX_TURNS,
20
20
  CONTEXT_SUMMARIZATION_CUTOFF,
21
21
  FINISH_TOOL_NAME,
22
+ TURNS_REMAINING_WARNING_THRESHOLD,
22
23
  )
24
+ from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
23
25
  from stirrup.core.models import (
24
26
  AssistantMessage,
25
27
  ChatMessage,
@@ -72,6 +74,7 @@ class SessionState:
72
74
  depth: int = 0
73
75
  uploaded_file_paths: list[str] = field(default_factory=list) # Paths of files uploaded to exec_env
74
76
  skills_metadata: list[SkillMetadata] = field(default_factory=list) # Loaded skills metadata
77
+ logger: AgentLoggerBase | None = None # Logger for pause/resume during user input
75
78
 
76
79
 
77
80
  _SESSION_STATE: contextvars.ContextVar[SessionState] = contextvars.ContextVar("session_state")
@@ -112,17 +115,19 @@ def _handle_text_only_tool_responses(tool_messages: list[ToolMessage]) -> tuple[
112
115
  return tool_messages, user_messages
113
116
 
114
117
 
115
- def _get_total_token_usage(messages: list[list[ChatMessage]]) -> TokenUsage:
116
- """Aggregate token usage across all assistant messages in grouped conversation history.
118
+ def _get_total_token_usage(messages: list[list[ChatMessage]]) -> list[TokenUsage]:
119
+ """
120
+ Returns a list of TokenUsage objects aggregated from all AssistantMessage
121
+ instances across the provided grouped message history.
117
122
 
118
123
  Args:
119
- messages: List of message groups, where each group represents a segment of conversation.
124
+ messages: A list where each item is a list of ChatMessage objects representing a segment
125
+ or turn group of the conversation history.
120
126
 
127
+ Returns:
128
+ List of TokenUsage corresponding to each AssistantMessage in the flattened conversation history.
121
129
  """
122
- return sum(
123
- [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)],
124
- start=TokenUsage(),
125
- )
130
+ return [msg.token_usage for msg in chain.from_iterable(messages) if isinstance(msg, AssistantMessage)]
126
131
 
127
132
 
128
133
  class SubAgentParams(BaseModel):
@@ -176,6 +181,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
176
181
  finish_tool: Tool[FinishParams, FinishMeta] | None = None,
177
182
  # Agent options
178
183
  context_summarization_cutoff: float = CONTEXT_SUMMARIZATION_CUTOFF,
184
+ turns_remaining_warning_threshold: int = TURNS_REMAINING_WARNING_THRESHOLD,
179
185
  run_sync_in_thread: bool = True,
180
186
  text_only_tool_responses: bool = True,
181
187
  # Logging
@@ -215,6 +221,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
215
221
  self._tools = tools if tools is not None else DEFAULT_TOOLS
216
222
  self._finish_tool: Tool = finish_tool if finish_tool is not None else SIMPLE_FINISH_TOOL
217
223
  self._context_summarization_cutoff = context_summarization_cutoff
224
+ self._turns_remaining_warning_threshold = turns_remaining_warning_threshold
218
225
  self._run_sync_in_thread = run_sync_in_thread
219
226
  self._text_only_tool_responses = text_only_tool_responses
220
227
 
@@ -225,6 +232,9 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
225
232
  self._pending_output_dir: Path | None = None
226
233
  self._pending_input_files: str | Path | list[str | Path] | None = None
227
234
  self._pending_skills_dir: Path | None = None
235
+ self._resume: bool = False
236
+ self._clear_cache_on_success: bool = True
237
+ self._cache_on_interrupt: bool = True
228
238
 
229
239
  # Instance-scoped state (populated during __aenter__, isolated per agent instance)
230
240
  self._active_tools: dict[str, Tool] = {}
@@ -232,6 +242,10 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
232
242
  self._last_run_metadata: dict[str, list[Any]] = {}
233
243
  self._transferred_paths: list[str] = [] # Paths transferred to parent (for subagents)
234
244
 
245
+ # Cache state for resumption (set during run(), used in __aexit__ for caching on interrupt)
246
+ self._current_task_hash: str | None = None
247
+ self._current_run_state: CacheState | None = None
248
+
235
249
  @property
236
250
  def name(self) -> str:
237
251
  """The name of this agent."""
@@ -262,6 +276,9 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
262
276
  output_dir: Path | str | None = None,
263
277
  input_files: str | Path | list[str | Path] | None = None,
264
278
  skills_dir: Path | str | None = None,
279
+ resume: bool = False,
280
+ clear_cache_on_success: bool = True,
281
+ cache_on_interrupt: bool = True,
265
282
  ) -> Self:
266
283
  """Configure a session and return self for use as async context manager.
267
284
 
@@ -277,6 +294,17 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
277
294
  skills_dir: Directory containing skill definitions to load and make available
278
295
  to the agent. Skills are uploaded to the execution environment
279
296
  and their metadata is included in the system prompt.
297
+ resume: If True, attempt to resume from cached state if available.
298
+ The cache is identified by hashing the init_msgs passed to run().
299
+ Cached state includes message history, current turn, and execution
300
+ environment files from a previous interrupted run.
301
+ clear_cache_on_success: If True (default), automatically clear the cache
302
+ when the agent completes successfully. Set to False
303
+ to preserve caches for inspection or debugging.
304
+ cache_on_interrupt: If True (default), set up a SIGINT handler to cache
305
+ state on Ctrl+C. Set to False when running agents in
306
+ threads or subprocesses where signal handlers cannot
307
+ be registered from non-main threads.
280
308
 
281
309
  Returns:
282
310
  Self, for use with `async with agent.session(...) as session:`
@@ -293,8 +321,19 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
293
321
  self._pending_output_dir = Path(output_dir) if output_dir else None
294
322
  self._pending_input_files = input_files
295
323
  self._pending_skills_dir = Path(skills_dir) if skills_dir else None
324
+ self._resume = resume
325
+ self._clear_cache_on_success = clear_cache_on_success
326
+ self._cache_on_interrupt = cache_on_interrupt
296
327
  return self
297
328
 
329
+ def _handle_interrupt(self, _signum: int, _frame: object) -> None:
330
+ """Handle SIGINT to ensure caching before exit.
331
+
332
+ Converts the signal to a KeyboardInterrupt exception so that __aexit__
333
+ is properly called and can cache the state before cleanup.
334
+ """
335
+ raise KeyboardInterrupt("Agent interrupted - state will be cached")
336
+
298
337
  def _resolve_input_files(self, input_files: str | Path | list[str | Path]) -> list[Path]:
299
338
  """Resolve input file paths, expanding globs and normalizing to Path objects.
300
339
 
@@ -410,6 +449,15 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
410
449
  # Base prompt with max_turns
411
450
  parts.append(BASE_SYSTEM_PROMPT_TEMPLATE.format(max_turns=self._max_turns))
412
451
 
452
+ # User interaction guidance based on whether user_input tool is available
453
+ if "user_input" in self._active_tools:
454
+ parts.append(
455
+ " You have access to the user_input tool which allows you to ask the user "
456
+ "questions when you need clarification or are uncertain about something."
457
+ )
458
+ else:
459
+ parts.append(" You are not able to interact with the user during the task.")
460
+
413
461
  # Input files section (if any were uploaded)
414
462
  state = _SESSION_STATE.get(None)
415
463
  if state and state.uploaded_file_paths:
@@ -514,6 +562,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
514
562
  output_dir=str(self._pending_output_dir) if self._pending_output_dir else None,
515
563
  parent_exec_env=parent_state.exec_env if parent_state else None,
516
564
  depth=current_depth,
565
+ logger=self._logger,
517
566
  )
518
567
  _SESSION_STATE.set(state)
519
568
 
@@ -613,6 +662,13 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
613
662
  state.skills_metadata = load_skills_metadata(skills_path)
614
663
  logger.debug("[%s __aenter__] Loaded %d skills", self._name, len(state.skills_metadata))
615
664
  self._pending_skills_dir = None # Clear pending state
665
+ elif parent_state and parent_state.skills_metadata:
666
+ # Sub-agent: inherit skills from parent
667
+ state.skills_metadata = parent_state.skills_metadata
668
+ logger.debug("[%s __aenter__] Inherited %d skills from parent", self._name, len(state.skills_metadata))
669
+ # Transfer skills directory from parent's exec_env to sub-agent's exec_env
670
+ if state.exec_env and parent_state.exec_env:
671
+ await state.exec_env.upload_files("skills", source_env=parent_state.exec_env)
616
672
 
617
673
  # Configure and enter logger context
618
674
  self._logger.name = self._name
@@ -621,6 +677,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
621
677
  # depth is already set (0 for main agent, passed in for sub-agents)
622
678
  self._logger.__enter__()
623
679
 
680
+ # Set up signal handler for graceful caching on interrupt (root agent only)
681
+ if current_depth == 0 and self._cache_on_interrupt:
682
+ self._original_sigint = signal.getsignal(signal.SIGINT)
683
+ signal.signal(signal.SIGINT, self._handle_interrupt)
684
+
624
685
  return self
625
686
 
626
687
  except Exception:
@@ -642,6 +703,47 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
642
703
  state = _SESSION_STATE.get()
643
704
 
644
705
  try:
706
+ # Cache state on non-success exit (only at root level)
707
+ should_cache = (
708
+ state.depth == 0
709
+ and (exc_type is not None or self._last_finish_params is None)
710
+ and self._current_task_hash is not None
711
+ and self._current_run_state is not None
712
+ )
713
+
714
+ logger.debug(
715
+ "[%s __aexit__] Cache decision: should_cache=%s, depth=%d, exc_type=%s, "
716
+ "finish_params=%s, task_hash=%s, run_state=%s",
717
+ self._name,
718
+ should_cache,
719
+ state.depth,
720
+ exc_type,
721
+ self._last_finish_params is not None,
722
+ self._current_task_hash,
723
+ self._current_run_state is not None,
724
+ )
725
+
726
+ if should_cache:
727
+ cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
728
+
729
+ exec_env_dir = state.exec_env.temp_dir if state.exec_env else None
730
+
731
+ # Explicit checks to keep type checker happy - should_cache condition guarantees these
732
+ if self._current_task_hash is None or self._current_run_state is None:
733
+ raise ValueError("Cache state is unexpectedly None after should_cache check")
734
+
735
+ # Temporarily block SIGINT during cache save to prevent interruption
736
+ original_handler = signal.getsignal(signal.SIGINT)
737
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
738
+ try:
739
+ cache_manager.save_state(
740
+ self._current_task_hash,
741
+ self._current_run_state,
742
+ exec_env_dir,
743
+ )
744
+ finally:
745
+ signal.signal(signal.SIGINT, original_handler)
746
+ self._logger.info(f"Cached state for task {self._current_task_hash}")
645
747
  # Save files from finish_params.paths based on depth
646
748
  if state.output_dir and self._last_finish_params and state.exec_env:
647
749
  paths = getattr(self._last_finish_params, "paths", None)
@@ -696,6 +798,11 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
696
798
  state.depth,
697
799
  )
698
800
  finally:
801
+ # Restore original signal handler (root agent only)
802
+ if hasattr(self, "_original_sigint"):
803
+ signal.signal(signal.SIGINT, self._original_sigint)
804
+ del self._original_sigint
805
+
699
806
  # Exit logger context
700
807
  self._logger.finish_params = self._last_finish_params
701
808
  self._logger.run_metadata = self._last_run_metadata
@@ -721,10 +828,9 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
721
828
 
722
829
  if tool:
723
830
  try:
724
- # Parse parameters if the tool has them, otherwise use None
725
- params = (
726
- tool.parameters.model_validate_json(tool_call.arguments) if tool.parameters is not None else None
727
- )
831
+ # Normalize empty arguments to valid empty JSON object
832
+ args = tool_call.arguments if tool_call.arguments and tool_call.arguments.strip() else "{}"
833
+ params = tool.parameters.model_validate_json(args)
728
834
 
729
835
  # Set parent depth for sub-agent tools to read
730
836
  prev_depth = _PARENT_DEPTH.set(self._logger.depth)
@@ -749,17 +855,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
749
855
  tool_call.name,
750
856
  tool_call.arguments,
751
857
  )
752
- result = ToolResult(content="Tool arguments are not valid")
858
+ result = ToolResult(content="Tool arguments are not valid", success=False)
753
859
  args_valid = False
754
860
  else:
755
861
  LOGGER.debug(f"LLMClient tried to use the tool {tool_call.name} which is not in the tools list")
756
- result = ToolResult(content=f"{tool_call.name} is not a valid tool")
862
+ result = ToolResult(content=f"{tool_call.name} is not a valid tool", success=False)
757
863
 
758
864
  return ToolMessage(
759
865
  content=result.content,
760
866
  tool_call_id=tool_call.tool_call_id,
761
867
  name=tool_call.name,
762
868
  args_was_valid=args_valid,
869
+ success=result.success,
763
870
  )
764
871
 
765
872
  async def step(
@@ -768,7 +875,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
768
875
  run_metadata: dict[str, list[Any]],
769
876
  turn: int = 0,
770
877
  max_turns: int = 0,
771
- ) -> tuple[AssistantMessage, list[ToolMessage], ToolCall | None]:
878
+ ) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None]:
772
879
  """Execute one agent step: generate assistant message and run any requested tool calls.
773
880
 
774
881
  Args:
@@ -786,24 +893,21 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
786
893
  if turn > 0:
787
894
  self._logger.assistant_message(turn, max_turns, assistant_message)
788
895
 
896
+ finish_params: FinishParams | None = None
789
897
  tool_messages: list[ToolMessage] = []
790
- finish_call: ToolCall | None = None
791
-
792
898
  if assistant_message.tool_calls:
793
- finish_call = next(
794
- (tc for tc in assistant_message.tool_calls if tc.name == FINISH_TOOL_NAME),
795
- None,
796
- )
797
-
798
899
  tool_messages = []
799
900
  for tool_call in assistant_message.tool_calls:
800
901
  tool_message = await self.run_tool(tool_call, run_metadata)
801
902
  tool_messages.append(tool_message)
802
903
 
904
+ if tool_message.success and tool_message.name == FINISH_TOOL_NAME:
905
+ finish_params = self._finish_tool.parameters.model_validate_json(tool_call.arguments)
906
+
803
907
  # Log tool result immediately
804
908
  self._logger.tool_result(tool_message)
805
909
 
806
- return assistant_message, tool_messages, finish_call
910
+ return assistant_message, tool_messages, finish_params
807
911
 
808
912
  async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]:
809
913
  """Condense message history using LLM to stay within context window."""
@@ -829,7 +933,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
829
933
  init_msgs: str | list[ChatMessage],
830
934
  *,
831
935
  depth: int | None = None,
832
- ) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, list[Any]]]:
936
+ ) -> tuple[FinishParams | None, list[list[ChatMessage]], dict[str, Any]]:
833
937
  """Execute the agent loop until finish tool is called or max_turns reached.
834
938
 
835
939
  A base system prompt is automatically prepended to all runs, including:
@@ -859,23 +963,59 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
859
963
  ])
860
964
 
861
965
  """
862
- msgs: list[ChatMessage] = []
863
966
 
864
- # Build the complete system prompt (base + input files + user instructions)
865
- full_system_prompt = self._build_system_prompt()
866
- msgs.append(SystemMessage(content=full_system_prompt))
967
+ # Compute task hash for caching/resume
968
+ task_hash = compute_task_hash(init_msgs)
969
+ self._current_task_hash = task_hash
970
+
971
+ # Initialize cache manager
972
+ cache_manager = CacheManager(clear_on_success=self._clear_cache_on_success)
973
+ start_turn = 0
974
+ resumed = False
975
+
976
+ # Try to resume from cache if requested
977
+ if self._resume:
978
+ state = _SESSION_STATE.get()
979
+ cached = cache_manager.load_state(task_hash)
980
+ if cached:
981
+ # Restore files to exec env
982
+ if state.exec_env and state.exec_env.temp_dir:
983
+ cache_manager.restore_files(task_hash, state.exec_env.temp_dir)
984
+
985
+ # Restore state
986
+ msgs = cached.msgs
987
+ full_msg_history = cached.full_msg_history
988
+ run_metadata = cached.run_metadata
989
+ start_turn = cached.turn
990
+ resumed = True
991
+ self._logger.info(f"Resuming from cached state at turn {start_turn}")
992
+ else:
993
+ self._logger.info(f"No cache found for task {task_hash}, starting fresh")
867
994
 
868
- if isinstance(init_msgs, str):
869
- msgs.append(UserMessage(content=init_msgs))
870
- else:
871
- msgs.extend(init_msgs)
995
+ if not resumed:
996
+ msgs: list[ChatMessage] = []
997
+
998
+ # Build the complete system prompt (base + input files + user instructions)
999
+ full_system_prompt = self._build_system_prompt()
1000
+ msgs.append(SystemMessage(content=full_system_prompt))
1001
+
1002
+ if isinstance(init_msgs, str):
1003
+ msgs.append(UserMessage(content=init_msgs))
1004
+ else:
1005
+ msgs.extend(init_msgs)
1006
+
1007
+ # Local metadata storage - isolated per run() invocation for thread safety
1008
+ run_metadata: dict[str, list[Any]] = {}
1009
+
1010
+ full_msg_history: list[list[ChatMessage]] = []
872
1011
 
873
1012
  # Set logger depth if provided (for sub-agent runs)
874
1013
  if depth is not None:
875
1014
  self._logger.depth = depth
876
1015
 
877
- # Log the task at run start
878
- self._logger.task_message(msgs[-1].content)
1016
+ # Log the task at run start (only if not resuming)
1017
+ if not resumed:
1018
+ self._logger.task_message(msgs[-1].content)
879
1019
 
880
1020
  # Show warnings (top-level only, if logger supports it)
881
1021
  if self._logger.depth == 0 and isinstance(self._logger, AgentLogger):
@@ -886,25 +1026,30 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
886
1026
  # Use logger callback if available and not overridden
887
1027
  step_callback = self._logger.on_step
888
1028
 
889
- # Local metadata storage - isolated per run() invocation for thread safety
890
- run_metadata: dict[str, list[Any]] = {}
891
-
892
1029
  full_msg_history: list[list[ChatMessage]] = []
893
- finish_params: FinishParams | None = None
894
1030
 
895
1031
  # Cumulative stats for spinner
896
1032
  total_tool_calls = 0
897
1033
  total_input_tokens = 0
898
1034
  total_output_tokens = 0
899
1035
 
900
- for i in range(self._max_turns):
901
- if self._max_turns - i <= 30 and i != 0:
1036
+ for i in range(start_turn, self._max_turns):
1037
+ # Capture current state for potential caching (before any async work)
1038
+ self._current_run_state = CacheState(
1039
+ msgs=list(msgs),
1040
+ full_msg_history=[list(group) for group in full_msg_history],
1041
+ turn=i,
1042
+ run_metadata=dict(run_metadata),
1043
+ task_hash=task_hash,
1044
+ agent_name=self._name,
1045
+ )
1046
+ if self._max_turns - i <= self._turns_remaining_warning_threshold and i != 0:
902
1047
  num_turns_remaining_msg = _num_turns_remaining_msg(self._max_turns - i)
903
1048
  msgs.append(num_turns_remaining_msg)
904
1049
  self._logger.user_message(num_turns_remaining_msg)
905
1050
 
906
1051
  # Pass turn info to step() for real-time logging
907
- assistant_message, tool_messages, finish_call = await self.step(
1052
+ assistant_message, tool_messages, finish_params = await self.step(
908
1053
  msgs,
909
1054
  run_metadata,
910
1055
  turn=i + 1,
@@ -930,18 +1075,8 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
930
1075
 
931
1076
  msgs.extend([assistant_message, *tool_messages, *user_messages])
932
1077
 
933
- if finish_call:
934
- try:
935
- finish_arguments = json.loads(finish_call.arguments)
936
- if self._finish_tool.parameters is not None:
937
- finish_params = self._finish_tool.parameters.model_validate(finish_arguments)
938
- break
939
- except (json.JSONDecodeError, ValidationError, TypeError):
940
- LOGGER.debug(
941
- "Agent tried to use the finish tool but the tool call is not valid: %r",
942
- finish_call.arguments,
943
- )
944
- # continue until the finish tool call is valid
1078
+ if finish_params:
1079
+ break
945
1080
 
946
1081
  pct_context_used = assistant_message.token_usage.total / self._client.max_tokens
947
1082
  if pct_context_used >= self._context_summarization_cutoff and i + 1 != self._max_turns:
@@ -956,15 +1091,18 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
956
1091
  full_msg_history.append(msgs)
957
1092
 
958
1093
  # Add agent's own token usage to run_metadata under "token_usage" key
959
- agent_token_usage = _get_total_token_usage(full_msg_history)
960
- if "token_usage" not in run_metadata:
961
- run_metadata["token_usage"] = []
962
- run_metadata["token_usage"].append(agent_token_usage)
1094
+ run_metadata["token_usage"] = _get_total_token_usage(full_msg_history)
963
1095
 
964
1096
  # Store for __aexit__ to access (on instance for this agent)
965
1097
  self._last_finish_params = finish_params
966
1098
  self._last_run_metadata = run_metadata
967
1099
 
1100
+ # Clear cache on successful completion (finish_params is set)
1101
+ if finish_params is not None and cache_manager.clear_on_success:
1102
+ cache_manager.clear_cache(task_hash)
1103
+ self._current_task_hash = None
1104
+ self._current_run_state = None
1105
+
968
1106
  return finish_params, full_msg_history, run_metadata
969
1107
 
970
1108
  def to_tool(
@@ -1092,6 +1230,7 @@ class Agent[FinishParams: BaseModel, FinishMeta]:
1092
1230
  )
1093
1231
  return ToolResult(
1094
1232
  content=f"<sub_agent_result>\n<error>{e!s}</error>\n</sub_agent_result>",
1233
+ success=False,
1095
1234
  metadata=error_metadata,
1096
1235
  )
1097
1236
  finally: