zrb 1.21.31__py3-none-any.whl → 1.21.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

@@ -129,56 +129,79 @@ class LLMRateLimitter:
129
129
  async def throttle(
130
130
  self,
131
131
  prompt: Any,
132
- throttle_notif_callback: Callable[[str], Any] | None = None,
132
+ throttle_notif_callback: Callable[..., Any] | None = None,
133
133
  ):
134
134
  now = time.time()
135
135
  str_prompt = self._prompt_to_str(prompt)
136
- tokens = self.count_token(str_prompt)
136
+ new_requested_tokens = self.count_token(str_prompt)
137
137
  # Clean up old entries
138
138
  while self.request_times and now - self.request_times[0] > 60:
139
139
  self.request_times.popleft()
140
140
  while self.token_times and now - self.token_times[0][0] > 60:
141
141
  self.token_times.popleft()
142
142
  # Check per-request token limit
143
- if tokens > self.max_tokens_per_request:
143
+ if new_requested_tokens > self.max_tokens_per_request:
144
144
  raise ValueError(
145
145
  (
146
- "Request exceeds max_tokens_per_request "
147
- f"({tokens} > {self.max_tokens_per_request})."
146
+ "New request exceeds max_tokens_per_request "
147
+ f"({new_requested_tokens} > {self.max_tokens_per_request})."
148
148
  )
149
149
  )
150
- if tokens > self.max_tokens_per_minute:
150
+ if new_requested_tokens > self.max_tokens_per_minute:
151
151
  raise ValueError(
152
152
  (
153
- "Request exceeds max_tokens_per_minute "
154
- f"({tokens} > {self.max_tokens_per_minute})."
153
+ "New request exceeds max_tokens_per_minute "
154
+ f"({new_requested_tokens} > {self.max_tokens_per_minute})."
155
155
  )
156
156
  )
157
157
  # Wait if over per-minute request or token limit
158
+ callback_new_line = True
159
+ ever_throttled = False
158
160
  while (
159
161
  len(self.request_times) >= self.max_requests_per_minute
160
- or sum(t for _, t in self.token_times) + tokens > self.max_tokens_per_minute
162
+ or sum(t for _, t in self.token_times) + new_requested_tokens
163
+ > self.max_tokens_per_minute
161
164
  ):
165
+ ever_throttled = True
162
166
  if throttle_notif_callback is not None:
163
167
  if len(self.request_times) >= self.max_requests_per_minute:
168
+ limit = self.max_requests_per_minute
164
169
  rpm = len(self.request_times)
170
+ wait_time = max(0, 60 - (now - self.request_times[0]))
165
171
  throttle_notif_callback(
166
- f"Max request per minute exceeded: {rpm} of {self.max_requests_per_minute}"
172
+ f"Max request per minute exceeded: {rpm} of {limit}. "
173
+ f"Waiting for {wait_time:.2f} seconds.",
174
+ new_line=callback_new_line,
167
175
  )
168
176
  else:
169
- tpm = sum(t for _, t in self.token_times) + tokens
177
+ limit = self.max_tokens_per_minute
178
+ current_tokens = sum(t for _, t in self.token_times)
179
+ tpm = current_tokens + new_requested_tokens
180
+ needed = tpm - self.max_tokens_per_minute
181
+ freed = 0
182
+ wait_time = 0
183
+ for t_time, t_count in self.token_times:
184
+ freed += t_count
185
+ if freed >= needed:
186
+ wait_time = max(0, 60 - (now - t_time))
187
+ break
170
188
  throttle_notif_callback(
171
- f"Max token per minute exceeded: {tpm} of {self.max_tokens_per_minute}"
189
+ f"Max token per minute exceeded: {tpm} of {limit}. "
190
+ f"Waiting for {wait_time:.2f} seconds.",
191
+ new_line=callback_new_line,
172
192
  )
193
+ callback_new_line = False
173
194
  await asyncio.sleep(self.throttle_sleep)
174
195
  now = time.time()
175
196
  while self.request_times and now - self.request_times[0] > 60:
176
197
  self.request_times.popleft()
177
198
  while self.token_times and now - self.token_times[0][0] > 60:
178
199
  self.token_times.popleft()
200
+ if ever_throttled and throttle_notif_callback is not None:
201
+ throttle_notif_callback("", new_line=True)
179
202
  # Record this request
180
203
  self.request_times.append(now)
181
- self.token_times.append((now, tokens))
204
+ self.token_times.append((now, new_requested_tokens))
182
205
 
183
206
  def _prompt_to_str(self, prompt: Any) -> str:
184
207
  try:
zrb/input/option_input.py CHANGED
@@ -1,7 +1,13 @@
1
+ from typing import TYPE_CHECKING
2
+
1
3
  from zrb.attr.type import StrAttr, StrListAttr
2
4
  from zrb.context.any_shared_context import AnySharedContext
3
5
  from zrb.input.base_input import BaseInput
4
6
  from zrb.util.attr import get_str_list_attr
7
+ from zrb.util.match import fuzzy_match
8
+
9
+ if TYPE_CHECKING:
10
+ from prompt_toolkit.completion import Completer
5
11
 
6
12
 
7
13
  class OptionInput(BaseInput):
@@ -58,10 +64,32 @@ class OptionInput(BaseInput):
58
64
  self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
59
65
  ) -> str:
60
66
  from prompt_toolkit import PromptSession
61
- from prompt_toolkit.completion import WordCompleter
62
67
 
63
68
  if shared_ctx.is_tty:
64
69
  reader = PromptSession()
65
- option_completer = WordCompleter(options, ignore_case=True)
70
+ option_completer = self._get_option_completer(options)
66
71
  return reader.prompt(f"{prompt_message}: ", completer=option_completer)
67
72
  return input(f"{prompt_message}: ")
73
+
74
+ def _get_option_completer(self, options: list[str]) -> "Completer":
75
+ from prompt_toolkit.completion import CompleteEvent, Completer, Completion
76
+ from prompt_toolkit.document import Document
77
+
78
+ class OptionCompleter(Completer):
79
+ def __init__(self, options: list[str]):
80
+ self._options = options
81
+
82
+ def get_completions(
83
+ self, document: Document, complete_event: CompleteEvent
84
+ ):
85
+ search_pattern = document.get_word_before_cursor(WORD=True)
86
+ candidates = []
87
+ for option in self._options:
88
+ matched, score = fuzzy_match(option, search_pattern)
89
+ if matched:
90
+ candidates.append((score, option))
91
+ candidates.sort(key=lambda x: (x[0], x[1]))
92
+ for _, option in candidates:
93
+ yield Completion(option, start_position=-len(search_pattern))
94
+
95
+ return OptionCompleter(options)
zrb/task/cmd_task.py CHANGED
@@ -48,6 +48,7 @@ class CmdTask(BaseTask):
48
48
  warn_unrecommended_command: bool | None = None,
49
49
  max_output_line: int = 1000,
50
50
  max_error_line: int = 1000,
51
+ execution_timeout: int = 3600,
51
52
  is_interactive: bool = False,
52
53
  execute_condition: BoolAttr = True,
53
54
  retries: int = 2,
@@ -103,6 +104,7 @@ class CmdTask(BaseTask):
103
104
  self._render_cwd = render_cwd
104
105
  self._max_output_line = max_output_line
105
106
  self._max_error_line = max_error_line
107
+ self._execution_timeout = execution_timeout
106
108
  self._should_plain_print = plain_print
107
109
  self._should_warn_unrecommended_command = warn_unrecommended_command
108
110
  self._is_interactive = is_interactive
@@ -142,6 +144,7 @@ class CmdTask(BaseTask):
142
144
  register_pid_method=lambda pid: ctx.xcom.get(xcom_pid_key).push(pid),
143
145
  max_output_line=self._max_output_line,
144
146
  max_error_line=self._max_error_line,
147
+ timeout=self._execution_timeout,
145
148
  is_interactive=self._is_interactive,
146
149
  )
147
150
  # Check for errors
@@ -112,8 +112,12 @@ async def _run_single_agent_iteration(
112
112
 
113
113
 
114
114
  def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[str], None]:
115
- def _print_throttle_notif(reason: str):
116
- ctx.print(stylize_faint(f" ⌛>> Request Throttled: {reason}"), plain=True)
115
+ def _print_throttle_notif(text: str, *args: Any, **kwargs: Any):
116
+ new_line = kwargs.get("new_line", True)
117
+ prefix = "\r" if not new_line else "\n"
118
+ if text != "":
119
+ prefix = f"{prefix} 🐢 Request Throttled: "
120
+ ctx.print(stylize_faint(f"{prefix}{text}"), plain=True, end="")
117
121
 
118
122
  return _print_throttle_notif
119
123
 
@@ -0,0 +1,13 @@
1
+ from typing import Any
2
+
3
+
4
+ def remove_system_prompt_and_instruction(
5
+ history_list: list[dict[str, Any]],
6
+ ) -> list[dict[str, Any]]:
7
+ for msg in history_list:
8
+ if msg.get("instructions"):
9
+ msg["instructions"] = ""
10
+ msg["parts"] = [
11
+ p for p in msg.get("parts", []) if p.get("part_kind") != "system-prompt"
12
+ ]
13
+ return history_list
@@ -88,30 +88,21 @@ def create_summarize_history_processor(
88
88
  messages: list[ModelMessage],
89
89
  ) -> list[ModelMessage]:
90
90
  history_list = json.loads(ModelMessagesTypeAdapter.dump_json(messages))
91
- history_json_str = json.dumps(history_list)
91
+ history_list_str = json.dumps(history_list)
92
92
  # Estimate token usage
93
93
  # Note: Pydantic ai has run context parameter
94
94
  # (https://ai.pydantic.dev/message-history/#runcontext-parameter)
95
95
  # But we cannot use run_ctx.usage.total_tokens because total token keep increasing
96
96
  # even after summariztion.
97
- estimated_token_usage = rate_limitter.count_token(history_json_str)
97
+ estimated_token_usage = rate_limitter.count_token(history_list_str)
98
98
  _print_request_info(
99
99
  ctx, estimated_token_usage, summarization_token_threshold, messages
100
100
  )
101
101
  if estimated_token_usage < summarization_token_threshold or len(messages) == 1:
102
102
  return messages
103
- history_list_without_instruction = [
104
- {
105
- key: obj[key]
106
- for key in obj
107
- if index == len(history_list) - 1 or key != "instructions"
108
- }
109
- for index, obj in enumerate(history_list)
110
- ]
111
- history_json_str_without_instruction = json.dumps(
112
- history_list_without_instruction
103
+ summarization_message = (
104
+ f"Summarize the following conversation: {history_list_str}"
113
105
  )
114
- summarization_message = f"Summarize the following conversation: {history_json_str_without_instruction}"
115
106
  summarization_agent = Agent[None, ConversationSummary](
116
107
  model=summarization_model,
117
108
  output_type=save_conversation_summary,
@@ -24,15 +24,19 @@ async def print_node(
24
24
  )
25
25
 
26
26
  meta = getattr(node, "id", None) or getattr(node, "request_id", None)
27
+ progress_char_list = ["|", "/", "-", "\\"]
28
+ progress_index = 0
27
29
  if Agent.is_user_prompt_node(node):
28
30
  print_func(_format_header("🔠 Receiving input...", log_indent_level))
29
- elif Agent.is_model_request_node(node):
31
+ return
32
+ if Agent.is_model_request_node(node):
30
33
  # A model request node => We can stream tokens from the model's request
31
34
  print_func(_format_header("🧠 Processing...", log_indent_level))
32
35
  # Reference: https://ai.pydantic.dev/agents/#streaming-all-events-and-output
33
36
  try:
34
37
  async with node.stream(agent_run.ctx) as request_stream:
35
38
  is_streaming = False
39
+ is_tool_processing = False
36
40
  async for event in request_stream:
37
41
  if isinstance(event, PartStartEvent) and event.part:
38
42
  if is_streaming:
@@ -40,29 +44,59 @@ async def print_node(
40
44
  content = _get_event_part_content(event)
41
45
  print_func(_format_content(content, log_indent_level), end="")
42
46
  is_streaming = True
43
- elif isinstance(event, PartDeltaEvent):
47
+ is_tool_processing = False
48
+ continue
49
+ if isinstance(event, PartDeltaEvent):
44
50
  if isinstance(event.delta, TextPartDelta):
45
51
  content_delta = event.delta.content_delta
46
52
  print_func(
47
53
  _format_stream_content(content_delta, log_indent_level),
48
54
  end="",
49
55
  )
50
- elif isinstance(event.delta, ThinkingPartDelta):
56
+ is_tool_processing = False
57
+ is_streaming = True
58
+ continue
59
+ if isinstance(event.delta, ThinkingPartDelta):
51
60
  content_delta = event.delta.content_delta
52
61
  print_func(
53
62
  _format_stream_content(content_delta, log_indent_level),
54
63
  end="",
55
64
  )
56
- elif isinstance(event.delta, ToolCallPartDelta):
57
- args_delta = event.delta.args_delta
58
- if isinstance(args_delta, dict):
59
- args_delta = json.dumps(args_delta)
65
+ is_tool_processing = False
66
+ is_streaming = True
67
+ continue
68
+ if isinstance(event.delta, ToolCallPartDelta):
69
+ if CFG.LLM_SHOW_TOOL_CALL_PREPARATION:
70
+ args_delta = event.delta.args_delta
71
+ if isinstance(args_delta, dict):
72
+ args_delta = json.dumps(args_delta)
73
+ print_func(
74
+ _format_stream_content(
75
+ args_delta, log_indent_level
76
+ ),
77
+ end="",
78
+ )
79
+ is_streaming = True
80
+ is_tool_processing = True
81
+ continue
82
+ prefix = "\n" if not is_tool_processing else ""
83
+ progress_char = progress_char_list[progress_index]
60
84
  print_func(
61
- _format_stream_content(args_delta, log_indent_level),
85
+ _format_content(
86
+ f"Preparing Tool Parameters... {progress_char}",
87
+ log_indent_level,
88
+ prefix=f"\r{prefix}",
89
+ ),
62
90
  end="",
63
91
  )
92
+ progress_index += 1
93
+ if progress_index >= len(progress_char_list):
94
+ progress_index = 0
95
+ is_tool_processing = True
96
+ is_streaming = True
97
+ continue
64
98
  is_streaming = True
65
- elif isinstance(event, FinalResultEvent) and event.tool_name:
99
+ if isinstance(event, FinalResultEvent) and event.tool_name:
66
100
  if is_streaming:
67
101
  print_func("")
68
102
  tool_name = event.tool_name
@@ -72,6 +106,7 @@ async def print_node(
72
106
  )
73
107
  )
74
108
  is_streaming = False
109
+ is_tool_processing = False
75
110
  if is_streaming:
76
111
  print_func("")
77
112
  except UnexpectedModelBehavior as e:
@@ -85,7 +120,8 @@ async def print_node(
85
120
  log_indent_level,
86
121
  )
87
122
  )
88
- elif Agent.is_call_tools_node(node):
123
+ return
124
+ if Agent.is_call_tools_node(node):
89
125
  # A handle-response node => The model returned some data, potentially calls a tool
90
126
  print_func(_format_header("🧰 Calling Tool...", log_indent_level))
91
127
  try:
@@ -100,7 +136,8 @@ async def print_node(
100
136
  f"{call_id} | Call {tool_name} {args}", log_indent_level
101
137
  )
102
138
  )
103
- elif (
139
+ continue
140
+ if (
104
141
  isinstance(event, FunctionToolResultEvent)
105
142
  and event.tool_call_id
106
143
  ):
@@ -113,12 +150,10 @@ async def print_node(
113
150
  log_indent_level,
114
151
  )
115
152
  )
116
- else:
117
- print_func(
118
- _format_content(
119
- f"{call_id} | Executed", log_indent_level
120
- )
121
- )
153
+ continue
154
+ print_func(
155
+ _format_content(f"{call_id} | Executed", log_indent_level)
156
+ )
122
157
  except UnexpectedModelBehavior as e:
123
158
  print_func("") # ensure newline consistency
124
159
  print_func(
@@ -130,9 +165,11 @@ async def print_node(
130
165
  log_indent_level,
131
166
  )
132
167
  )
133
- elif Agent.is_end_node(node):
168
+ return
169
+ if Agent.is_end_node(node):
134
170
  # Once an End node is reached, the agent run is complete
135
171
  print_func(_format_header("✅ Completed...", log_indent_level))
172
+ return
136
173
 
137
174
 
138
175
  def _format_header(text: str | None, log_indent_level: int = 0) -> str:
@@ -145,8 +182,10 @@ def _format_header(text: str | None, log_indent_level: int = 0) -> str:
145
182
  )
146
183
 
147
184
 
148
- def _format_content(text: str | None, log_indent_level: int = 0) -> str:
149
- return _format(
185
+ def _format_content(
186
+ text: str | None, log_indent_level: int = 0, prefix: str = ""
187
+ ) -> str:
188
+ return prefix + _format(
150
189
  text,
151
190
  base_indent=2,
152
191
  first_indent=3,
@@ -155,8 +194,10 @@ def _format_content(text: str | None, log_indent_level: int = 0) -> str:
155
194
  )
156
195
 
157
196
 
158
- def _format_stream_content(text: str | None, log_indent_level: int = 0) -> str:
159
- return _format(
197
+ def _format_stream_content(
198
+ text: str | None, log_indent_level: int = 0, prefix: str = ""
199
+ ) -> str:
200
+ return prefix + _format(
160
201
  text,
161
202
  base_indent=2,
162
203
  indent=3,
@@ -207,7 +248,7 @@ def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
207
248
  return {key: _truncate_arg(val) for key, val in kwargs.items()}
208
249
 
209
250
 
210
- def _truncate_arg(arg: str, length: int = 19) -> str:
251
+ def _truncate_arg(arg: str, length: int = 30) -> str:
211
252
  if isinstance(arg, str) and len(arg) > length:
212
253
  return f"{arg[:length-4]} ..."
213
254
  return arg
@@ -136,9 +136,12 @@ def _create_wrapper(
136
136
  result = await run_async(func(*args, **kwargs))
137
137
  _check_tool_call_result_limit(result)
138
138
  if has_ever_edited:
139
+ serializable_kwargs = kwargs.copy()
140
+ if any_context_param_name is not None:
141
+ serializable_kwargs.pop(any_context_param_name, None)
139
142
  return {
140
143
  "tool_call_result": result,
141
- "new_tool_parameters": kwargs,
144
+ "new_tool_parameters": serializable_kwargs,
142
145
  "message": "User correction: Tool was called with user's parameters",
143
146
  }
144
147
  return result
zrb/task/llm/workflow.py CHANGED
@@ -44,26 +44,40 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
44
44
  workflow_name.strip().lower(): workflow
45
45
  for workflow_name, workflow in llm_context_config.get_workflows().items()
46
46
  }
47
- # Define builtin workflow locations in order of precedence
48
- builtin_workflow_locations = [
49
- os.path.expanduser(additional_builtin_workflow_path)
50
- for additional_builtin_workflow_path in CFG.LLM_BUILTIN_WORKFLOW_PATHS
51
- if os.path.isdir(os.path.expanduser(additional_builtin_workflow_path))
52
- ]
53
- builtin_workflow_locations.append(
54
- os.path.join(os.path.dirname(__file__), "default_workflow")
47
+ workflow_hidden_folder = f".{CFG.ROOT_GROUP_NAME}"
48
+ # Define workflow locations in order of precedence
49
+ default_workflow_locations = (
50
+ [
51
+ # Project specific + user specific workflows
52
+ os.path.join(
53
+ os.path.dirname(__file__), workflow_hidden_folder, "workflows"
54
+ ),
55
+ os.path.join(os.path.dirname(__file__), workflow_hidden_folder, "skills"),
56
+ os.path.join(os.path.dirname(__file__), ".claude", "skills"),
57
+ os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "workflows"),
58
+ os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "skills"),
59
+ os.path.join(os.path.expanduser("~"), ".claude", "skills"),
60
+ ]
61
+ + [
62
+ # User defined builtin workflows
63
+ os.path.expanduser(additional_builtin_workflow_path)
64
+ for additional_builtin_workflow_path in CFG.LLM_BUILTIN_WORKFLOW_PATHS
65
+ if os.path.isdir(os.path.expanduser(additional_builtin_workflow_path))
66
+ ]
67
+ + [
68
+ # Zrb builtin workflows
69
+ os.path.join(os.path.dirname(__file__), "default_workflow"),
70
+ ]
55
71
  )
56
72
  # Load workflows from all locations
57
- for workflow_location in builtin_workflow_locations:
73
+ for workflow_location in default_workflow_locations:
58
74
  if not os.path.isdir(workflow_location):
59
75
  continue
60
76
  for workflow_name in os.listdir(workflow_location):
61
77
  workflow_dir = os.path.join(workflow_location, workflow_name)
62
- workflow_file = os.path.join(workflow_dir, "workflow.md")
63
- if not os.path.isfile(workflow_file):
64
- workflow_file = os.path.join(workflow_dir, "SKILL.md")
65
- if not os.path.isfile(path=workflow_file):
66
- continue
78
+ workflow_file = _get_workflow_file_name(workflow_dir)
79
+ if not workflow_file:
80
+ continue
67
81
  # Only add if not already defined (earlier locations have precedence)
68
82
  if workflow_name not in available_workflows:
69
83
  with open(workflow_file, "r") as f:
@@ -74,3 +88,16 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
74
88
  content=workflow_content,
75
89
  )
76
90
  return available_workflows
91
+
92
+
93
+ def _get_workflow_file_name(workflow_dir: str) -> str | None:
94
+ workflow_file = os.path.join(workflow_dir, "workflow.md")
95
+ if os.path.isfile(workflow_file):
96
+ return workflow_file
97
+ workflow_file = os.path.join(workflow_dir, "WORKFLOW.md")
98
+ if os.path.isfile(workflow_file):
99
+ return workflow_file
100
+ workflow_file = os.path.join(workflow_dir, "SKILL.md")
101
+ if os.path.isfile(workflow_file):
102
+ return workflow_file
103
+ return None
zrb/task/llm_task.py CHANGED
@@ -22,6 +22,7 @@ from zrb.task.llm.conversation_history import (
22
22
  write_conversation_history,
23
23
  )
24
24
  from zrb.task.llm.conversation_history_model import ConversationHistory
25
+ from zrb.task.llm.history_list import remove_system_prompt_and_instruction
25
26
  from zrb.task.llm.history_summarization import get_history_summarization_token_threshold
26
27
  from zrb.task.llm.prompt import (
27
28
  get_attachments,
@@ -134,9 +135,6 @@ class LLMTask(BaseTask):
134
135
  upstream: list[AnyTask] | AnyTask | None = None,
135
136
  fallback: list[AnyTask] | AnyTask | None = None,
136
137
  successor: list[AnyTask] | AnyTask | None = None,
137
- conversation_context: (
138
- dict[str, Any] | Callable[[AnyContext], dict[str, Any]] | None
139
- ) = None,
140
138
  ):
141
139
  super().__init__(
142
140
  name=name,
@@ -204,7 +202,6 @@ class LLMTask(BaseTask):
204
202
  render_history_summarization_token_threshold
205
203
  )
206
204
  self._max_call_iteration = max_call_iteration
207
- self._conversation_context = conversation_context
208
205
  self._yolo_mode = yolo_mode
209
206
  self._render_yolo_mode = render_yolo_mode
210
207
  self._attachment = attachment
@@ -362,7 +359,9 @@ class LLMTask(BaseTask):
362
359
  rate_limitter=self._rate_limitter,
363
360
  )
364
361
  if agent_run and agent_run.result:
365
- new_history_list = json.loads(agent_run.result.all_messages_json())
362
+ new_history_list = remove_system_prompt_and_instruction(
363
+ json.loads(agent_run.result.all_messages_json())
364
+ )
366
365
  conversation_history.history = new_history_list
367
366
  xcom_usage_key = f"{self.name}-usage"
368
367
  if xcom_usage_key not in ctx.xcom:
zrb/task/rsync_task.py CHANGED
@@ -44,6 +44,7 @@ class RsyncTask(CmdTask):
44
44
  plain_print: bool = False,
45
45
  max_output_line: int = 1000,
46
46
  max_error_line: int = 1000,
47
+ execution_timeout: int = 3600,
47
48
  execute_condition: BoolAttr = True,
48
49
  retries: int = 2,
49
50
  retry_period: float = 0,
@@ -77,6 +78,7 @@ class RsyncTask(CmdTask):
77
78
  plain_print=plain_print,
78
79
  max_output_line=max_output_line,
79
80
  max_error_line=max_error_line,
81
+ execution_timeout=execution_timeout,
80
82
  execute_condition=execute_condition,
81
83
  retries=retries,
82
84
  retry_period=retry_period,
zrb/util/cmd/command.py CHANGED
@@ -5,7 +5,7 @@ import signal
5
5
  import sys
6
6
  from collections import deque
7
7
  from collections.abc import Callable
8
- from typing import TextIO
8
+ from typing import Any, TextIO
9
9
 
10
10
  import psutil
11
11
 
@@ -62,6 +62,8 @@ async def run_command(
62
62
  register_pid_method: Callable[[int], None] | None = None,
63
63
  max_output_line: int = 1000,
64
64
  max_error_line: int = 1000,
65
+ max_display_line: int | None = None,
66
+ timeout: int = 3600,
65
67
  is_interactive: bool = False,
66
68
  ) -> tuple[CmdResult, int]:
67
69
  """
@@ -77,6 +79,8 @@ async def run_command(
77
79
  actual_print_method = print_method if print_method is not None else print
78
80
  if cwd is None:
79
81
  cwd = os.getcwd()
82
+ if max_display_line is None:
83
+ max_display_line = max(max_output_line, max_error_line)
80
84
  # While environment variables alone weren't the fix, they are still
81
85
  # good practice for encouraging simpler output from tools.
82
86
  child_env = (env_map or os.environ).copy()
@@ -95,17 +99,33 @@ async def run_command(
95
99
  if register_pid_method is not None:
96
100
  register_pid_method(cmd_process.pid)
97
101
  # Use the new, simple, and correct stream reader.
102
+ display_lines = deque(maxlen=max_display_line if max_display_line > 0 else 0)
98
103
  stdout_task = asyncio.create_task(
99
- __read_stream(cmd_process.stdout, actual_print_method, max_output_line)
104
+ __read_stream(
105
+ cmd_process.stdout, actual_print_method, max_output_line, display_lines
106
+ )
100
107
  )
101
108
  stderr_task = asyncio.create_task(
102
- __read_stream(cmd_process.stderr, actual_print_method, max_error_line)
109
+ __read_stream(
110
+ cmd_process.stderr, actual_print_method, max_error_line, display_lines
111
+ )
112
+ )
113
+ timeout_task = (
114
+ asyncio.create_task(asyncio.sleep(timeout)) if timeout and timeout > 0 else None
103
115
  )
104
116
  try:
105
- return_code = await cmd_process.wait()
117
+ wait_task = asyncio.create_task(cmd_process.wait())
118
+ done, pending = await asyncio.wait(
119
+ {wait_task, timeout_task} if timeout_task else {wait_task},
120
+ return_when=asyncio.FIRST_COMPLETED,
121
+ )
122
+ if timeout_task and timeout_task in done:
123
+ raise asyncio.TimeoutError()
124
+ return_code = wait_task.result()
106
125
  stdout, stderr = await asyncio.gather(stdout_task, stderr_task)
107
- return CmdResult(stdout, stderr), return_code
108
- except (KeyboardInterrupt, asyncio.CancelledError):
126
+ display = "\r\n".join(display_lines)
127
+ return CmdResult(stdout, stderr, display=display), return_code
128
+ except (KeyboardInterrupt, asyncio.CancelledError, asyncio.TimeoutError):
109
129
  try:
110
130
  os.killpg(cmd_process.pid, signal.SIGINT)
111
131
  await asyncio.wait_for(cmd_process.wait(), timeout=2.0)
@@ -133,13 +153,14 @@ def __get_cmd_stdin(is_interactive: bool) -> int | TextIO:
133
153
  async def __read_stream(
134
154
  stream: asyncio.StreamReader,
135
155
  print_method: Callable[..., None],
136
- max_lines: int,
156
+ max_line: int,
157
+ display_queue: deque[Any],
137
158
  ) -> str:
138
159
  """
139
160
  Reads from the stream using the robust `readline()` and correctly
140
161
  interprets carriage returns (`\r`) as distinct print events.
141
162
  """
142
- captured_lines = deque(maxlen=max_lines if max_lines > 0 else 0)
163
+ captured_lines = deque(maxlen=max_line if max_line > 0 else 0)
143
164
  while True:
144
165
  try:
145
166
  line_bytes = await stream.readline()
@@ -149,8 +170,9 @@ async def __read_stream(
149
170
  # Safety valve for the memory limit.
150
171
  error_msg = "[ERROR] A single line of output was too long to process."
151
172
  print_method(error_msg)
152
- if max_lines > 0:
173
+ if max_line > 0:
153
174
  captured_lines.append(error_msg)
175
+ display_queue.append(error_msg)
154
176
  break
155
177
  except (KeyboardInterrupt, asyncio.CancelledError):
156
178
  raise
@@ -165,8 +187,9 @@ async def __read_stream(
165
187
  print_method(clean_part, end="\r\n")
166
188
  except Exception:
167
189
  print_method(clean_part)
168
- if max_lines > 0:
190
+ if max_line > 0:
169
191
  captured_lines.append(clean_part)
192
+ display_queue.append(clean_part)
170
193
  return "\r\n".join(captured_lines)
171
194
 
172
195