zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_completion.py +274 -0
- zrb/builtin/llm/chat_session.py +126 -167
- zrb/builtin/llm/chat_session_cmd.py +288 -0
- zrb/builtin/llm/chat_trigger.py +79 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +217 -135
- zrb/builtin/llm/tool/api.py +74 -70
- zrb/builtin/llm/tool/cli.py +35 -21
- zrb/builtin/llm/tool/code.py +55 -73
- zrb/builtin/llm/tool/file.py +278 -344
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +27 -34
- zrb/builtin/llm/tool/sub_agent.py +54 -41
- zrb/builtin/llm/tool/web.py +74 -98
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +202 -27
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +57 -16
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +119 -23
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +100 -47
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +3 -7
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +13 -1
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +128 -167
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/config.py +39 -20
- zrb/task/llm/conversation_history.py +110 -29
- zrb/task/llm/conversation_history_model.py +4 -179
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +2 -193
- zrb/task/llm/print_node.py +184 -64
- zrb/task/llm/prompt.py +175 -179
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_wrapper.py +226 -85
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +109 -71
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +12 -3
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
zrb/task/base/execution.py
CHANGED
|
@@ -53,7 +53,9 @@ def check_execute_condition(task: "BaseTask", session: AnySession) -> bool:
|
|
|
53
53
|
Evaluates the task's execute_condition attribute.
|
|
54
54
|
"""
|
|
55
55
|
ctx = task.get_ctx(session)
|
|
56
|
-
execute_condition_attr =
|
|
56
|
+
execute_condition_attr = (
|
|
57
|
+
task._execute_condition if task._execute_condition is not None else True
|
|
58
|
+
)
|
|
57
59
|
return get_bool_attr(ctx, execute_condition_attr, True, auto_render=True)
|
|
58
60
|
|
|
59
61
|
|
|
@@ -63,8 +65,12 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
|
|
|
63
65
|
"""
|
|
64
66
|
ctx = task.get_ctx(session)
|
|
65
67
|
readiness_checks = task.readiness_checks
|
|
66
|
-
readiness_check_delay =
|
|
67
|
-
|
|
68
|
+
readiness_check_delay = (
|
|
69
|
+
task._readiness_check_delay if task._readiness_check_delay is not None else 0.5
|
|
70
|
+
)
|
|
71
|
+
monitor_readiness = (
|
|
72
|
+
task._monitor_readiness if task._monitor_readiness is not None else False
|
|
73
|
+
)
|
|
68
74
|
|
|
69
75
|
if not readiness_checks: # Simplified check for empty list
|
|
70
76
|
ctx.log_info("No readiness checks")
|
|
@@ -140,8 +146,8 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
|
|
|
140
146
|
handling success (triggering successors) and failure (triggering fallbacks).
|
|
141
147
|
"""
|
|
142
148
|
ctx = task.get_ctx(session)
|
|
143
|
-
retries =
|
|
144
|
-
retry_period =
|
|
149
|
+
retries = task._retries if task._retries is not None else 2
|
|
150
|
+
retry_period = task._retry_period if task._retry_period is not None else 0
|
|
145
151
|
max_attempt = retries + 1
|
|
146
152
|
ctx.set_max_attempt(max_attempt)
|
|
147
153
|
|
|
@@ -163,8 +169,9 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
|
|
|
163
169
|
session.get_task_status(task).mark_as_completed()
|
|
164
170
|
|
|
165
171
|
# Store result in XCom
|
|
166
|
-
task_xcom: Xcom = ctx.xcom.get(task.name)
|
|
167
|
-
task_xcom
|
|
172
|
+
task_xcom: Xcom | None = ctx.xcom.get(task.name)
|
|
173
|
+
if task_xcom is not None:
|
|
174
|
+
task_xcom.push(result)
|
|
168
175
|
|
|
169
176
|
# Skip fallbacks and execute successors on success
|
|
170
177
|
skip_fallbacks(task, session)
|
|
@@ -201,7 +208,7 @@ async def run_default_action(task: "BaseTask", ctx: AnyContext) -> Any:
|
|
|
201
208
|
This is the default implementation called by BaseTask._exec_action.
|
|
202
209
|
Subclasses like LLMTask override _exec_action with their own logic.
|
|
203
210
|
"""
|
|
204
|
-
action =
|
|
211
|
+
action = task._action
|
|
205
212
|
if action is None:
|
|
206
213
|
ctx.log_debug("No action defined for this task.")
|
|
207
214
|
return None
|
zrb/task/base/lifecycle.py
CHANGED
|
@@ -12,7 +12,8 @@ from zrb.util.run import run_async
|
|
|
12
12
|
async def run_and_cleanup(
|
|
13
13
|
task: AnyTask,
|
|
14
14
|
session: AnySession | None = None,
|
|
15
|
-
str_kwargs: dict[str, str] =
|
|
15
|
+
str_kwargs: dict[str, str] | None = None,
|
|
16
|
+
kwargs: dict[str, Any] | None = None,
|
|
16
17
|
) -> Any:
|
|
17
18
|
"""
|
|
18
19
|
Wrapper for async_run that ensures session termination and cleanup of
|
|
@@ -23,7 +24,9 @@ async def run_and_cleanup(
|
|
|
23
24
|
session = Session(shared_ctx=SharedContext())
|
|
24
25
|
|
|
25
26
|
# Create the main task execution coroutine
|
|
26
|
-
main_task_coro = asyncio.create_task(
|
|
27
|
+
main_task_coro = asyncio.create_task(
|
|
28
|
+
run_task_async(task, session, str_kwargs, kwargs)
|
|
29
|
+
)
|
|
27
30
|
|
|
28
31
|
try:
|
|
29
32
|
result = await main_task_coro
|
|
@@ -67,7 +70,8 @@ async def run_and_cleanup(
|
|
|
67
70
|
async def run_task_async(
|
|
68
71
|
task: AnyTask,
|
|
69
72
|
session: AnySession | None = None,
|
|
70
|
-
str_kwargs: dict[str, str] =
|
|
73
|
+
str_kwargs: dict[str, str] | None = None,
|
|
74
|
+
kwargs: dict[str, Any] | None = None,
|
|
71
75
|
) -> Any:
|
|
72
76
|
"""
|
|
73
77
|
Asynchronous entry point for running a task (`task.async_run()`).
|
|
@@ -77,7 +81,7 @@ async def run_task_async(
|
|
|
77
81
|
session = Session(shared_ctx=SharedContext())
|
|
78
82
|
|
|
79
83
|
# Populate shared context with inputs and environment variables
|
|
80
|
-
fill_shared_context_inputs(
|
|
84
|
+
fill_shared_context_inputs(session.shared_ctx, task, str_kwargs, kwargs)
|
|
81
85
|
fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
|
|
82
86
|
|
|
83
87
|
# Start the execution chain from the root tasks
|
zrb/task/base/monitoring.py
CHANGED
|
@@ -17,9 +17,13 @@ async def monitor_task_readiness(
|
|
|
17
17
|
"""
|
|
18
18
|
ctx = task.get_ctx(session)
|
|
19
19
|
readiness_checks = task.readiness_checks
|
|
20
|
-
readiness_check_period =
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
readiness_check_period = (
|
|
21
|
+
task._readiness_check_period if task._readiness_check_period else 5.0
|
|
22
|
+
)
|
|
23
|
+
readiness_failure_threshold = (
|
|
24
|
+
task._readiness_failure_threshold if task._readiness_failure_threshold else 1
|
|
25
|
+
)
|
|
26
|
+
readiness_timeout = task._readiness_timeout if task._readiness_timeout else 60
|
|
23
27
|
|
|
24
28
|
if not readiness_checks:
|
|
25
29
|
ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
|
|
@@ -41,8 +45,9 @@ async def monitor_task_readiness(
|
|
|
41
45
|
session.get_task_status(check).reset_history()
|
|
42
46
|
session.get_task_status(check).reset()
|
|
43
47
|
# Clear previous XCom data for the check task if needed
|
|
44
|
-
check_xcom: Xcom = ctx.xcom.get(check.name)
|
|
45
|
-
check_xcom
|
|
48
|
+
check_xcom: Xcom | None = ctx.xcom.get(check.name)
|
|
49
|
+
if check_xcom is not None:
|
|
50
|
+
check_xcom.clear()
|
|
46
51
|
|
|
47
52
|
readiness_check_coros = [
|
|
48
53
|
run_async(check.exec_chain(session)) for check in readiness_checks
|
|
@@ -77,7 +82,7 @@ async def monitor_task_readiness(
|
|
|
77
82
|
)
|
|
78
83
|
# Ensure check tasks are marked as failed on timeout
|
|
79
84
|
for check in readiness_checks:
|
|
80
|
-
if not session.get_task_status(check).
|
|
85
|
+
if not session.get_task_status(check).is_ready:
|
|
81
86
|
session.get_task_status(check).mark_as_failed()
|
|
82
87
|
|
|
83
88
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
@@ -92,7 +97,7 @@ async def monitor_task_readiness(
|
|
|
92
97
|
)
|
|
93
98
|
# Mark checks as failed
|
|
94
99
|
for check in readiness_checks:
|
|
95
|
-
if not session.get_task_status(check).
|
|
100
|
+
if not session.get_task_status(check).is_ready:
|
|
96
101
|
session.get_task_status(check).mark_as_failed()
|
|
97
102
|
|
|
98
103
|
# If failure threshold is reached
|
zrb/task/base_task.py
CHANGED
|
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
|
|
|
21
21
|
)
|
|
22
22
|
from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
|
|
23
23
|
from zrb.task.base.operators import handle_lshift, handle_rshift
|
|
24
|
+
from zrb.util.string.conversion import to_snake_case
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class BaseTask(AnyTask):
|
|
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
|
|
|
216
217
|
return build_task_context(self, session)
|
|
217
218
|
|
|
218
219
|
def run(
|
|
219
|
-
self,
|
|
220
|
+
self,
|
|
221
|
+
session: AnySession | None = None,
|
|
222
|
+
str_kwargs: dict[str, str] | None = None,
|
|
223
|
+
kwargs: dict[str, Any] | None = None,
|
|
220
224
|
) -> Any:
|
|
221
225
|
"""
|
|
222
226
|
Synchronously runs the task and its dependencies, handling async setup and cleanup.
|
|
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
|
|
|
235
239
|
Any: The final result of the main task execution.
|
|
236
240
|
"""
|
|
237
241
|
# Use asyncio.run() to execute the async cleanup wrapper
|
|
238
|
-
return asyncio.run(
|
|
242
|
+
return asyncio.run(
|
|
243
|
+
run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
|
|
244
|
+
)
|
|
239
245
|
|
|
240
246
|
async def async_run(
|
|
241
|
-
self,
|
|
247
|
+
self,
|
|
248
|
+
session: AnySession | None = None,
|
|
249
|
+
str_kwargs: dict[str, str] | None = None,
|
|
250
|
+
kwargs: dict[str, Any] | None = None,
|
|
242
251
|
) -> Any:
|
|
243
|
-
return await run_task_async(
|
|
252
|
+
return await run_task_async(
|
|
253
|
+
self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
|
|
254
|
+
)
|
|
244
255
|
|
|
245
256
|
async def exec_root_tasks(self, session: AnySession):
|
|
246
257
|
return await execute_root_tasks(self, session)
|
|
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
|
|
|
276
287
|
# Add definition location to the error
|
|
277
288
|
if hasattr(e, "add_note"):
|
|
278
289
|
e.add_note(additional_error_note)
|
|
279
|
-
|
|
290
|
+
elif hasattr(e, "__notes__"):
|
|
280
291
|
# fallback: use the __notes__ attribute directly
|
|
281
292
|
e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
|
|
282
293
|
raise e
|
|
294
|
+
|
|
295
|
+
def to_function(self) -> Callable[..., Any]:
|
|
296
|
+
from zrb.context.shared_context import SharedContext
|
|
297
|
+
from zrb.session.session import Session
|
|
298
|
+
|
|
299
|
+
def task_runner_fn(**kwargs) -> Any:
|
|
300
|
+
task_kwargs = self._get_func_kwargs(kwargs)
|
|
301
|
+
shared_ctx = SharedContext()
|
|
302
|
+
session = Session(shared_ctx=shared_ctx)
|
|
303
|
+
return self.run(session=session, kwargs=task_kwargs)
|
|
304
|
+
|
|
305
|
+
task_runner_fn.__doc__ = self._create_fn_docstring()
|
|
306
|
+
task_runner_fn.__signature__ = self._create_fn_signature()
|
|
307
|
+
task_runner_fn.__name__ = self.name
|
|
308
|
+
return task_runner_fn
|
|
309
|
+
|
|
310
|
+
def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
311
|
+
fn_kwargs = {}
|
|
312
|
+
for inp in self.inputs:
|
|
313
|
+
snake_input_name = to_snake_case(inp.name)
|
|
314
|
+
if snake_input_name in kwargs:
|
|
315
|
+
fn_kwargs[inp.name] = kwargs[snake_input_name]
|
|
316
|
+
return fn_kwargs
|
|
317
|
+
|
|
318
|
+
def _create_fn_docstring(self) -> str:
|
|
319
|
+
from zrb.context.shared_context import SharedContext
|
|
320
|
+
|
|
321
|
+
stub_shared_ctx = SharedContext()
|
|
322
|
+
str_input_default_values = {}
|
|
323
|
+
for inp in self.inputs:
|
|
324
|
+
str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
|
|
325
|
+
# Create docstring
|
|
326
|
+
doc = f"{self.description}\n\n"
|
|
327
|
+
if len(self.inputs) > 0:
|
|
328
|
+
doc += "Args:\n"
|
|
329
|
+
for inp in self.inputs:
|
|
330
|
+
str_input_default = str_input_default_values.get(inp.name, "")
|
|
331
|
+
doc += (
|
|
332
|
+
f" {inp.name}: {inp.description} (default: {str_input_default})"
|
|
333
|
+
)
|
|
334
|
+
doc += "\n"
|
|
335
|
+
return doc
|
|
336
|
+
|
|
337
|
+
def _create_fn_signature(self) -> inspect.Signature:
|
|
338
|
+
params = []
|
|
339
|
+
for inp in self.inputs:
|
|
340
|
+
params.append(
|
|
341
|
+
inspect.Parameter(
|
|
342
|
+
name=to_snake_case(inp.name),
|
|
343
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
return inspect.Signature(params)
|
zrb/task/base_trigger.py
CHANGED
|
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
|
|
|
36
36
|
input: list[AnyInput | None] | AnyInput | None = None,
|
|
37
37
|
env: list[AnyEnv | None] | AnyEnv | None = None,
|
|
38
38
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
|
39
|
-
execute_condition: bool | str | Callable[[
|
|
39
|
+
execute_condition: bool | str | Callable[[AnyContext], bool] = True,
|
|
40
40
|
queue_name: fstring | None = None,
|
|
41
41
|
callback: list[AnyCallback] | AnyCallback = [],
|
|
42
42
|
retries: int = 2,
|
|
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
|
|
|
127
127
|
return self._callbacks
|
|
128
128
|
|
|
129
129
|
async def exec_root_tasks(self, session: AnySession):
|
|
130
|
-
exchange_xcom = self.
|
|
130
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
131
131
|
exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
|
|
132
132
|
return await super().exec_root_tasks(session)
|
|
133
133
|
|
|
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
|
|
|
136
136
|
session.defer_coro(coro)
|
|
137
137
|
|
|
138
138
|
async def _fanout_and_trigger_callback(self, session: AnySession):
|
|
139
|
-
|
|
140
|
-
data = exchange_xcom.pop()
|
|
139
|
+
data = self.pop_exchange_xcom(session)
|
|
141
140
|
coros = []
|
|
142
141
|
for callback in self.callbacks:
|
|
143
142
|
xcom_dict = DotDict({self.queue_name: Xcom([data])})
|
|
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
|
|
|
156
155
|
)
|
|
157
156
|
await asyncio.gather(*coros)
|
|
158
157
|
|
|
159
|
-
def
|
|
158
|
+
def _get_exchange_xcom(self, session: AnySession) -> Xcom:
|
|
160
159
|
shared_ctx = session.shared_ctx
|
|
161
160
|
if self.queue_name not in shared_ctx.xcom:
|
|
162
161
|
shared_ctx.xcom[self.queue_name] = Xcom()
|
|
163
162
|
return shared_ctx.xcom[self.queue_name]
|
|
163
|
+
|
|
164
|
+
def push_exchange_xcom(self, session: AnySession, data: Any):
|
|
165
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
166
|
+
exchange_xcom.push(data)
|
|
167
|
+
|
|
168
|
+
def pop_exchange_xcom(self, session: AnySession) -> Any:
|
|
169
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
170
|
+
return exchange_xcom.pop()
|
zrb/task/llm/agent.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
|
1
|
-
import
|
|
1
|
+
import inspect
|
|
2
2
|
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from typing import TYPE_CHECKING, Any
|
|
4
5
|
|
|
5
|
-
from zrb.config.llm_rate_limitter import
|
|
6
|
+
from zrb.config.llm_rate_limitter import LLMRateLimitter
|
|
6
7
|
from zrb.context.any_context import AnyContext
|
|
7
|
-
from zrb.
|
|
8
|
-
from zrb.task.llm.error import extract_api_error_details
|
|
9
|
-
from zrb.task.llm.print_node import print_node
|
|
8
|
+
from zrb.task.llm.history_processor import create_summarize_history_processor
|
|
10
9
|
from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
|
|
11
|
-
from zrb.task.llm.typing import ListOfDict
|
|
12
10
|
|
|
13
11
|
if TYPE_CHECKING:
|
|
14
12
|
from pydantic_ai import Agent, Tool
|
|
15
|
-
from pydantic_ai.
|
|
16
|
-
from pydantic_ai.messages import UserContent
|
|
13
|
+
from pydantic_ai._agent_graph import HistoryProcessor
|
|
17
14
|
from pydantic_ai.models import Model
|
|
15
|
+
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
18
16
|
from pydantic_ai.settings import ModelSettings
|
|
19
17
|
from pydantic_ai.toolsets import AbstractToolset
|
|
20
18
|
|
|
@@ -24,19 +22,59 @@ if TYPE_CHECKING:
|
|
|
24
22
|
def create_agent_instance(
|
|
25
23
|
ctx: AnyContext,
|
|
26
24
|
model: "str | Model",
|
|
25
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
26
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
27
27
|
system_prompt: str = "",
|
|
28
28
|
model_settings: "ModelSettings | None" = None,
|
|
29
|
-
tools:
|
|
30
|
-
toolsets: list["AbstractToolset[
|
|
29
|
+
tools: list["ToolOrCallable"] = [],
|
|
30
|
+
toolsets: list["AbstractToolset[None]"] = [],
|
|
31
31
|
retries: int = 3,
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
yolo_mode: bool | list[str] | None = None,
|
|
33
|
+
summarization_model: "Model | str | None" = None,
|
|
34
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
35
|
+
summarization_system_prompt: str | None = None,
|
|
36
|
+
summarization_retries: int = 2,
|
|
37
|
+
summarization_token_threshold: int | None = None,
|
|
38
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
39
|
+
auto_summarize: bool = True,
|
|
40
|
+
) -> "Agent[None, Any]":
|
|
34
41
|
"""Creates a new Agent instance with configured tools and servers."""
|
|
35
|
-
from pydantic_ai import Agent, Tool
|
|
42
|
+
from pydantic_ai import Agent, RunContext, Tool
|
|
36
43
|
from pydantic_ai.tools import GenerateToolJsonSchema
|
|
44
|
+
from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ConfirmationWrapperToolset(WrapperToolset):
|
|
48
|
+
ctx: AnyContext
|
|
49
|
+
yolo_mode: bool | list[str]
|
|
50
|
+
|
|
51
|
+
async def call_tool(
|
|
52
|
+
self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
|
|
53
|
+
) -> Any:
|
|
54
|
+
# The `tool` object is passed in. Use it for inspection.
|
|
55
|
+
# Define a temporary function that performs the actual tool call.
|
|
56
|
+
async def execute_delegated_tool_call(**params):
|
|
57
|
+
# Pass all arguments down the chain.
|
|
58
|
+
return await self.wrapped.call_tool(name, tool_args, ctx, tool)
|
|
59
|
+
|
|
60
|
+
# For the confirmation UI, make our temporary function look like the real one.
|
|
61
|
+
try:
|
|
62
|
+
execute_delegated_tool_call.__name__ = name
|
|
63
|
+
execute_delegated_tool_call.__doc__ = tool.function.__doc__
|
|
64
|
+
execute_delegated_tool_call.__signature__ = inspect.signature(
|
|
65
|
+
tool.function
|
|
66
|
+
)
|
|
67
|
+
except (AttributeError, TypeError):
|
|
68
|
+
pass # Ignore if we can't inspect the original function
|
|
69
|
+
# Use the existing wrap_func to get the confirmation logic
|
|
70
|
+
wrapped_executor = wrap_func(
|
|
71
|
+
execute_delegated_tool_call, self.ctx, self.yolo_mode
|
|
72
|
+
)
|
|
73
|
+
# Call the wrapped executor. This will trigger the confirmation prompt.
|
|
74
|
+
return await wrapped_executor(**tool_args)
|
|
37
75
|
|
|
38
|
-
if
|
|
39
|
-
|
|
76
|
+
if yolo_mode is None:
|
|
77
|
+
yolo_mode = False
|
|
40
78
|
# Normalize tools
|
|
41
79
|
tool_list = []
|
|
42
80
|
for tool_or_callable in tools:
|
|
@@ -46,7 +84,7 @@ def create_agent_instance(
|
|
|
46
84
|
tool = tool_or_callable
|
|
47
85
|
tool_list.append(
|
|
48
86
|
Tool(
|
|
49
|
-
function=wrap_func(tool.function, ctx,
|
|
87
|
+
function=wrap_func(tool.function, ctx, yolo_mode),
|
|
50
88
|
takes_ctx=tool.takes_ctx,
|
|
51
89
|
max_retries=tool.max_retries,
|
|
52
90
|
name=tool.name,
|
|
@@ -60,184 +98,107 @@ def create_agent_instance(
|
|
|
60
98
|
)
|
|
61
99
|
else:
|
|
62
100
|
# Turn function into tool
|
|
63
|
-
tool_list.append(wrap_tool(tool_or_callable, ctx,
|
|
101
|
+
tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
|
|
102
|
+
# Wrap toolsets
|
|
103
|
+
wrapped_toolsets = [
|
|
104
|
+
ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
|
|
105
|
+
for toolset in toolsets
|
|
106
|
+
]
|
|
107
|
+
# Create History processor with summarizer
|
|
108
|
+
history_processors = [] if history_processors is None else history_processors
|
|
109
|
+
if auto_summarize:
|
|
110
|
+
history_processors += [
|
|
111
|
+
create_summarize_history_processor(
|
|
112
|
+
ctx=ctx,
|
|
113
|
+
system_prompt=system_prompt,
|
|
114
|
+
rate_limitter=rate_limitter,
|
|
115
|
+
summarization_model=summarization_model,
|
|
116
|
+
summarization_model_settings=summarization_model_settings,
|
|
117
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
118
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
119
|
+
summarization_retries=summarization_retries,
|
|
120
|
+
)
|
|
121
|
+
]
|
|
64
122
|
# Return Agent
|
|
65
|
-
return Agent(
|
|
123
|
+
return Agent[None, Any](
|
|
66
124
|
model=model,
|
|
67
|
-
|
|
125
|
+
output_type=output_type,
|
|
126
|
+
instructions=system_prompt,
|
|
68
127
|
tools=tool_list,
|
|
69
|
-
toolsets=
|
|
128
|
+
toolsets=wrapped_toolsets,
|
|
70
129
|
model_settings=model_settings,
|
|
71
130
|
retries=retries,
|
|
131
|
+
history_processors=history_processors,
|
|
72
132
|
)
|
|
73
133
|
|
|
74
134
|
|
|
75
135
|
def get_agent(
|
|
76
136
|
ctx: AnyContext,
|
|
77
|
-
agent_attr: "Agent | Callable[[AnySharedContext], Agent] | None",
|
|
78
137
|
model: "str | Model",
|
|
79
|
-
|
|
80
|
-
|
|
138
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
139
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
140
|
+
system_prompt: str = "",
|
|
141
|
+
model_settings: "ModelSettings | None" = None,
|
|
81
142
|
tools_attr: (
|
|
82
|
-
"list[ToolOrCallable] | Callable[[
|
|
83
|
-
),
|
|
84
|
-
additional_tools: "list[ToolOrCallable]",
|
|
85
|
-
toolsets_attr: "list[AbstractToolset[
|
|
86
|
-
additional_toolsets: "list[AbstractToolset[
|
|
143
|
+
"list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
|
|
144
|
+
) = [],
|
|
145
|
+
additional_tools: "list[ToolOrCallable]" = [],
|
|
146
|
+
toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
|
|
147
|
+
additional_toolsets: "list[AbstractToolset[None] | str]" = [],
|
|
87
148
|
retries: int = 3,
|
|
88
|
-
|
|
149
|
+
yolo_mode: bool | list[str] | None = None,
|
|
150
|
+
summarization_model: "Model | str | None" = None,
|
|
151
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
152
|
+
summarization_system_prompt: str | None = None,
|
|
153
|
+
summarization_retries: int = 2,
|
|
154
|
+
summarization_token_threshold: int | None = None,
|
|
155
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
89
156
|
) -> "Agent":
|
|
90
157
|
"""Retrieves the configured Agent instance or creates one if necessary."""
|
|
91
|
-
from pydantic_ai import Agent
|
|
92
|
-
|
|
93
|
-
# Render agent instance and return if agent_attr is already an agent
|
|
94
|
-
if isinstance(agent_attr, Agent):
|
|
95
|
-
return agent_attr
|
|
96
|
-
if callable(agent_attr):
|
|
97
|
-
agent_instance = agent_attr(ctx)
|
|
98
|
-
if not isinstance(agent_instance, Agent):
|
|
99
|
-
err_msg = (
|
|
100
|
-
"Callable agent factory did not return an Agent instance, "
|
|
101
|
-
f"got: {type(agent_instance)}"
|
|
102
|
-
)
|
|
103
|
-
raise TypeError(err_msg)
|
|
104
|
-
return agent_instance
|
|
105
158
|
# Get tools for agent
|
|
106
159
|
tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
|
|
107
160
|
tools.extend(additional_tools)
|
|
108
161
|
# Get Toolsets for agent
|
|
109
|
-
|
|
110
|
-
|
|
162
|
+
toolset_or_str_list = list(
|
|
163
|
+
toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
|
|
164
|
+
)
|
|
165
|
+
toolset_or_str_list.extend(additional_toolsets)
|
|
166
|
+
toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
|
|
111
167
|
# If no agent provided, create one using the configuration
|
|
112
168
|
return create_agent_instance(
|
|
113
169
|
ctx=ctx,
|
|
114
170
|
model=model,
|
|
171
|
+
rate_limitter=rate_limitter,
|
|
172
|
+
output_type=output_type,
|
|
115
173
|
system_prompt=system_prompt,
|
|
116
174
|
tools=tools,
|
|
117
|
-
toolsets=
|
|
175
|
+
toolsets=toolsets,
|
|
118
176
|
model_settings=model_settings,
|
|
119
177
|
retries=retries,
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
user_prompt: str,
|
|
128
|
-
attachments: "list[UserContent] | None" = None,
|
|
129
|
-
history_list: ListOfDict | None = None,
|
|
130
|
-
rate_limitter: LLMRateLimiter | None = None,
|
|
131
|
-
max_retry: int = 2,
|
|
132
|
-
) -> "AgentRun":
|
|
133
|
-
"""
|
|
134
|
-
Runs a single iteration of the agent execution loop.
|
|
135
|
-
|
|
136
|
-
Args:
|
|
137
|
-
ctx: The task context.
|
|
138
|
-
agent: The Pydantic AI agent instance.
|
|
139
|
-
user_prompt: The user's input prompt.
|
|
140
|
-
history_list: The current conversation history.
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
The agent run result object.
|
|
144
|
-
|
|
145
|
-
Raises:
|
|
146
|
-
Exception: If any error occurs during agent execution.
|
|
147
|
-
"""
|
|
148
|
-
if max_retry < 0:
|
|
149
|
-
raise ValueError("Max retry cannot be less than 0")
|
|
150
|
-
attempt = 0
|
|
151
|
-
while attempt < max_retry:
|
|
152
|
-
try:
|
|
153
|
-
return await _run_single_agent_iteration(
|
|
154
|
-
ctx=ctx,
|
|
155
|
-
agent=agent,
|
|
156
|
-
user_prompt=user_prompt,
|
|
157
|
-
attachments=[] if attachments is None else attachments,
|
|
158
|
-
history_list=[] if history_list is None else history_list,
|
|
159
|
-
rate_limitter=(
|
|
160
|
-
llm_rate_limitter if rate_limitter is None else rate_limitter
|
|
161
|
-
),
|
|
162
|
-
)
|
|
163
|
-
except BaseException:
|
|
164
|
-
attempt += 1
|
|
165
|
-
if attempt == max_retry:
|
|
166
|
-
raise
|
|
167
|
-
raise Exception("Max retry exceeded")
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
async def _run_single_agent_iteration(
|
|
171
|
-
ctx: AnyContext,
|
|
172
|
-
agent: "Agent",
|
|
173
|
-
user_prompt: str,
|
|
174
|
-
attachments: "list[UserContent]",
|
|
175
|
-
history_list: ListOfDict,
|
|
176
|
-
rate_limitter: LLMRateLimiter,
|
|
177
|
-
) -> "AgentRun":
|
|
178
|
-
from openai import APIError
|
|
179
|
-
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
180
|
-
|
|
181
|
-
agent_payload = _estimate_request_payload(
|
|
182
|
-
agent, user_prompt, attachments, history_list
|
|
178
|
+
yolo_mode=yolo_mode,
|
|
179
|
+
summarization_model=summarization_model,
|
|
180
|
+
summarization_model_settings=summarization_model_settings,
|
|
181
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
182
|
+
summarization_retries=summarization_retries,
|
|
183
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
184
|
+
history_processors=history_processors,
|
|
183
185
|
)
|
|
184
|
-
if rate_limitter:
|
|
185
|
-
await rate_limitter.throttle(agent_payload)
|
|
186
|
-
else:
|
|
187
|
-
await llm_rate_limitter.throttle(agent_payload)
|
|
188
|
-
|
|
189
|
-
user_prompt_with_attachments = [user_prompt] + attachments
|
|
190
|
-
async with agent:
|
|
191
|
-
async with agent.iter(
|
|
192
|
-
user_prompt=user_prompt_with_attachments,
|
|
193
|
-
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
194
|
-
) as agent_run:
|
|
195
|
-
async for node in agent_run:
|
|
196
|
-
# Each node represents a step in the agent's execution
|
|
197
|
-
# Reference: https://ai.pydantic.dev/agents/#streaming
|
|
198
|
-
try:
|
|
199
|
-
await print_node(_get_plain_printer(ctx), agent_run, node)
|
|
200
|
-
except APIError as e:
|
|
201
|
-
# Extract detailed error information from the response
|
|
202
|
-
error_details = extract_api_error_details(e)
|
|
203
|
-
ctx.log_error(f"API Error: {error_details}")
|
|
204
|
-
raise
|
|
205
|
-
except Exception as e:
|
|
206
|
-
ctx.log_error(f"Error processing node: {str(e)}")
|
|
207
|
-
ctx.log_error(f"Error type: {type(e).__name__}")
|
|
208
|
-
raise
|
|
209
|
-
return agent_run
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def _estimate_request_payload(
|
|
213
|
-
agent: "Agent",
|
|
214
|
-
user_prompt: str,
|
|
215
|
-
attachments: "list[UserContent]",
|
|
216
|
-
history_list: ListOfDict,
|
|
217
|
-
) -> str:
|
|
218
|
-
system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
|
|
219
|
-
return json.dumps(
|
|
220
|
-
[
|
|
221
|
-
{"role": "system", "content": "\n".join(system_prompts)},
|
|
222
|
-
*history_list,
|
|
223
|
-
{"role": "user", "content": user_prompt},
|
|
224
|
-
*[_estimate_attachment_payload(attachment) for attachment in attachments],
|
|
225
|
-
]
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def _estimate_attachment_payload(attachment: "UserContent") -> Any:
|
|
230
|
-
if hasattr(attachment, "url"):
|
|
231
|
-
return {"role": "user", "content": attachment.url}
|
|
232
|
-
if hasattr(attachment, "data"):
|
|
233
|
-
return {"role": "user", "content": "x" * len(attachment.data)}
|
|
234
|
-
return ""
|
|
235
|
-
|
|
236
186
|
|
|
237
|
-
def _get_plain_printer(ctx: AnyContext):
|
|
238
|
-
def printer(*args, **kwargs):
|
|
239
|
-
if "plain" not in kwargs:
|
|
240
|
-
kwargs["plain"] = True
|
|
241
|
-
return ctx.print(*args, **kwargs)
|
|
242
187
|
|
|
243
|
-
|
|
188
|
+
def _render_toolset_or_str_list(
|
|
189
|
+
ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
|
|
190
|
+
) -> list["AbstractToolset[None]"]:
|
|
191
|
+
from pydantic_ai.mcp import load_mcp_servers
|
|
192
|
+
|
|
193
|
+
toolsets = []
|
|
194
|
+
for toolset_or_str in toolset_or_str_list:
|
|
195
|
+
if isinstance(toolset_or_str, str):
|
|
196
|
+
try:
|
|
197
|
+
servers = load_mcp_servers(toolset_or_str)
|
|
198
|
+
for server in servers:
|
|
199
|
+
toolsets.append(server)
|
|
200
|
+
except Exception as e:
|
|
201
|
+
ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
|
|
202
|
+
continue
|
|
203
|
+
toolsets.append(toolset_or_str)
|
|
204
|
+
return toolsets
|