zrb 1.21.37__py3-none-any.whl → 1.21.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/builtin/llm/chat_completion.py +46 -0
- zrb/builtin/llm/chat_session.py +89 -29
- zrb/builtin/llm/chat_session_cmd.py +87 -11
- zrb/builtin/llm/chat_trigger.py +92 -5
- zrb/builtin/llm/history.py +14 -7
- zrb/builtin/llm/llm_ask.py +16 -7
- zrb/builtin/llm/tool/file.py +3 -2
- zrb/builtin/llm/tool/search/brave.py +2 -2
- zrb/builtin/llm/tool/search/searxng.py +2 -2
- zrb/builtin/llm/tool/search/serpapi.py +2 -2
- zrb/builtin/llm/xcom_names.py +3 -0
- zrb/callback/callback.py +8 -1
- zrb/config/config.py +1 -1
- zrb/context/context.py +11 -0
- zrb/task/base/context.py +25 -13
- zrb/task/base/execution.py +52 -47
- zrb/task/base/lifecycle.py +1 -1
- zrb/task/base_task.py +31 -45
- zrb/task/base_trigger.py +0 -1
- zrb/task/llm/agent.py +39 -31
- zrb/task/llm/agent_runner.py +59 -1
- zrb/task/llm/default_workflow/researching/workflow.md +2 -0
- zrb/task/llm/print_node.py +15 -2
- zrb/task/llm/prompt.py +70 -40
- zrb/task/llm/workflow.py +13 -1
- zrb/task/llm_task.py +83 -28
- zrb/util/run.py +3 -3
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/METADATA +1 -1
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/RECORD +31 -30
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/WHEEL +0 -0
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/entry_points.txt +0 -0
zrb/builtin/llm/history.py
CHANGED
|
@@ -17,13 +17,10 @@ def read_chat_conversation(ctx: AnyContext) -> dict[str, Any] | list | None:
|
|
|
17
17
|
return None # Indicate no history to load
|
|
18
18
|
previous_session_name = ctx.input.previous_session
|
|
19
19
|
if not previous_session_name: # Check for empty string or None
|
|
20
|
-
|
|
21
|
-
if
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
return None
|
|
25
|
-
else:
|
|
26
|
-
return None # No previous session specified and no last session found
|
|
20
|
+
last_session_name = get_last_session_name()
|
|
21
|
+
if last_session_name is None:
|
|
22
|
+
return None
|
|
23
|
+
previous_session_name = last_session_name
|
|
27
24
|
conversation_file_path = os.path.join(
|
|
28
25
|
CFG.LLM_HISTORY_DIR, f"{previous_session_name}.json"
|
|
29
26
|
)
|
|
@@ -51,6 +48,16 @@ def read_chat_conversation(ctx: AnyContext) -> dict[str, Any] | list | None:
|
|
|
51
48
|
return None
|
|
52
49
|
|
|
53
50
|
|
|
51
|
+
def get_last_session_name() -> str | None:
|
|
52
|
+
last_session_file_path = os.path.join(CFG.LLM_HISTORY_DIR, "last-session")
|
|
53
|
+
if not os.path.isfile(last_session_file_path):
|
|
54
|
+
return None
|
|
55
|
+
last_session_name = read_file(last_session_file_path).strip()
|
|
56
|
+
if not last_session_name: # Handle empty last-session file
|
|
57
|
+
return None
|
|
58
|
+
return last_session_name
|
|
59
|
+
|
|
60
|
+
|
|
54
61
|
def write_chat_conversation(ctx: AnyContext, history_data: ConversationHistory):
|
|
55
62
|
"""Writes the conversation history data (including context) to a session file."""
|
|
56
63
|
os.makedirs(CFG.LLM_HISTORY_DIR, exist_ok=True)
|
zrb/builtin/llm/llm_ask.py
CHANGED
|
@@ -31,6 +31,11 @@ from zrb.builtin.llm.tool.web import (
|
|
|
31
31
|
create_search_internet_tool,
|
|
32
32
|
open_web_page,
|
|
33
33
|
)
|
|
34
|
+
from zrb.builtin.llm.xcom_names import (
|
|
35
|
+
LLM_ASK_ERROR_XCOM_NAME,
|
|
36
|
+
LLM_ASK_RESULT_XCOM_NAME,
|
|
37
|
+
LLM_ASK_SESSION_XCOM_NAME,
|
|
38
|
+
)
|
|
34
39
|
from zrb.callback.callback import Callback
|
|
35
40
|
from zrb.config.config import CFG
|
|
36
41
|
from zrb.config.llm_config import llm_config
|
|
@@ -40,6 +45,7 @@ from zrb.input.bool_input import BoolInput
|
|
|
40
45
|
from zrb.input.str_input import StrInput
|
|
41
46
|
from zrb.input.text_input import TextInput
|
|
42
47
|
from zrb.task.base_trigger import BaseTrigger
|
|
48
|
+
from zrb.task.llm.workflow import LLM_LOADED_WORKFLOW_XCOM_NAME
|
|
43
49
|
from zrb.task.llm_task import LLMTask
|
|
44
50
|
from zrb.util.string.conversion import to_boolean
|
|
45
51
|
|
|
@@ -99,6 +105,8 @@ def _get_default_yolo_mode(ctx: AnyContext) -> str:
|
|
|
99
105
|
|
|
100
106
|
|
|
101
107
|
def _render_yolo_mode_input(ctx: AnyContext) -> list[str] | bool:
|
|
108
|
+
if isinstance(ctx.input.yolo, bool):
|
|
109
|
+
return ctx.input.yolo
|
|
102
110
|
if ctx.input.yolo.strip() == "":
|
|
103
111
|
return []
|
|
104
112
|
elements = [element.strip() for element in ctx.input.yolo.split(",")]
|
|
@@ -172,9 +180,9 @@ def _get_inputs(require_message: bool = True) -> list[AnyInput | None]:
|
|
|
172
180
|
always_prompt=False,
|
|
173
181
|
),
|
|
174
182
|
TextInput(
|
|
175
|
-
"
|
|
176
|
-
description="Workflows",
|
|
177
|
-
prompt="Workflows",
|
|
183
|
+
"workflow",
|
|
184
|
+
description="Workflows (comma separated)",
|
|
185
|
+
prompt="Workflows (comma separated)",
|
|
178
186
|
default=lambda ctx: ",".join(llm_config.default_workflows),
|
|
179
187
|
allow_positional_parsing=False,
|
|
180
188
|
always_prompt=False,
|
|
@@ -237,7 +245,7 @@ llm_ask = LLMTask(
|
|
|
237
245
|
None if ctx.input.system_prompt.strip() == "" else ctx.input.system_prompt
|
|
238
246
|
),
|
|
239
247
|
workflows=lambda ctx: (
|
|
240
|
-
None if ctx.input.
|
|
248
|
+
None if ctx.input.workflow.strip() == "" else ctx.input.workflow.split(",")
|
|
241
249
|
),
|
|
242
250
|
attachment=_render_attach_input,
|
|
243
251
|
message="{ctx.input.message}",
|
|
@@ -258,9 +266,10 @@ llm_group.add_task(
|
|
|
258
266
|
callback=Callback(
|
|
259
267
|
task=llm_ask,
|
|
260
268
|
input_mapping=get_llm_ask_input_mapping,
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
269
|
+
xcom_mapping={LLM_LOADED_WORKFLOW_XCOM_NAME: LLM_LOADED_WORKFLOW_XCOM_NAME},
|
|
270
|
+
result_queue=LLM_ASK_RESULT_XCOM_NAME,
|
|
271
|
+
error_queue=LLM_ASK_ERROR_XCOM_NAME,
|
|
272
|
+
session_name_queue=LLM_ASK_SESSION_XCOM_NAME,
|
|
264
273
|
),
|
|
265
274
|
retries=0,
|
|
266
275
|
cli_only=True,
|
zrb/builtin/llm/tool/file.py
CHANGED
|
@@ -274,8 +274,9 @@ def write_to_file(
|
|
|
274
274
|
- CORRECT: "content": "He said \"Hello\""
|
|
275
275
|
- WRONG: "content": "He said \\"Hello\\"" <-- This breaks JSON parsing!
|
|
276
276
|
2. **SIZE LIMIT:** Content MUST NOT exceed 4000 characters.
|
|
277
|
-
-
|
|
278
|
-
-
|
|
277
|
+
- **STRICT PROHIBITION:** You are FORBIDDEN from writing more than 4000 characters in a single call.
|
|
278
|
+
- This is due to LLM output token limits, which will cause truncation and failure.
|
|
279
|
+
- To write larger files, you MUST split the content into multiple sequential calls (e.g., first 'w', then 'a').
|
|
279
280
|
|
|
280
281
|
Examples:
|
|
281
282
|
```
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
-
import requests
|
|
4
|
-
|
|
5
3
|
from zrb.config.config import CFG
|
|
6
4
|
|
|
7
5
|
|
|
@@ -36,6 +34,8 @@ def search_internet(
|
|
|
36
34
|
Returns:
|
|
37
35
|
dict: Summary of search results (titles, links, snippets).
|
|
38
36
|
"""
|
|
37
|
+
import requests
|
|
38
|
+
|
|
39
39
|
if safe_search is None:
|
|
40
40
|
safe_search = CFG.BRAVE_API_SAFE
|
|
41
41
|
if language is None:
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
-
import requests
|
|
4
|
-
|
|
5
3
|
from zrb.config.config import CFG
|
|
6
4
|
|
|
7
5
|
|
|
@@ -36,6 +34,8 @@ def search_internet(
|
|
|
36
34
|
Returns:
|
|
37
35
|
dict: Summary of search results (titles, links, snippets).
|
|
38
36
|
"""
|
|
37
|
+
import requests
|
|
38
|
+
|
|
39
39
|
if safe_search is None:
|
|
40
40
|
safe_search = CFG.SEARXNG_SAFE
|
|
41
41
|
if language is None:
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
-
import requests
|
|
4
|
-
|
|
5
3
|
from zrb.config.config import CFG
|
|
6
4
|
|
|
7
5
|
|
|
@@ -36,6 +34,8 @@ def search_internet(
|
|
|
36
34
|
Returns:
|
|
37
35
|
dict: Summary of search results (titles, links, snippets).
|
|
38
36
|
"""
|
|
37
|
+
import requests
|
|
38
|
+
|
|
39
39
|
if safe_search is None:
|
|
40
40
|
safe_search = CFG.SERPAPI_SAFE
|
|
41
41
|
if language is None:
|
zrb/callback/callback.py
CHANGED
|
@@ -6,7 +6,6 @@ from zrb.callback.any_callback import AnyCallback
|
|
|
6
6
|
from zrb.session.any_session import AnySession
|
|
7
7
|
from zrb.task.any_task import AnyTask
|
|
8
8
|
from zrb.util.attr import get_str_dict_attr
|
|
9
|
-
from zrb.util.cli.style import stylize_faint
|
|
10
9
|
from zrb.util.string.conversion import to_snake_case
|
|
11
10
|
from zrb.xcom.xcom import Xcom
|
|
12
11
|
|
|
@@ -24,6 +23,7 @@ class Callback(AnyCallback):
|
|
|
24
23
|
task: AnyTask,
|
|
25
24
|
input_mapping: StrDictAttr,
|
|
26
25
|
render_input_mapping: bool = True,
|
|
26
|
+
xcom_mapping: dict[str, str] | None = None,
|
|
27
27
|
result_queue: str | None = None,
|
|
28
28
|
error_queue: str | None = None,
|
|
29
29
|
session_name_queue: str | None = None,
|
|
@@ -36,6 +36,7 @@ class Callback(AnyCallback):
|
|
|
36
36
|
input_mapping: A dictionary or attribute mapping to prepare inputs for the task.
|
|
37
37
|
render_input_mapping: Whether to render the input mapping using
|
|
38
38
|
f-string like syntax.
|
|
39
|
+
xcom_mapping: Map of parent session's xcom names to current session's xcom names
|
|
39
40
|
result_queue: The name of the XCom queue in the parent session
|
|
40
41
|
to publish the task result.
|
|
41
42
|
result_queue: The name of the Xcom queue in the parent session
|
|
@@ -46,6 +47,7 @@ class Callback(AnyCallback):
|
|
|
46
47
|
self._task = task
|
|
47
48
|
self._input_mapping = input_mapping
|
|
48
49
|
self._render_input_mapping = render_input_mapping
|
|
50
|
+
self._xcom_mapping = xcom_mapping
|
|
49
51
|
self._result_queue = result_queue
|
|
50
52
|
self._error_queue = error_queue
|
|
51
53
|
self._session_name_queue = session_name_queue
|
|
@@ -63,6 +65,11 @@ class Callback(AnyCallback):
|
|
|
63
65
|
for name, value in inputs.items():
|
|
64
66
|
session.shared_ctx.input[name] = value
|
|
65
67
|
session.shared_ctx.input[to_snake_case(name)] = value
|
|
68
|
+
# map xcom
|
|
69
|
+
if self._xcom_mapping is not None:
|
|
70
|
+
for parent_xcom_name, current_xcom_name in self._xcom_mapping.items():
|
|
71
|
+
parent_xcom = parent_session.shared_ctx.xcom[parent_xcom_name]
|
|
72
|
+
session.shared_ctx.xcom[current_xcom_name] = parent_xcom
|
|
66
73
|
# run task and get result
|
|
67
74
|
try:
|
|
68
75
|
result = await self._task.async_run(session)
|
zrb/config/config.py
CHANGED
zrb/context/context.py
CHANGED
|
@@ -139,6 +139,17 @@ class Context(AnyContext):
|
|
|
139
139
|
stylized_prefix = stylize(prefix, color=color)
|
|
140
140
|
print(f"{stylized_prefix} {message}", sep=sep, end=end, file=file, flush=flush)
|
|
141
141
|
|
|
142
|
+
def print_err(
|
|
143
|
+
self,
|
|
144
|
+
*values: object,
|
|
145
|
+
sep: str | None = " ",
|
|
146
|
+
end: str | None = "\n",
|
|
147
|
+
file: TextIO | None = sys.stderr,
|
|
148
|
+
flush: bool = True,
|
|
149
|
+
plain: bool = False,
|
|
150
|
+
):
|
|
151
|
+
self.print(*values, sep=sep, end=end, file=file, flush=flush, plain=plain)
|
|
152
|
+
|
|
142
153
|
def log_debug(
|
|
143
154
|
self,
|
|
144
155
|
*values: object,
|
zrb/task/base/context.py
CHANGED
|
@@ -79,24 +79,36 @@ def combine_inputs(
|
|
|
79
79
|
input_names.append(task_input.name) # Update names list
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
def combine_envs(
|
|
83
|
+
existing_envs: list[AnyEnv],
|
|
84
|
+
new_envs: list[AnyEnv | None] | AnyEnv | None,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Combines new envs into an existing list.
|
|
88
|
+
Modifies the existing_envs list in place.
|
|
89
|
+
"""
|
|
90
|
+
if isinstance(new_envs, AnyEnv):
|
|
91
|
+
existing_envs.append(new_envs)
|
|
92
|
+
elif new_envs is None:
|
|
93
|
+
pass
|
|
94
|
+
else:
|
|
95
|
+
# new_envs is a list
|
|
96
|
+
for env in new_envs:
|
|
97
|
+
if env is not None:
|
|
98
|
+
existing_envs.append(env)
|
|
99
|
+
|
|
100
|
+
|
|
82
101
|
def get_combined_envs(task: "BaseTask") -> list[AnyEnv]:
|
|
83
102
|
"""
|
|
84
103
|
Aggregates environment variables from the task and its upstreams.
|
|
85
104
|
"""
|
|
86
|
-
envs = []
|
|
105
|
+
envs: list[AnyEnv] = []
|
|
87
106
|
for upstream in task.upstreams:
|
|
88
|
-
envs
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
envs.append(task_envs)
|
|
94
|
-
elif isinstance(task_envs, list):
|
|
95
|
-
# Filter out None while extending
|
|
96
|
-
envs.extend(env for env in task_envs if env is not None)
|
|
97
|
-
|
|
98
|
-
# Filter out None values efficiently from the combined list
|
|
99
|
-
return [env for env in envs if env is not None]
|
|
107
|
+
combine_envs(envs, upstream.envs)
|
|
108
|
+
|
|
109
|
+
combine_envs(envs, task._envs)
|
|
110
|
+
|
|
111
|
+
return envs
|
|
100
112
|
|
|
101
113
|
|
|
102
114
|
def get_combined_inputs(task: "BaseTask") -> list[AnyInput]:
|
zrb/task/base/execution.py
CHANGED
|
@@ -88,56 +88,61 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
|
|
|
88
88
|
run_async(execute_action_with_retry(task, session))
|
|
89
89
|
)
|
|
90
90
|
|
|
91
|
-
await asyncio.sleep(readiness_check_delay)
|
|
92
|
-
|
|
93
|
-
readiness_check_coros = [
|
|
94
|
-
run_async(check.exec_chain(session)) for check in readiness_checks
|
|
95
|
-
]
|
|
96
|
-
|
|
97
|
-
# Wait primarily for readiness checks to complete
|
|
98
|
-
ctx.log_info("Waiting for readiness checks")
|
|
99
|
-
readiness_passed = False
|
|
100
91
|
try:
|
|
101
|
-
|
|
102
|
-
await asyncio.gather(*readiness_check_coros)
|
|
103
|
-
# Check if all readiness tasks actually completed successfully
|
|
104
|
-
all_readiness_completed = all(
|
|
105
|
-
session.get_task_status(check).is_completed for check in readiness_checks
|
|
106
|
-
)
|
|
107
|
-
if all_readiness_completed:
|
|
108
|
-
ctx.log_info("Readiness checks completed successfully")
|
|
109
|
-
readiness_passed = True
|
|
110
|
-
# Mark task as ready only if checks passed and action didn't fail during checks
|
|
111
|
-
if not session.get_task_status(task).is_failed:
|
|
112
|
-
ctx.log_info("Marked as ready")
|
|
113
|
-
session.get_task_status(task).mark_as_ready()
|
|
114
|
-
else:
|
|
115
|
-
ctx.log_warning(
|
|
116
|
-
"One or more readiness checks did not complete successfully."
|
|
117
|
-
)
|
|
92
|
+
await asyncio.sleep(readiness_check_delay)
|
|
118
93
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
# The action_coro might still be running or have failed.
|
|
123
|
-
# execute_action_with_retry handles marking the main task status.
|
|
124
|
-
|
|
125
|
-
# Defer the main action coroutine; it will be awaited later if needed
|
|
126
|
-
session.defer_action(task, action_coro)
|
|
127
|
-
|
|
128
|
-
# Start monitoring only if readiness passed and monitoring is enabled
|
|
129
|
-
if readiness_passed and monitor_readiness:
|
|
130
|
-
# Import dynamically to avoid circular dependency if monitoring imports execution
|
|
131
|
-
from zrb.task.base.monitoring import monitor_task_readiness
|
|
94
|
+
readiness_check_coros = [
|
|
95
|
+
run_async(check.exec_chain(session)) for check in readiness_checks
|
|
96
|
+
]
|
|
132
97
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
98
|
+
# Wait primarily for readiness checks to complete
|
|
99
|
+
ctx.log_info("Waiting for readiness checks")
|
|
100
|
+
readiness_passed = False
|
|
101
|
+
try:
|
|
102
|
+
# Gather results, but primarily interested in completion/errors
|
|
103
|
+
await asyncio.gather(*readiness_check_coros)
|
|
104
|
+
# Check if all readiness tasks actually completed successfully
|
|
105
|
+
all_readiness_completed = all(
|
|
106
|
+
session.get_task_status(check).is_completed
|
|
107
|
+
for check in readiness_checks
|
|
108
|
+
)
|
|
109
|
+
if all_readiness_completed:
|
|
110
|
+
ctx.log_info("Readiness checks completed successfully")
|
|
111
|
+
readiness_passed = True
|
|
112
|
+
# Mark task as ready only if checks passed and action didn't fail during checks
|
|
113
|
+
if not session.get_task_status(task).is_failed:
|
|
114
|
+
ctx.log_info("Marked as ready")
|
|
115
|
+
session.get_task_status(task).mark_as_ready()
|
|
116
|
+
else:
|
|
117
|
+
ctx.log_warning(
|
|
118
|
+
"One or more readiness checks did not complete successfully."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
except Exception as e:
|
|
122
|
+
ctx.log_error(f"Readiness check failed with exception: {e}")
|
|
123
|
+
# If readiness checks fail with an exception, the task is not ready.
|
|
124
|
+
# The action_coro might still be running or have failed.
|
|
125
|
+
# execute_action_with_retry handles marking the main task status.
|
|
126
|
+
|
|
127
|
+
# Defer the main action coroutine; it will be awaited later if needed
|
|
128
|
+
session.defer_action(task, action_coro)
|
|
129
|
+
|
|
130
|
+
# Start monitoring only if readiness passed and monitoring is enabled
|
|
131
|
+
if readiness_passed and monitor_readiness:
|
|
132
|
+
# Import dynamically to avoid circular dependency if monitoring imports execution
|
|
133
|
+
from zrb.task.base.monitoring import monitor_task_readiness
|
|
134
|
+
|
|
135
|
+
monitor_coro = asyncio.create_task(
|
|
136
|
+
run_async(monitor_task_readiness(task, session, action_coro))
|
|
137
|
+
)
|
|
138
|
+
session.defer_monitoring(task, monitor_coro)
|
|
137
139
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
# The result here is primarily about readiness check completion.
|
|
141
|
+
# The actual task result is handled by the deferred action_coro.
|
|
142
|
+
return None
|
|
143
|
+
except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
|
|
144
|
+
action_coro.cancel()
|
|
145
|
+
raise
|
|
141
146
|
|
|
142
147
|
|
|
143
148
|
async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> Any:
|
|
@@ -178,7 +183,7 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
|
|
|
178
183
|
await run_async(execute_successors(task, session))
|
|
179
184
|
return result
|
|
180
185
|
|
|
181
|
-
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
186
|
+
except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
|
|
182
187
|
ctx.log_warning("Task cancelled or interrupted")
|
|
183
188
|
session.get_task_status(task).mark_as_failed() # Mark as failed on cancel
|
|
184
189
|
# Do not trigger fallbacks/successors on cancellation
|
zrb/task/base/lifecycle.py
CHANGED
|
@@ -176,7 +176,7 @@ async def log_session_state(task: AnyTask, session: AnySession):
|
|
|
176
176
|
try:
|
|
177
177
|
while not session.is_terminated:
|
|
178
178
|
session.state_logger.write(session.as_state_log())
|
|
179
|
-
await asyncio.sleep(0
|
|
179
|
+
await asyncio.sleep(0) # Log interval
|
|
180
180
|
# Log one final time after termination signal
|
|
181
181
|
session.state_logger.write(session.as_state_log())
|
|
182
182
|
except asyncio.CancelledError:
|
zrb/task/base_task.py
CHANGED
|
@@ -3,7 +3,7 @@ import inspect
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
-
from zrb.attr.type import
|
|
6
|
+
from zrb.attr.type import fstring
|
|
7
7
|
from zrb.context.any_context import AnyContext
|
|
8
8
|
from zrb.env.any_env import AnyEnv
|
|
9
9
|
from zrb.input.any_input import AnyInput
|
|
@@ -55,7 +55,7 @@ class BaseTask(AnyTask):
|
|
|
55
55
|
input: list[AnyInput | None] | AnyInput | None = None,
|
|
56
56
|
env: list[AnyEnv | None] | AnyEnv | None = None,
|
|
57
57
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
|
58
|
-
execute_condition:
|
|
58
|
+
execute_condition: bool | str | Callable[[AnyContext], bool] = True,
|
|
59
59
|
retries: int = 2,
|
|
60
60
|
retry_period: float = 0,
|
|
61
61
|
readiness_check: list[AnyTask] | AnyTask | None = None,
|
|
@@ -68,9 +68,18 @@ class BaseTask(AnyTask):
|
|
|
68
68
|
fallback: list[AnyTask] | AnyTask | None = None,
|
|
69
69
|
successor: list[AnyTask] | AnyTask | None = None,
|
|
70
70
|
):
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
71
|
+
# Optimized stack retrieval
|
|
72
|
+
frame = inspect.currentframe()
|
|
73
|
+
if frame is not None:
|
|
74
|
+
caller_frame = frame.f_back
|
|
75
|
+
self.__decl_file = (
|
|
76
|
+
caller_frame.f_code.co_filename if caller_frame else "unknown"
|
|
77
|
+
)
|
|
78
|
+
self.__decl_line = caller_frame.f_lineno if caller_frame else 0
|
|
79
|
+
else:
|
|
80
|
+
self.__decl_file = "unknown"
|
|
81
|
+
self.__decl_line = 0
|
|
82
|
+
|
|
74
83
|
self._name = name
|
|
75
84
|
self._color = color
|
|
76
85
|
self._icon = icon
|
|
@@ -80,10 +89,10 @@ class BaseTask(AnyTask):
|
|
|
80
89
|
self._envs = env
|
|
81
90
|
self._retries = retries
|
|
82
91
|
self._retry_period = retry_period
|
|
83
|
-
self._upstreams = upstream
|
|
84
|
-
self._fallbacks = fallback
|
|
85
|
-
self._successors = successor
|
|
86
|
-
self._readiness_checks = readiness_check
|
|
92
|
+
self._upstreams = self._ensure_task_list(upstream)
|
|
93
|
+
self._fallbacks = self._ensure_task_list(fallback)
|
|
94
|
+
self._successors = self._ensure_task_list(successor)
|
|
95
|
+
self._readiness_checks = self._ensure_task_list(readiness_check)
|
|
87
96
|
self._readiness_check_delay = readiness_check_delay
|
|
88
97
|
self._readiness_check_period = readiness_check_period
|
|
89
98
|
self._readiness_failure_threshold = readiness_failure_threshold
|
|
@@ -92,6 +101,13 @@ class BaseTask(AnyTask):
|
|
|
92
101
|
self._execute_condition = execute_condition
|
|
93
102
|
self._action = action
|
|
94
103
|
|
|
104
|
+
def _ensure_task_list(self, tasks: AnyTask | list[AnyTask] | None) -> list[AnyTask]:
|
|
105
|
+
if tasks is None:
|
|
106
|
+
return []
|
|
107
|
+
if isinstance(tasks, list):
|
|
108
|
+
return tasks
|
|
109
|
+
return [tasks]
|
|
110
|
+
|
|
95
111
|
def __repr__(self):
|
|
96
112
|
return f"<{self.__class__.__name__} name={self.name}>"
|
|
97
113
|
|
|
@@ -132,18 +148,10 @@ class BaseTask(AnyTask):
|
|
|
132
148
|
@property
|
|
133
149
|
def fallbacks(self) -> list[AnyTask]:
|
|
134
150
|
"""Returns the list of fallback tasks."""
|
|
135
|
-
|
|
136
|
-
return []
|
|
137
|
-
elif isinstance(self._fallbacks, list):
|
|
138
|
-
return self._fallbacks
|
|
139
|
-
return [self._fallbacks] # Assume single task
|
|
151
|
+
return self._fallbacks
|
|
140
152
|
|
|
141
153
|
def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
|
|
142
154
|
"""Appends fallback tasks, ensuring no duplicates."""
|
|
143
|
-
if self._fallbacks is None:
|
|
144
|
-
self._fallbacks = []
|
|
145
|
-
elif not isinstance(self._fallbacks, list):
|
|
146
|
-
self._fallbacks = [self._fallbacks]
|
|
147
155
|
to_add = fallbacks if isinstance(fallbacks, list) else [fallbacks]
|
|
148
156
|
for fb in to_add:
|
|
149
157
|
if fb not in self._fallbacks:
|
|
@@ -152,18 +160,10 @@ class BaseTask(AnyTask):
|
|
|
152
160
|
@property
|
|
153
161
|
def successors(self) -> list[AnyTask]:
|
|
154
162
|
"""Returns the list of successor tasks."""
|
|
155
|
-
|
|
156
|
-
return []
|
|
157
|
-
elif isinstance(self._successors, list):
|
|
158
|
-
return self._successors
|
|
159
|
-
return [self._successors] # Assume single task
|
|
163
|
+
return self._successors
|
|
160
164
|
|
|
161
165
|
def append_successor(self, successors: AnyTask | list[AnyTask]):
|
|
162
166
|
"""Appends successor tasks, ensuring no duplicates."""
|
|
163
|
-
if self._successors is None:
|
|
164
|
-
self._successors = []
|
|
165
|
-
elif not isinstance(self._successors, list):
|
|
166
|
-
self._successors = [self._successors]
|
|
167
167
|
to_add = successors if isinstance(successors, list) else [successors]
|
|
168
168
|
for succ in to_add:
|
|
169
169
|
if succ not in self._successors:
|
|
@@ -172,18 +172,10 @@ class BaseTask(AnyTask):
|
|
|
172
172
|
@property
|
|
173
173
|
def readiness_checks(self) -> list[AnyTask]:
|
|
174
174
|
"""Returns the list of readiness check tasks."""
|
|
175
|
-
|
|
176
|
-
return []
|
|
177
|
-
elif isinstance(self._readiness_checks, list):
|
|
178
|
-
return self._readiness_checks
|
|
179
|
-
return [self._readiness_checks] # Assume single task
|
|
175
|
+
return self._readiness_checks
|
|
180
176
|
|
|
181
177
|
def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
|
|
182
178
|
"""Appends readiness check tasks, ensuring no duplicates."""
|
|
183
|
-
if self._readiness_checks is None:
|
|
184
|
-
self._readiness_checks = []
|
|
185
|
-
elif not isinstance(self._readiness_checks, list):
|
|
186
|
-
self._readiness_checks = [self._readiness_checks]
|
|
187
179
|
to_add = (
|
|
188
180
|
readiness_checks
|
|
189
181
|
if isinstance(readiness_checks, list)
|
|
@@ -196,18 +188,10 @@ class BaseTask(AnyTask):
|
|
|
196
188
|
@property
|
|
197
189
|
def upstreams(self) -> list[AnyTask]:
|
|
198
190
|
"""Returns the list of upstream tasks."""
|
|
199
|
-
|
|
200
|
-
return []
|
|
201
|
-
elif isinstance(self._upstreams, list):
|
|
202
|
-
return self._upstreams
|
|
203
|
-
return [self._upstreams] # Assume single task
|
|
191
|
+
return self._upstreams
|
|
204
192
|
|
|
205
193
|
def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
|
|
206
194
|
"""Appends upstream tasks, ensuring no duplicates."""
|
|
207
|
-
if self._upstreams is None:
|
|
208
|
-
self._upstreams = []
|
|
209
|
-
elif not isinstance(self._upstreams, list):
|
|
210
|
-
self._upstreams = [self._upstreams]
|
|
211
195
|
to_add = upstreams if isinstance(upstreams, list) else [upstreams]
|
|
212
196
|
for up in to_add:
|
|
213
197
|
if up not in self._upstreams:
|
|
@@ -277,6 +261,8 @@ class BaseTask(AnyTask):
|
|
|
277
261
|
try:
|
|
278
262
|
# Delegate to the helper function for the default behavior
|
|
279
263
|
return await run_default_action(self, ctx)
|
|
264
|
+
except (KeyboardInterrupt, GeneratorExit):
|
|
265
|
+
raise
|
|
280
266
|
except BaseException as e:
|
|
281
267
|
additional_error_note = (
|
|
282
268
|
f"Task: {self.name} ({self.__decl_file}:{self.__decl_line})"
|
zrb/task/base_trigger.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import Any
|
|
|
5
5
|
from zrb.attr.type import fstring
|
|
6
6
|
from zrb.callback.any_callback import AnyCallback
|
|
7
7
|
from zrb.context.any_context import AnyContext
|
|
8
|
-
from zrb.context.any_shared_context import AnySharedContext
|
|
9
8
|
from zrb.context.shared_context import SharedContext
|
|
10
9
|
from zrb.dot_dict.dot_dict import DotDict
|
|
11
10
|
from zrb.env.any_env import AnyEnv
|
zrb/task/llm/agent.py
CHANGED
|
@@ -39,39 +39,10 @@ def create_agent_instance(
|
|
|
39
39
|
auto_summarize: bool = True,
|
|
40
40
|
) -> "Agent[None, Any]":
|
|
41
41
|
"""Creates a new Agent instance with configured tools and servers."""
|
|
42
|
-
from pydantic_ai import Agent,
|
|
42
|
+
from pydantic_ai import Agent, Tool
|
|
43
43
|
from pydantic_ai.tools import GenerateToolJsonSchema
|
|
44
|
-
from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
|
|
45
44
|
|
|
46
|
-
|
|
47
|
-
class ConfirmationWrapperToolset(WrapperToolset):
|
|
48
|
-
ctx: AnyContext
|
|
49
|
-
yolo_mode: bool | list[str]
|
|
50
|
-
|
|
51
|
-
async def call_tool(
|
|
52
|
-
self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
|
|
53
|
-
) -> Any:
|
|
54
|
-
# The `tool` object is passed in. Use it for inspection.
|
|
55
|
-
# Define a temporary function that performs the actual tool call.
|
|
56
|
-
async def execute_delegated_tool_call(**params):
|
|
57
|
-
# Pass all arguments down the chain.
|
|
58
|
-
return await self.wrapped.call_tool(name, tool_args, ctx, tool)
|
|
59
|
-
|
|
60
|
-
# For the confirmation UI, make our temporary function look like the real one.
|
|
61
|
-
try:
|
|
62
|
-
execute_delegated_tool_call.__name__ = name
|
|
63
|
-
execute_delegated_tool_call.__doc__ = tool.function.__doc__
|
|
64
|
-
execute_delegated_tool_call.__signature__ = inspect.signature(
|
|
65
|
-
tool.function
|
|
66
|
-
)
|
|
67
|
-
except (AttributeError, TypeError):
|
|
68
|
-
pass # Ignore if we can't inspect the original function
|
|
69
|
-
# Use the existing wrap_func to get the confirmation logic
|
|
70
|
-
wrapped_executor = wrap_func(
|
|
71
|
-
execute_delegated_tool_call, self.ctx, self.yolo_mode
|
|
72
|
-
)
|
|
73
|
-
# Call the wrapped executor. This will trigger the confirmation prompt.
|
|
74
|
-
return await wrapped_executor(**tool_args)
|
|
45
|
+
ConfirmationWrapperToolset = _get_confirmation_wrapper_toolset_class()
|
|
75
46
|
|
|
76
47
|
if yolo_mode is None:
|
|
77
48
|
yolo_mode = False
|
|
@@ -132,6 +103,43 @@ def create_agent_instance(
|
|
|
132
103
|
)
|
|
133
104
|
|
|
134
105
|
|
|
106
|
+
def _get_confirmation_wrapper_toolset_class():
|
|
107
|
+
from pydantic_ai import RunContext
|
|
108
|
+
from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class ConfirmationWrapperToolset(WrapperToolset):
|
|
112
|
+
ctx: AnyContext
|
|
113
|
+
yolo_mode: bool | list[str]
|
|
114
|
+
|
|
115
|
+
async def call_tool(
|
|
116
|
+
self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
|
|
117
|
+
) -> Any:
|
|
118
|
+
# The `tool` object is passed in. Use it for inspection.
|
|
119
|
+
# Define a temporary function that performs the actual tool call.
|
|
120
|
+
async def execute_delegated_tool_call(**params):
|
|
121
|
+
# Pass all arguments down the chain.
|
|
122
|
+
return await self.wrapped.call_tool(name, tool_args, ctx, tool)
|
|
123
|
+
|
|
124
|
+
# For the confirmation UI, make our temporary function look like the real one.
|
|
125
|
+
try:
|
|
126
|
+
execute_delegated_tool_call.__name__ = name
|
|
127
|
+
execute_delegated_tool_call.__doc__ = tool.function.__doc__
|
|
128
|
+
execute_delegated_tool_call.__signature__ = inspect.signature(
|
|
129
|
+
tool.function
|
|
130
|
+
)
|
|
131
|
+
except (AttributeError, TypeError):
|
|
132
|
+
pass # Ignore if we can't inspect the original function
|
|
133
|
+
# Use the existing wrap_func to get the confirmation logic
|
|
134
|
+
wrapped_executor = wrap_func(
|
|
135
|
+
execute_delegated_tool_call, self.ctx, self.yolo_mode
|
|
136
|
+
)
|
|
137
|
+
# Call the wrapped executor. This will trigger the confirmation prompt.
|
|
138
|
+
return await wrapped_executor(**tool_args)
|
|
139
|
+
|
|
140
|
+
return ConfirmationWrapperToolset
|
|
141
|
+
|
|
142
|
+
|
|
135
143
|
def get_agent(
|
|
136
144
|
ctx: AnyContext,
|
|
137
145
|
model: "str | Model",
|