zrb 1.21.37__py3-none-any.whl → 1.21.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

@@ -17,13 +17,10 @@ def read_chat_conversation(ctx: AnyContext) -> dict[str, Any] | list | None:
17
17
  return None # Indicate no history to load
18
18
  previous_session_name = ctx.input.previous_session
19
19
  if not previous_session_name: # Check for empty string or None
20
- last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
21
- if os.path.isfile(last_session_file_path):
22
- previous_session_name = read_file(last_session_file_path).strip()
23
- if not previous_session_name: # Handle empty last-session file
24
- return None
25
- else:
26
- return None # No previous session specified and no last session found
20
+ last_session_name = get_last_session_name()
21
+ if last_session_name is None:
22
+ return None
23
+ previous_session_name = last_session_name
27
24
  conversation_file_path = os.path.join(
28
25
  CFG.LLM_HISTORY_DIR, f"{previous_session_name}.json"
29
26
  )
@@ -51,6 +48,16 @@ def read_chat_conversation(ctx: AnyContext) -> dict[str, Any] | list | None:
51
48
  return None
52
49
 
53
50
 
51
+ def get_last_session_name() -> str | None:
52
+ last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
53
+ if not os.path.isfile(last_session_file_path):
54
+ return None
55
+ last_session_name = read_file(last_session_file_path).strip()
56
+ if not last_session_name: # Handle empty last-session file
57
+ return None
58
+ return last_session_name
59
+
60
+
54
61
  def write_chat_conversation(ctx: AnyContext, history_data: ConversationHistory):
55
62
  """Writes the conversation history data (including context) to a session file."""
56
63
  os.makedirs(CFG.LLM_HISTORY_DIR, exist_ok=True)
@@ -31,6 +31,11 @@ from zrb.builtin.llm.tool.web import (
31
31
  create_search_internet_tool,
32
32
  open_web_page,
33
33
  )
34
+ from zrb.builtin.llm.xcom_names import (
35
+ LLM_ASK_ERROR_XCOM_NAME,
36
+ LLM_ASK_RESULT_XCOM_NAME,
37
+ LLM_ASK_SESSION_XCOM_NAME,
38
+ )
34
39
  from zrb.callback.callback import Callback
35
40
  from zrb.config.config import CFG
36
41
  from zrb.config.llm_config import llm_config
@@ -40,6 +45,7 @@ from zrb.input.bool_input import BoolInput
40
45
  from zrb.input.str_input import StrInput
41
46
  from zrb.input.text_input import TextInput
42
47
  from zrb.task.base_trigger import BaseTrigger
48
+ from zrb.task.llm.workflow import LLM_LOADED_WORKFLOW_XCOM_NAME
43
49
  from zrb.task.llm_task import LLMTask
44
50
  from zrb.util.string.conversion import to_boolean
45
51
 
@@ -99,6 +105,8 @@ def _get_default_yolo_mode(ctx: AnyContext) -> str:
99
105
 
100
106
 
101
107
  def _render_yolo_mode_input(ctx: AnyContext) -> list[str] | bool:
108
+ if isinstance(ctx.input.yolo, bool):
109
+ return ctx.input.yolo
102
110
  if ctx.input.yolo.strip() == "":
103
111
  return []
104
112
  elements = [element.strip() for element in ctx.input.yolo.split(",")]
@@ -172,9 +180,9 @@ def _get_inputs(require_message: bool = True) -> list[AnyInput | None]:
172
180
  always_prompt=False,
173
181
  ),
174
182
  TextInput(
175
- "workflows",
176
- description="Workflows",
177
- prompt="Workflows",
183
+ "workflow",
184
+ description="Workflows (comma separated)",
185
+ prompt="Workflows (comma separated)",
178
186
  default=lambda ctx: ",".join(llm_config.default_workflows),
179
187
  allow_positional_parsing=False,
180
188
  always_prompt=False,
@@ -237,7 +245,7 @@ llm_ask = LLMTask(
237
245
  None if ctx.input.system_prompt.strip() == "" else ctx.input.system_prompt
238
246
  ),
239
247
  workflows=lambda ctx: (
240
- None if ctx.input.workflows.strip() == "" else ctx.input.workflows.split(",")
248
+ None if ctx.input.workflow.strip() == "" else ctx.input.workflow.split(",")
241
249
  ),
242
250
  attachment=_render_attach_input,
243
251
  message="{ctx.input.message}",
@@ -258,9 +266,10 @@ llm_group.add_task(
258
266
  callback=Callback(
259
267
  task=llm_ask,
260
268
  input_mapping=get_llm_ask_input_mapping,
261
- result_queue="ask_result",
262
- error_queue="ask_error",
263
- session_name_queue="ask_session_name",
269
+ xcom_mapping={LLM_LOADED_WORKFLOW_XCOM_NAME: LLM_LOADED_WORKFLOW_XCOM_NAME},
270
+ result_queue=LLM_ASK_RESULT_XCOM_NAME,
271
+ error_queue=LLM_ASK_ERROR_XCOM_NAME,
272
+ session_name_queue=LLM_ASK_SESSION_XCOM_NAME,
264
273
  ),
265
274
  retries=0,
266
275
  cli_only=True,
@@ -274,8 +274,9 @@ def write_to_file(
274
274
  - CORRECT: "content": "He said \"Hello\""
275
275
  - WRONG: "content": "He said \\"Hello\\"" <-- This breaks JSON parsing!
276
276
  2. **SIZE LIMIT:** Content MUST NOT exceed 4000 characters.
277
- - Exceeding this causes truncation and EOF errors.
278
- - Split larger content into multiple sequential calls (first 'w', then 'a').
277
+ - **STRICT PROHIBITION:** You are FORBIDDEN from writing more than 4000 characters in a single call.
278
+ - This is due to LLM output token limits, which will cause truncation and failure.
279
+ - To write larger files, you MUST split the content into multiple sequential calls (e.g., first 'w', then 'a').
279
280
 
280
281
  Examples:
281
282
  ```
@@ -1,7 +1,5 @@
1
1
  from typing import Any
2
2
 
3
- import requests
4
-
5
3
  from zrb.config.config import CFG
6
4
 
7
5
 
@@ -36,6 +34,8 @@ def search_internet(
36
34
  Returns:
37
35
  dict: Summary of search results (titles, links, snippets).
38
36
  """
37
+ import requests
38
+
39
39
  if safe_search is None:
40
40
  safe_search = CFG.BRAVE_API_SAFE
41
41
  if language is None:
@@ -1,7 +1,5 @@
1
1
  from typing import Any
2
2
 
3
- import requests
4
-
5
3
  from zrb.config.config import CFG
6
4
 
7
5
 
@@ -36,6 +34,8 @@ def search_internet(
36
34
  Returns:
37
35
  dict: Summary of search results (titles, links, snippets).
38
36
  """
37
+ import requests
38
+
39
39
  if safe_search is None:
40
40
  safe_search = CFG.SEARXNG_SAFE
41
41
  if language is None:
@@ -1,7 +1,5 @@
1
1
  from typing import Any
2
2
 
3
- import requests
4
-
5
3
  from zrb.config.config import CFG
6
4
 
7
5
 
@@ -36,6 +34,8 @@ def search_internet(
36
34
  Returns:
37
35
  dict: Summary of search results (titles, links, snippets).
38
36
  """
37
+ import requests
38
+
39
39
  if safe_search is None:
40
40
  safe_search = CFG.SERPAPI_SAFE
41
41
  if language is None:
@@ -0,0 +1,3 @@
1
+ LLM_ASK_RESULT_XCOM_NAME = "ask_result"
2
+ LLM_ASK_ERROR_XCOM_NAME = "ask_error"
3
+ LLM_ASK_SESSION_XCOM_NAME = "ask_session_name"
zrb/callback/callback.py CHANGED
@@ -6,7 +6,6 @@ from zrb.callback.any_callback import AnyCallback
6
6
  from zrb.session.any_session import AnySession
7
7
  from zrb.task.any_task import AnyTask
8
8
  from zrb.util.attr import get_str_dict_attr
9
- from zrb.util.cli.style import stylize_faint
10
9
  from zrb.util.string.conversion import to_snake_case
11
10
  from zrb.xcom.xcom import Xcom
12
11
 
@@ -24,6 +23,7 @@ class Callback(AnyCallback):
24
23
  task: AnyTask,
25
24
  input_mapping: StrDictAttr,
26
25
  render_input_mapping: bool = True,
26
+ xcom_mapping: dict[str, str] | None = None,
27
27
  result_queue: str | None = None,
28
28
  error_queue: str | None = None,
29
29
  session_name_queue: str | None = None,
@@ -36,6 +36,7 @@ class Callback(AnyCallback):
36
36
  input_mapping: A dictionary or attribute mapping to prepare inputs for the task.
37
37
  render_input_mapping: Whether to render the input mapping using
38
38
  f-string like syntax.
39
+ xcom_mapping: Map of parent session's xcom names to current session's xcom names
39
40
  result_queue: The name of the XCom queue in the parent session
40
41
  to publish the task result.
41
42
  result_queue: The name of the Xcom queue in the parent session
@@ -46,6 +47,7 @@ class Callback(AnyCallback):
46
47
  self._task = task
47
48
  self._input_mapping = input_mapping
48
49
  self._render_input_mapping = render_input_mapping
50
+ self._xcom_mapping = xcom_mapping
49
51
  self._result_queue = result_queue
50
52
  self._error_queue = error_queue
51
53
  self._session_name_queue = session_name_queue
@@ -63,6 +65,11 @@ class Callback(AnyCallback):
63
65
  for name, value in inputs.items():
64
66
  session.shared_ctx.input[name] = value
65
67
  session.shared_ctx.input[to_snake_case(name)] = value
68
+ # map xcom
69
+ if self._xcom_mapping is not None:
70
+ for parent_xcom_name, current_xcom_name in self._xcom_mapping.items():
71
+ parent_xcom = parent_session.shared_ctx.xcom[parent_xcom_name]
72
+ session.shared_ctx.xcom[current_xcom_name] = parent_xcom
66
73
  # run task and get result
67
74
  try:
68
75
  result = await self._task.async_run(session)
zrb/config/config.py CHANGED
@@ -373,7 +373,7 @@ class Config:
373
373
  """
374
374
  return int(
375
375
  self._getenv(
376
- ["LLM_MAX_TOKEN_PER_MINUTE", "LLM_MAX_TOKENS_PER_MINUTE"], "100000"
376
+ ["LLM_MAX_TOKEN_PER_MINUTE", "LLM_MAX_TOKENS_PER_MINUTE"], "120000"
377
377
  )
378
378
  )
379
379
 
zrb/context/context.py CHANGED
@@ -139,6 +139,17 @@ class Context(AnyContext):
139
139
  stylized_prefix = stylize(prefix, color=color)
140
140
  print(f"{stylized_prefix} {message}", sep=sep, end=end, file=file, flush=flush)
141
141
 
142
+ def print_err(
143
+ self,
144
+ *values: object,
145
+ sep: str | None = " ",
146
+ end: str | None = "\n",
147
+ file: TextIO | None = sys.stderr,
148
+ flush: bool = True,
149
+ plain: bool = False,
150
+ ):
151
+ self.print(*values, sep=sep, end=end, file=file, flush=flush, plain=plain)
152
+
142
153
  def log_debug(
143
154
  self,
144
155
  *values: object,
zrb/task/base/context.py CHANGED
@@ -79,24 +79,36 @@ def combine_inputs(
79
79
  input_names.append(task_input.name) # Update names list
80
80
 
81
81
 
82
+ def combine_envs(
83
+ existing_envs: list[AnyEnv],
84
+ new_envs: list[AnyEnv | None] | AnyEnv | None,
85
+ ):
86
+ """
87
+ Combines new envs into an existing list.
88
+ Modifies the existing_envs list in place.
89
+ """
90
+ if isinstance(new_envs, AnyEnv):
91
+ existing_envs.append(new_envs)
92
+ elif new_envs is None:
93
+ pass
94
+ else:
95
+ # new_envs is a list
96
+ for env in new_envs:
97
+ if env is not None:
98
+ existing_envs.append(env)
99
+
100
+
82
101
  def get_combined_envs(task: "BaseTask") -> list[AnyEnv]:
83
102
  """
84
103
  Aggregates environment variables from the task and its upstreams.
85
104
  """
86
- envs = []
105
+ envs: list[AnyEnv] = []
87
106
  for upstream in task.upstreams:
88
- envs.extend(upstream.envs) # Use extend for list concatenation
89
-
90
- # Access _envs directly as task is BaseTask
91
- task_envs: list[AnyEnv | None] | AnyEnv | None = task._envs
92
- if isinstance(task_envs, AnyEnv):
93
- envs.append(task_envs)
94
- elif isinstance(task_envs, list):
95
- # Filter out None while extending
96
- envs.extend(env for env in task_envs if env is not None)
97
-
98
- # Filter out None values efficiently from the combined list
99
- return [env for env in envs if env is not None]
107
+ combine_envs(envs, upstream.envs)
108
+
109
+ combine_envs(envs, task._envs)
110
+
111
+ return envs
100
112
 
101
113
 
102
114
  def get_combined_inputs(task: "BaseTask") -> list[AnyInput]:
@@ -88,56 +88,61 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
88
88
  run_async(execute_action_with_retry(task, session))
89
89
  )
90
90
 
91
- await asyncio.sleep(readiness_check_delay)
92
-
93
- readiness_check_coros = [
94
- run_async(check.exec_chain(session)) for check in readiness_checks
95
- ]
96
-
97
- # Wait primarily for readiness checks to complete
98
- ctx.log_info("Waiting for readiness checks")
99
- readiness_passed = False
100
91
  try:
101
- # Gather results, but primarily interested in completion/errors
102
- await asyncio.gather(*readiness_check_coros)
103
- # Check if all readiness tasks actually completed successfully
104
- all_readiness_completed = all(
105
- session.get_task_status(check).is_completed for check in readiness_checks
106
- )
107
- if all_readiness_completed:
108
- ctx.log_info("Readiness checks completed successfully")
109
- readiness_passed = True
110
- # Mark task as ready only if checks passed and action didn't fail during checks
111
- if not session.get_task_status(task).is_failed:
112
- ctx.log_info("Marked as ready")
113
- session.get_task_status(task).mark_as_ready()
114
- else:
115
- ctx.log_warning(
116
- "One or more readiness checks did not complete successfully."
117
- )
92
+ await asyncio.sleep(readiness_check_delay)
118
93
 
119
- except Exception as e:
120
- ctx.log_error(f"Readiness check failed with exception: {e}")
121
- # If readiness checks fail with an exception, the task is not ready.
122
- # The action_coro might still be running or have failed.
123
- # execute_action_with_retry handles marking the main task status.
124
-
125
- # Defer the main action coroutine; it will be awaited later if needed
126
- session.defer_action(task, action_coro)
127
-
128
- # Start monitoring only if readiness passed and monitoring is enabled
129
- if readiness_passed and monitor_readiness:
130
- # Import dynamically to avoid circular dependency if monitoring imports execution
131
- from zrb.task.base.monitoring import monitor_task_readiness
94
+ readiness_check_coros = [
95
+ run_async(check.exec_chain(session)) for check in readiness_checks
96
+ ]
132
97
 
133
- monitor_coro = asyncio.create_task(
134
- run_async(monitor_task_readiness(task, session, action_coro))
135
- )
136
- session.defer_monitoring(task, monitor_coro)
98
+ # Wait primarily for readiness checks to complete
99
+ ctx.log_info("Waiting for readiness checks")
100
+ readiness_passed = False
101
+ try:
102
+ # Gather results, but primarily interested in completion/errors
103
+ await asyncio.gather(*readiness_check_coros)
104
+ # Check if all readiness tasks actually completed successfully
105
+ all_readiness_completed = all(
106
+ session.get_task_status(check).is_completed
107
+ for check in readiness_checks
108
+ )
109
+ if all_readiness_completed:
110
+ ctx.log_info("Readiness checks completed successfully")
111
+ readiness_passed = True
112
+ # Mark task as ready only if checks passed and action didn't fail during checks
113
+ if not session.get_task_status(task).is_failed:
114
+ ctx.log_info("Marked as ready")
115
+ session.get_task_status(task).mark_as_ready()
116
+ else:
117
+ ctx.log_warning(
118
+ "One or more readiness checks did not complete successfully."
119
+ )
120
+
121
+ except Exception as e:
122
+ ctx.log_error(f"Readiness check failed with exception: {e}")
123
+ # If readiness checks fail with an exception, the task is not ready.
124
+ # The action_coro might still be running or have failed.
125
+ # execute_action_with_retry handles marking the main task status.
126
+
127
+ # Defer the main action coroutine; it will be awaited later if needed
128
+ session.defer_action(task, action_coro)
129
+
130
+ # Start monitoring only if readiness passed and monitoring is enabled
131
+ if readiness_passed and monitor_readiness:
132
+ # Import dynamically to avoid circular dependency if monitoring imports execution
133
+ from zrb.task.base.monitoring import monitor_task_readiness
134
+
135
+ monitor_coro = asyncio.create_task(
136
+ run_async(monitor_task_readiness(task, session, action_coro))
137
+ )
138
+ session.defer_monitoring(task, monitor_coro)
137
139
 
138
- # The result here is primarily about readiness check completion.
139
- # The actual task result is handled by the deferred action_coro.
140
- return None
140
+ # The result here is primarily about readiness check completion.
141
+ # The actual task result is handled by the deferred action_coro.
142
+ return None
143
+ except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
144
+ action_coro.cancel()
145
+ raise
141
146
 
142
147
 
143
148
  async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> Any:
@@ -178,7 +183,7 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
178
183
  await run_async(execute_successors(task, session))
179
184
  return result
180
185
 
181
- except (asyncio.CancelledError, KeyboardInterrupt):
186
+ except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
182
187
  ctx.log_warning("Task cancelled or interrupted")
183
188
  session.get_task_status(task).mark_as_failed() # Mark as failed on cancel
184
189
  # Do not trigger fallbacks/successors on cancellation
@@ -176,7 +176,7 @@ async def log_session_state(task: AnyTask, session: AnySession):
176
176
  try:
177
177
  while not session.is_terminated:
178
178
  session.state_logger.write(session.as_state_log())
179
- await asyncio.sleep(0.1) # Log interval
179
+ await asyncio.sleep(0) # Log interval
180
180
  # Log one final time after termination signal
181
181
  session.state_logger.write(session.as_state_log())
182
182
  except asyncio.CancelledError:
zrb/task/base_task.py CHANGED
@@ -3,7 +3,7 @@ import inspect
3
3
  from collections.abc import Callable
4
4
  from typing import Any
5
5
 
6
- from zrb.attr.type import BoolAttr, fstring
6
+ from zrb.attr.type import fstring
7
7
  from zrb.context.any_context import AnyContext
8
8
  from zrb.env.any_env import AnyEnv
9
9
  from zrb.input.any_input import AnyInput
@@ -55,7 +55,7 @@ class BaseTask(AnyTask):
55
55
  input: list[AnyInput | None] | AnyInput | None = None,
56
56
  env: list[AnyEnv | None] | AnyEnv | None = None,
57
57
  action: fstring | Callable[[AnyContext], Any] | None = None,
58
- execute_condition: BoolAttr = True,
58
+ execute_condition: bool | str | Callable[[AnyContext], bool] = True,
59
59
  retries: int = 2,
60
60
  retry_period: float = 0,
61
61
  readiness_check: list[AnyTask] | AnyTask | None = None,
@@ -68,9 +68,18 @@ class BaseTask(AnyTask):
68
68
  fallback: list[AnyTask] | AnyTask | None = None,
69
69
  successor: list[AnyTask] | AnyTask | None = None,
70
70
  ):
71
- caller_frame = inspect.stack()[1]
72
- self.__decl_file = caller_frame.filename
73
- self.__decl_line = caller_frame.lineno
71
+ # Optimized stack retrieval
72
+ frame = inspect.currentframe()
73
+ if frame is not None:
74
+ caller_frame = frame.f_back
75
+ self.__decl_file = (
76
+ caller_frame.f_code.co_filename if caller_frame else "unknown"
77
+ )
78
+ self.__decl_line = caller_frame.f_lineno if caller_frame else 0
79
+ else:
80
+ self.__decl_file = "unknown"
81
+ self.__decl_line = 0
82
+
74
83
  self._name = name
75
84
  self._color = color
76
85
  self._icon = icon
@@ -80,10 +89,10 @@ class BaseTask(AnyTask):
80
89
  self._envs = env
81
90
  self._retries = retries
82
91
  self._retry_period = retry_period
83
- self._upstreams = upstream
84
- self._fallbacks = fallback
85
- self._successors = successor
86
- self._readiness_checks = readiness_check
92
+ self._upstreams = self._ensure_task_list(upstream)
93
+ self._fallbacks = self._ensure_task_list(fallback)
94
+ self._successors = self._ensure_task_list(successor)
95
+ self._readiness_checks = self._ensure_task_list(readiness_check)
87
96
  self._readiness_check_delay = readiness_check_delay
88
97
  self._readiness_check_period = readiness_check_period
89
98
  self._readiness_failure_threshold = readiness_failure_threshold
@@ -92,6 +101,13 @@ class BaseTask(AnyTask):
92
101
  self._execute_condition = execute_condition
93
102
  self._action = action
94
103
 
104
+ def _ensure_task_list(self, tasks: AnyTask | list[AnyTask] | None) -> list[AnyTask]:
105
+ if tasks is None:
106
+ return []
107
+ if isinstance(tasks, list):
108
+ return tasks
109
+ return [tasks]
110
+
95
111
  def __repr__(self):
96
112
  return f"<{self.__class__.__name__} name={self.name}>"
97
113
 
@@ -132,18 +148,10 @@ class BaseTask(AnyTask):
132
148
  @property
133
149
  def fallbacks(self) -> list[AnyTask]:
134
150
  """Returns the list of fallback tasks."""
135
- if self._fallbacks is None:
136
- return []
137
- elif isinstance(self._fallbacks, list):
138
- return self._fallbacks
139
- return [self._fallbacks] # Assume single task
151
+ return self._fallbacks
140
152
 
141
153
  def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
142
154
  """Appends fallback tasks, ensuring no duplicates."""
143
- if self._fallbacks is None:
144
- self._fallbacks = []
145
- elif not isinstance(self._fallbacks, list):
146
- self._fallbacks = [self._fallbacks]
147
155
  to_add = fallbacks if isinstance(fallbacks, list) else [fallbacks]
148
156
  for fb in to_add:
149
157
  if fb not in self._fallbacks:
@@ -152,18 +160,10 @@ class BaseTask(AnyTask):
152
160
  @property
153
161
  def successors(self) -> list[AnyTask]:
154
162
  """Returns the list of successor tasks."""
155
- if self._successors is None:
156
- return []
157
- elif isinstance(self._successors, list):
158
- return self._successors
159
- return [self._successors] # Assume single task
163
+ return self._successors
160
164
 
161
165
  def append_successor(self, successors: AnyTask | list[AnyTask]):
162
166
  """Appends successor tasks, ensuring no duplicates."""
163
- if self._successors is None:
164
- self._successors = []
165
- elif not isinstance(self._successors, list):
166
- self._successors = [self._successors]
167
167
  to_add = successors if isinstance(successors, list) else [successors]
168
168
  for succ in to_add:
169
169
  if succ not in self._successors:
@@ -172,18 +172,10 @@ class BaseTask(AnyTask):
172
172
  @property
173
173
  def readiness_checks(self) -> list[AnyTask]:
174
174
  """Returns the list of readiness check tasks."""
175
- if self._readiness_checks is None:
176
- return []
177
- elif isinstance(self._readiness_checks, list):
178
- return self._readiness_checks
179
- return [self._readiness_checks] # Assume single task
175
+ return self._readiness_checks
180
176
 
181
177
  def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
182
178
  """Appends readiness check tasks, ensuring no duplicates."""
183
- if self._readiness_checks is None:
184
- self._readiness_checks = []
185
- elif not isinstance(self._readiness_checks, list):
186
- self._readiness_checks = [self._readiness_checks]
187
179
  to_add = (
188
180
  readiness_checks
189
181
  if isinstance(readiness_checks, list)
@@ -196,18 +188,10 @@ class BaseTask(AnyTask):
196
188
  @property
197
189
  def upstreams(self) -> list[AnyTask]:
198
190
  """Returns the list of upstream tasks."""
199
- if self._upstreams is None:
200
- return []
201
- elif isinstance(self._upstreams, list):
202
- return self._upstreams
203
- return [self._upstreams] # Assume single task
191
+ return self._upstreams
204
192
 
205
193
  def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
206
194
  """Appends upstream tasks, ensuring no duplicates."""
207
- if self._upstreams is None:
208
- self._upstreams = []
209
- elif not isinstance(self._upstreams, list):
210
- self._upstreams = [self._upstreams]
211
195
  to_add = upstreams if isinstance(upstreams, list) else [upstreams]
212
196
  for up in to_add:
213
197
  if up not in self._upstreams:
@@ -277,6 +261,8 @@ class BaseTask(AnyTask):
277
261
  try:
278
262
  # Delegate to the helper function for the default behavior
279
263
  return await run_default_action(self, ctx)
264
+ except (KeyboardInterrupt, GeneratorExit):
265
+ raise
280
266
  except BaseException as e:
281
267
  additional_error_note = (
282
268
  f"Task: {self.name} ({self.__decl_file}:{self.__decl_line})"
zrb/task/base_trigger.py CHANGED
@@ -5,7 +5,6 @@ from typing import Any
5
5
  from zrb.attr.type import fstring
6
6
  from zrb.callback.any_callback import AnyCallback
7
7
  from zrb.context.any_context import AnyContext
8
- from zrb.context.any_shared_context import AnySharedContext
9
8
  from zrb.context.shared_context import SharedContext
10
9
  from zrb.dot_dict.dot_dict import DotDict
11
10
  from zrb.env.any_env import AnyEnv
zrb/task/llm/agent.py CHANGED
@@ -39,39 +39,10 @@ def create_agent_instance(
39
39
  auto_summarize: bool = True,
40
40
  ) -> "Agent[None, Any]":
41
41
  """Creates a new Agent instance with configured tools and servers."""
42
- from pydantic_ai import Agent, RunContext, Tool
42
+ from pydantic_ai import Agent, Tool
43
43
  from pydantic_ai.tools import GenerateToolJsonSchema
44
- from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
45
44
 
46
- @dataclass
47
- class ConfirmationWrapperToolset(WrapperToolset):
48
- ctx: AnyContext
49
- yolo_mode: bool | list[str]
50
-
51
- async def call_tool(
52
- self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
53
- ) -> Any:
54
- # The `tool` object is passed in. Use it for inspection.
55
- # Define a temporary function that performs the actual tool call.
56
- async def execute_delegated_tool_call(**params):
57
- # Pass all arguments down the chain.
58
- return await self.wrapped.call_tool(name, tool_args, ctx, tool)
59
-
60
- # For the confirmation UI, make our temporary function look like the real one.
61
- try:
62
- execute_delegated_tool_call.__name__ = name
63
- execute_delegated_tool_call.__doc__ = tool.function.__doc__
64
- execute_delegated_tool_call.__signature__ = inspect.signature(
65
- tool.function
66
- )
67
- except (AttributeError, TypeError):
68
- pass # Ignore if we can't inspect the original function
69
- # Use the existing wrap_func to get the confirmation logic
70
- wrapped_executor = wrap_func(
71
- execute_delegated_tool_call, self.ctx, self.yolo_mode
72
- )
73
- # Call the wrapped executor. This will trigger the confirmation prompt.
74
- return await wrapped_executor(**tool_args)
45
+ ConfirmationWrapperToolset = _get_confirmation_wrapper_toolset_class()
75
46
 
76
47
  if yolo_mode is None:
77
48
  yolo_mode = False
@@ -132,6 +103,43 @@ def create_agent_instance(
132
103
  )
133
104
 
134
105
 
106
+ def _get_confirmation_wrapper_toolset_class():
107
+ from pydantic_ai import RunContext
108
+ from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
109
+
110
+ @dataclass
111
+ class ConfirmationWrapperToolset(WrapperToolset):
112
+ ctx: AnyContext
113
+ yolo_mode: bool | list[str]
114
+
115
+ async def call_tool(
116
+ self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
117
+ ) -> Any:
118
+ # The `tool` object is passed in. Use it for inspection.
119
+ # Define a temporary function that performs the actual tool call.
120
+ async def execute_delegated_tool_call(**params):
121
+ # Pass all arguments down the chain.
122
+ return await self.wrapped.call_tool(name, tool_args, ctx, tool)
123
+
124
+ # For the confirmation UI, make our temporary function look like the real one.
125
+ try:
126
+ execute_delegated_tool_call.__name__ = name
127
+ execute_delegated_tool_call.__doc__ = tool.function.__doc__
128
+ execute_delegated_tool_call.__signature__ = inspect.signature(
129
+ tool.function
130
+ )
131
+ except (AttributeError, TypeError):
132
+ pass # Ignore if we can't inspect the original function
133
+ # Use the existing wrap_func to get the confirmation logic
134
+ wrapped_executor = wrap_func(
135
+ execute_delegated_tool_call, self.ctx, self.yolo_mode
136
+ )
137
+ # Call the wrapped executor. This will trigger the confirmation prompt.
138
+ return await wrapped_executor(**tool_args)
139
+
140
+ return ConfirmationWrapperToolset
141
+
142
+
135
143
  def get_agent(
136
144
  ctx: AnyContext,
137
145
  model: "str | Model",