zrb 1.13.1__py3-none-any.whl → 1.21.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +8 -8
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/group.py +31 -15
- zrb/builtin/http.py +7 -8
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_session.py +130 -144
- zrb/builtin/llm/chat_session_cmd.py +226 -0
- zrb/builtin/llm/chat_trigger.py +73 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +218 -110
- zrb/builtin/llm/tool/api.py +74 -62
- zrb/builtin/llm/tool/cli.py +35 -16
- zrb/builtin/llm/tool/code.py +49 -47
- zrb/builtin/llm/tool/file.py +262 -251
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +25 -18
- zrb/builtin/llm/tool/sub_agent.py +29 -22
- zrb/builtin/llm/tool/web.py +135 -143
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/setup/latex/ubuntu.py +1 -0
- zrb/builtin/setup/ubuntu.py +1 -1
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +255 -78
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +8 -13
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +129 -24
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +89 -45
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +6 -8
- zrb/group/any_group.py +12 -5
- zrb/group/group.py +67 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_route/task_session_api_route.py +1 -4
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +138 -52
- zrb/task/llm/config.py +45 -13
- zrb/task/llm/conversation_history.py +76 -6
- zrb/task/llm/conversation_history_model.py +0 -168
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_summarization.py +22 -35
- zrb/task/llm/history_summarization_tool.py +24 -0
- zrb/task/llm/print_node.py +182 -63
- zrb/task/llm/prompt.py +213 -153
- zrb/task/llm/tool_wrapper.py +210 -53
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +98 -47
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +50 -40
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +27 -11
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/METADATA +40 -20
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/RECORD +102 -79
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/entry_points.txt +0 -0
zrb/task/base_task.py
CHANGED
|
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
|
|
|
21
21
|
)
|
|
22
22
|
from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
|
|
23
23
|
from zrb.task.base.operators import handle_lshift, handle_rshift
|
|
24
|
+
from zrb.util.string.conversion import to_snake_case
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class BaseTask(AnyTask):
|
|
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
|
|
|
216
217
|
return build_task_context(self, session)
|
|
217
218
|
|
|
218
219
|
def run(
|
|
219
|
-
self,
|
|
220
|
+
self,
|
|
221
|
+
session: AnySession | None = None,
|
|
222
|
+
str_kwargs: dict[str, str] | None = None,
|
|
223
|
+
kwargs: dict[str, Any] | None = None,
|
|
220
224
|
) -> Any:
|
|
221
225
|
"""
|
|
222
226
|
Synchronously runs the task and its dependencies, handling async setup and cleanup.
|
|
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
|
|
|
235
239
|
Any: The final result of the main task execution.
|
|
236
240
|
"""
|
|
237
241
|
# Use asyncio.run() to execute the async cleanup wrapper
|
|
238
|
-
return asyncio.run(
|
|
242
|
+
return asyncio.run(
|
|
243
|
+
run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
|
|
244
|
+
)
|
|
239
245
|
|
|
240
246
|
async def async_run(
|
|
241
|
-
self,
|
|
247
|
+
self,
|
|
248
|
+
session: AnySession | None = None,
|
|
249
|
+
str_kwargs: dict[str, str] | None = None,
|
|
250
|
+
kwargs: dict[str, Any] | None = None,
|
|
242
251
|
) -> Any:
|
|
243
|
-
return await run_task_async(
|
|
252
|
+
return await run_task_async(
|
|
253
|
+
self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
|
|
254
|
+
)
|
|
244
255
|
|
|
245
256
|
async def exec_root_tasks(self, session: AnySession):
|
|
246
257
|
return await execute_root_tasks(self, session)
|
|
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
|
|
|
276
287
|
# Add definition location to the error
|
|
277
288
|
if hasattr(e, "add_note"):
|
|
278
289
|
e.add_note(additional_error_note)
|
|
279
|
-
|
|
290
|
+
elif hasattr(e, "__notes__"):
|
|
280
291
|
# fallback: use the __notes__ attribute directly
|
|
281
292
|
e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
|
|
282
293
|
raise e
|
|
294
|
+
|
|
295
|
+
def to_function(self) -> Callable[..., Any]:
|
|
296
|
+
from zrb.context.shared_context import SharedContext
|
|
297
|
+
from zrb.session.session import Session
|
|
298
|
+
|
|
299
|
+
def task_runner_fn(**kwargs) -> Any:
|
|
300
|
+
task_kwargs = self._get_func_kwargs(kwargs)
|
|
301
|
+
shared_ctx = SharedContext()
|
|
302
|
+
session = Session(shared_ctx=shared_ctx)
|
|
303
|
+
return self.run(session=session, kwargs=task_kwargs)
|
|
304
|
+
|
|
305
|
+
task_runner_fn.__doc__ = self._create_fn_docstring()
|
|
306
|
+
task_runner_fn.__signature__ = self._create_fn_signature()
|
|
307
|
+
task_runner_fn.__name__ = self.name
|
|
308
|
+
return task_runner_fn
|
|
309
|
+
|
|
310
|
+
def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
311
|
+
fn_kwargs = {}
|
|
312
|
+
for inp in self.inputs:
|
|
313
|
+
snake_input_name = to_snake_case(inp.name)
|
|
314
|
+
if snake_input_name in kwargs:
|
|
315
|
+
fn_kwargs[inp.name] = kwargs[snake_input_name]
|
|
316
|
+
return fn_kwargs
|
|
317
|
+
|
|
318
|
+
def _create_fn_docstring(self) -> str:
|
|
319
|
+
from zrb.context.shared_context import SharedContext
|
|
320
|
+
|
|
321
|
+
stub_shared_ctx = SharedContext()
|
|
322
|
+
str_input_default_values = {}
|
|
323
|
+
for inp in self.inputs:
|
|
324
|
+
str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
|
|
325
|
+
# Create docstring
|
|
326
|
+
doc = f"{self.description}\n\n"
|
|
327
|
+
if len(self.inputs) > 0:
|
|
328
|
+
doc += "Args:\n"
|
|
329
|
+
for inp in self.inputs:
|
|
330
|
+
str_input_default = str_input_default_values.get(inp.name, "")
|
|
331
|
+
doc += (
|
|
332
|
+
f" {inp.name}: {inp.description} (default: {str_input_default})"
|
|
333
|
+
)
|
|
334
|
+
doc += "\n"
|
|
335
|
+
return doc
|
|
336
|
+
|
|
337
|
+
def _create_fn_signature(self) -> inspect.Signature:
|
|
338
|
+
params = []
|
|
339
|
+
for inp in self.inputs:
|
|
340
|
+
params.append(
|
|
341
|
+
inspect.Parameter(
|
|
342
|
+
name=to_snake_case(inp.name),
|
|
343
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
return inspect.Signature(params)
|
zrb/task/base_trigger.py
CHANGED
|
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
|
|
|
36
36
|
input: list[AnyInput | None] | AnyInput | None = None,
|
|
37
37
|
env: list[AnyEnv | None] | AnyEnv | None = None,
|
|
38
38
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
|
39
|
-
execute_condition: bool | str | Callable[[
|
|
39
|
+
execute_condition: bool | str | Callable[[AnyContext], bool] = True,
|
|
40
40
|
queue_name: fstring | None = None,
|
|
41
41
|
callback: list[AnyCallback] | AnyCallback = [],
|
|
42
42
|
retries: int = 2,
|
|
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
|
|
|
127
127
|
return self._callbacks
|
|
128
128
|
|
|
129
129
|
async def exec_root_tasks(self, session: AnySession):
|
|
130
|
-
exchange_xcom = self.
|
|
130
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
131
131
|
exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
|
|
132
132
|
return await super().exec_root_tasks(session)
|
|
133
133
|
|
|
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
|
|
|
136
136
|
session.defer_coro(coro)
|
|
137
137
|
|
|
138
138
|
async def _fanout_and_trigger_callback(self, session: AnySession):
|
|
139
|
-
|
|
140
|
-
data = exchange_xcom.pop()
|
|
139
|
+
data = self.pop_exchange_xcom(session)
|
|
141
140
|
coros = []
|
|
142
141
|
for callback in self.callbacks:
|
|
143
142
|
xcom_dict = DotDict({self.queue_name: Xcom([data])})
|
|
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
|
|
|
156
155
|
)
|
|
157
156
|
await asyncio.gather(*coros)
|
|
158
157
|
|
|
159
|
-
def
|
|
158
|
+
def _get_exchange_xcom(self, session: AnySession) -> Xcom:
|
|
160
159
|
shared_ctx = session.shared_ctx
|
|
161
160
|
if self.queue_name not in shared_ctx.xcom:
|
|
162
161
|
shared_ctx.xcom[self.queue_name] = Xcom()
|
|
163
162
|
return shared_ctx.xcom[self.queue_name]
|
|
163
|
+
|
|
164
|
+
def push_exchange_xcom(self, session: AnySession, data: Any):
|
|
165
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
166
|
+
exchange_xcom.push(data)
|
|
167
|
+
|
|
168
|
+
def pop_exchange_xcom(self, session: AnySession) -> Any:
|
|
169
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
170
|
+
return exchange_xcom.pop()
|
zrb/task/llm/agent.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import json
|
|
2
3
|
from collections.abc import Callable
|
|
4
|
+
from dataclasses import dataclass
|
|
3
5
|
from typing import TYPE_CHECKING, Any
|
|
4
6
|
|
|
5
7
|
from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
|
|
@@ -9,32 +11,68 @@ from zrb.task.llm.error import extract_api_error_details
|
|
|
9
11
|
from zrb.task.llm.print_node import print_node
|
|
10
12
|
from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
|
|
11
13
|
from zrb.task.llm.typing import ListOfDict
|
|
14
|
+
from zrb.util.cli.style import stylize_faint
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
17
|
from pydantic_ai import Agent, Tool
|
|
15
18
|
from pydantic_ai.agent import AgentRun
|
|
19
|
+
from pydantic_ai.messages import UserContent
|
|
16
20
|
from pydantic_ai.models import Model
|
|
21
|
+
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
17
22
|
from pydantic_ai.settings import ModelSettings
|
|
18
23
|
from pydantic_ai.toolsets import AbstractToolset
|
|
19
24
|
|
|
20
25
|
ToolOrCallable = Tool | Callable
|
|
21
|
-
else:
|
|
22
|
-
ToolOrCallable = Any
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
def create_agent_instance(
|
|
26
29
|
ctx: AnyContext,
|
|
27
|
-
model: "str | Model
|
|
30
|
+
model: "str | Model",
|
|
31
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
28
32
|
system_prompt: str = "",
|
|
29
33
|
model_settings: "ModelSettings | None" = None,
|
|
30
|
-
tools: list[ToolOrCallable] = [],
|
|
31
|
-
toolsets: list["AbstractToolset[
|
|
34
|
+
tools: "list[ToolOrCallable]" = [],
|
|
35
|
+
toolsets: list["AbstractToolset[None]"] = [],
|
|
32
36
|
retries: int = 3,
|
|
33
|
-
|
|
37
|
+
yolo_mode: bool | list[str] | None = None,
|
|
38
|
+
) -> "Agent[None, Any]":
|
|
34
39
|
"""Creates a new Agent instance with configured tools and servers."""
|
|
35
|
-
from pydantic_ai import Agent, Tool
|
|
40
|
+
from pydantic_ai import Agent, RunContext, Tool
|
|
36
41
|
from pydantic_ai.tools import GenerateToolJsonSchema
|
|
42
|
+
from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ConfirmationWrapperToolset(WrapperToolset):
|
|
46
|
+
ctx: AnyContext
|
|
47
|
+
yolo_mode: bool | list[str]
|
|
48
|
+
|
|
49
|
+
async def call_tool(
|
|
50
|
+
self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
|
|
51
|
+
) -> Any:
|
|
52
|
+
# The `tool` object is passed in. Use it for inspection.
|
|
53
|
+
# Define a temporary function that performs the actual tool call.
|
|
54
|
+
async def execute_delegated_tool_call(**params):
|
|
55
|
+
# Pass all arguments down the chain.
|
|
56
|
+
return await self.wrapped.call_tool(name, tool_args, ctx, tool)
|
|
37
57
|
|
|
58
|
+
# For the confirmation UI, make our temporary function look like the real one.
|
|
59
|
+
try:
|
|
60
|
+
execute_delegated_tool_call.__name__ = name
|
|
61
|
+
execute_delegated_tool_call.__doc__ = tool.function.__doc__
|
|
62
|
+
execute_delegated_tool_call.__signature__ = inspect.signature(
|
|
63
|
+
tool.function
|
|
64
|
+
)
|
|
65
|
+
except (AttributeError, TypeError):
|
|
66
|
+
pass # Ignore if we can't inspect the original function
|
|
67
|
+
# Use the existing wrap_func to get the confirmation logic
|
|
68
|
+
wrapped_executor = wrap_func(
|
|
69
|
+
execute_delegated_tool_call, self.ctx, self.yolo_mode
|
|
70
|
+
)
|
|
71
|
+
# Call the wrapped executor. This will trigger the confirmation prompt.
|
|
72
|
+
return await wrapped_executor(**tool_args)
|
|
73
|
+
|
|
74
|
+
if yolo_mode is None:
|
|
75
|
+
yolo_mode = False
|
|
38
76
|
# Normalize tools
|
|
39
77
|
tool_list = []
|
|
40
78
|
for tool_or_callable in tools:
|
|
@@ -44,7 +82,7 @@ def create_agent_instance(
|
|
|
44
82
|
tool = tool_or_callable
|
|
45
83
|
tool_list.append(
|
|
46
84
|
Tool(
|
|
47
|
-
function=wrap_func(tool.function),
|
|
85
|
+
function=wrap_func(tool.function, ctx, yolo_mode),
|
|
48
86
|
takes_ctx=tool.takes_ctx,
|
|
49
87
|
max_retries=tool.max_retries,
|
|
50
88
|
name=tool.name,
|
|
@@ -58,13 +96,19 @@ def create_agent_instance(
|
|
|
58
96
|
)
|
|
59
97
|
else:
|
|
60
98
|
# Turn function into tool
|
|
61
|
-
tool_list.append(wrap_tool(tool_or_callable, ctx))
|
|
99
|
+
tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
|
|
100
|
+
# Wrap toolsets
|
|
101
|
+
wrapped_toolsets = [
|
|
102
|
+
ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
|
|
103
|
+
for toolset in toolsets
|
|
104
|
+
]
|
|
62
105
|
# Return Agent
|
|
63
|
-
return Agent(
|
|
106
|
+
return Agent[None, Any](
|
|
64
107
|
model=model,
|
|
65
|
-
|
|
108
|
+
output_type=output_type,
|
|
109
|
+
instructions=system_prompt,
|
|
66
110
|
tools=tool_list,
|
|
67
|
-
toolsets=
|
|
111
|
+
toolsets=wrapped_toolsets,
|
|
68
112
|
model_settings=model_settings,
|
|
69
113
|
retries=retries,
|
|
70
114
|
)
|
|
@@ -72,58 +116,71 @@ def create_agent_instance(
|
|
|
72
116
|
|
|
73
117
|
def get_agent(
|
|
74
118
|
ctx: AnyContext,
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
system_prompt: str,
|
|
78
|
-
model_settings: "ModelSettings | None",
|
|
119
|
+
model: "str | Model",
|
|
120
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
121
|
+
system_prompt: str = "",
|
|
122
|
+
model_settings: "ModelSettings | None" = None,
|
|
79
123
|
tools_attr: (
|
|
80
|
-
list[ToolOrCallable] | Callable[[
|
|
81
|
-
),
|
|
82
|
-
additional_tools: list[ToolOrCallable],
|
|
83
|
-
toolsets_attr: "list[AbstractToolset[
|
|
84
|
-
additional_toolsets: "list[AbstractToolset[
|
|
124
|
+
"list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
|
|
125
|
+
) = [],
|
|
126
|
+
additional_tools: "list[ToolOrCallable]" = [],
|
|
127
|
+
toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
|
|
128
|
+
additional_toolsets: "list[AbstractToolset[None] | str]" = [],
|
|
85
129
|
retries: int = 3,
|
|
130
|
+
yolo_mode: bool | list[str] | None = None,
|
|
86
131
|
) -> "Agent":
|
|
87
132
|
"""Retrieves the configured Agent instance or creates one if necessary."""
|
|
88
|
-
from pydantic_ai import Agent
|
|
89
|
-
|
|
90
|
-
# Render agent instance and return if agent_attr is already an agent
|
|
91
|
-
if isinstance(agent_attr, Agent):
|
|
92
|
-
return agent_attr
|
|
93
|
-
if callable(agent_attr):
|
|
94
|
-
agent_instance = agent_attr(ctx)
|
|
95
|
-
if not isinstance(agent_instance, Agent):
|
|
96
|
-
err_msg = (
|
|
97
|
-
"Callable agent factory did not return an Agent instance, "
|
|
98
|
-
f"got: {type(agent_instance)}"
|
|
99
|
-
)
|
|
100
|
-
raise TypeError(err_msg)
|
|
101
|
-
return agent_instance
|
|
102
133
|
# Get tools for agent
|
|
103
134
|
tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
|
|
104
135
|
tools.extend(additional_tools)
|
|
105
136
|
# Get Toolsets for agent
|
|
106
|
-
|
|
107
|
-
|
|
137
|
+
toolset_or_str_list = list(
|
|
138
|
+
toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
|
|
139
|
+
)
|
|
140
|
+
toolset_or_str_list.extend(additional_toolsets)
|
|
141
|
+
toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
|
|
108
142
|
# If no agent provided, create one using the configuration
|
|
109
143
|
return create_agent_instance(
|
|
110
144
|
ctx=ctx,
|
|
111
145
|
model=model,
|
|
146
|
+
output_type=output_type,
|
|
112
147
|
system_prompt=system_prompt,
|
|
113
148
|
tools=tools,
|
|
114
|
-
toolsets=
|
|
149
|
+
toolsets=toolsets,
|
|
115
150
|
model_settings=model_settings,
|
|
116
151
|
retries=retries,
|
|
152
|
+
yolo_mode=yolo_mode,
|
|
117
153
|
)
|
|
118
154
|
|
|
119
155
|
|
|
156
|
+
def _render_toolset_or_str_list(
|
|
157
|
+
ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
|
|
158
|
+
) -> list["AbstractToolset[None]"]:
|
|
159
|
+
from pydantic_ai.mcp import load_mcp_servers
|
|
160
|
+
|
|
161
|
+
toolsets = []
|
|
162
|
+
for toolset_or_str in toolset_or_str_list:
|
|
163
|
+
if isinstance(toolset_or_str, str):
|
|
164
|
+
try:
|
|
165
|
+
servers = load_mcp_servers(toolset_or_str)
|
|
166
|
+
for server in servers:
|
|
167
|
+
toolsets.append(server)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
|
|
170
|
+
continue
|
|
171
|
+
toolsets.append(toolset_or_str)
|
|
172
|
+
return toolsets
|
|
173
|
+
|
|
174
|
+
|
|
120
175
|
async def run_agent_iteration(
|
|
121
176
|
ctx: AnyContext,
|
|
122
|
-
agent: "Agent",
|
|
177
|
+
agent: "Agent[None, Any]",
|
|
123
178
|
user_prompt: str,
|
|
124
|
-
|
|
179
|
+
attachments: "list[UserContent] | None" = None,
|
|
180
|
+
history_list: ListOfDict | None = None,
|
|
125
181
|
rate_limitter: LLMRateLimiter | None = None,
|
|
126
182
|
max_retry: int = 2,
|
|
183
|
+
log_indent_level: int = 0,
|
|
127
184
|
) -> "AgentRun":
|
|
128
185
|
"""
|
|
129
186
|
Runs a single iteration of the agent execution loop.
|
|
@@ -149,8 +206,12 @@ async def run_agent_iteration(
|
|
|
149
206
|
ctx=ctx,
|
|
150
207
|
agent=agent,
|
|
151
208
|
user_prompt=user_prompt,
|
|
152
|
-
|
|
153
|
-
|
|
209
|
+
attachments=[] if attachments is None else attachments,
|
|
210
|
+
history_list=[] if history_list is None else history_list,
|
|
211
|
+
rate_limitter=(
|
|
212
|
+
llm_rate_limitter if rate_limitter is None else rate_limitter
|
|
213
|
+
),
|
|
214
|
+
log_indent_level=log_indent_level,
|
|
154
215
|
)
|
|
155
216
|
except BaseException:
|
|
156
217
|
attempt += 1
|
|
@@ -163,28 +224,34 @@ async def _run_single_agent_iteration(
|
|
|
163
224
|
ctx: AnyContext,
|
|
164
225
|
agent: "Agent",
|
|
165
226
|
user_prompt: str,
|
|
227
|
+
attachments: "list[UserContent]",
|
|
166
228
|
history_list: ListOfDict,
|
|
167
|
-
rate_limitter: LLMRateLimiter
|
|
229
|
+
rate_limitter: LLMRateLimiter,
|
|
230
|
+
log_indent_level: int,
|
|
168
231
|
) -> "AgentRun":
|
|
169
232
|
from openai import APIError
|
|
170
233
|
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
171
234
|
|
|
172
|
-
agent_payload =
|
|
235
|
+
agent_payload = _estimate_request_payload(
|
|
236
|
+
agent, user_prompt, attachments, history_list
|
|
237
|
+
)
|
|
238
|
+
callback = _create_print_throttle_notif(ctx)
|
|
173
239
|
if rate_limitter:
|
|
174
|
-
await rate_limitter.throttle(agent_payload)
|
|
240
|
+
await rate_limitter.throttle(agent_payload, callback)
|
|
175
241
|
else:
|
|
176
|
-
await llm_rate_limitter.throttle(agent_payload)
|
|
177
|
-
|
|
242
|
+
await llm_rate_limitter.throttle(agent_payload, callback)
|
|
243
|
+
user_prompt_with_attachments = [user_prompt] + attachments
|
|
178
244
|
async with agent:
|
|
179
245
|
async with agent.iter(
|
|
180
|
-
user_prompt=
|
|
246
|
+
user_prompt=user_prompt_with_attachments,
|
|
181
247
|
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
182
248
|
) as agent_run:
|
|
183
249
|
async for node in agent_run:
|
|
184
250
|
# Each node represents a step in the agent's execution
|
|
185
|
-
# Reference: https://ai.pydantic.dev/agents/#streaming
|
|
186
251
|
try:
|
|
187
|
-
await print_node(
|
|
252
|
+
await print_node(
|
|
253
|
+
_get_plain_printer(ctx), agent_run, node, log_indent_level
|
|
254
|
+
)
|
|
188
255
|
except APIError as e:
|
|
189
256
|
# Extract detailed error information from the response
|
|
190
257
|
error_details = extract_api_error_details(e)
|
|
@@ -197,8 +264,18 @@ async def _run_single_agent_iteration(
|
|
|
197
264
|
return agent_run
|
|
198
265
|
|
|
199
266
|
|
|
200
|
-
def
|
|
201
|
-
|
|
267
|
+
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[], None]:
|
|
268
|
+
def _print_throttle_notif():
|
|
269
|
+
ctx.print(stylize_faint(" ⌛>> Request Throttled"), plain=True)
|
|
270
|
+
|
|
271
|
+
return _print_throttle_notif
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _estimate_request_payload(
|
|
275
|
+
agent: "Agent",
|
|
276
|
+
user_prompt: str,
|
|
277
|
+
attachments: "list[UserContent]",
|
|
278
|
+
history_list: ListOfDict,
|
|
202
279
|
) -> str:
|
|
203
280
|
system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
|
|
204
281
|
return json.dumps(
|
|
@@ -206,10 +283,19 @@ def estimate_request_payload(
|
|
|
206
283
|
{"role": "system", "content": "\n".join(system_prompts)},
|
|
207
284
|
*history_list,
|
|
208
285
|
{"role": "user", "content": user_prompt},
|
|
286
|
+
*[_estimate_attachment_payload(attachment) for attachment in attachments],
|
|
209
287
|
]
|
|
210
288
|
)
|
|
211
289
|
|
|
212
290
|
|
|
291
|
+
def _estimate_attachment_payload(attachment: "UserContent") -> Any:
|
|
292
|
+
if hasattr(attachment, "url"):
|
|
293
|
+
return {"role": "user", "content": attachment.url}
|
|
294
|
+
if hasattr(attachment, "data"):
|
|
295
|
+
return {"role": "user", "content": "x" * len(attachment.data)}
|
|
296
|
+
return ""
|
|
297
|
+
|
|
298
|
+
|
|
213
299
|
def _get_plain_printer(ctx: AnyContext):
|
|
214
300
|
def printer(*args, **kwargs):
|
|
215
301
|
if "plain" not in kwargs:
|
zrb/task/llm/config.py
CHANGED
|
@@ -1,20 +1,43 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING,
|
|
1
|
+
from typing import TYPE_CHECKING, Callable
|
|
2
2
|
|
|
3
3
|
if TYPE_CHECKING:
|
|
4
4
|
from pydantic_ai.models import Model
|
|
5
5
|
from pydantic_ai.settings import ModelSettings
|
|
6
6
|
|
|
7
|
-
from zrb.attr.type import StrAttr,
|
|
7
|
+
from zrb.attr.type import BoolAttr, StrAttr, StrListAttr
|
|
8
8
|
from zrb.config.llm_config import LLMConfig, llm_config
|
|
9
9
|
from zrb.context.any_context import AnyContext
|
|
10
|
-
from zrb.
|
|
11
|
-
|
|
10
|
+
from zrb.util.attr import get_attr, get_bool_attr, get_str_list_attr
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_yolo_mode(
|
|
14
|
+
ctx: AnyContext,
|
|
15
|
+
yolo_mode_attr: (
|
|
16
|
+
Callable[[AnyContext], list[str] | bool | None] | StrListAttr | BoolAttr | None
|
|
17
|
+
) = None,
|
|
18
|
+
render_yolo_mode: bool = True,
|
|
19
|
+
) -> bool | list[str]:
|
|
20
|
+
if yolo_mode_attr is None:
|
|
21
|
+
return llm_config.default_yolo_mode
|
|
22
|
+
try:
|
|
23
|
+
return get_bool_attr(
|
|
24
|
+
ctx,
|
|
25
|
+
yolo_mode_attr,
|
|
26
|
+
False,
|
|
27
|
+
auto_render=render_yolo_mode,
|
|
28
|
+
)
|
|
29
|
+
except Exception:
|
|
30
|
+
return get_str_list_attr(
|
|
31
|
+
ctx,
|
|
32
|
+
yolo_mode_attr,
|
|
33
|
+
auto_render=render_yolo_mode,
|
|
34
|
+
)
|
|
12
35
|
|
|
13
36
|
|
|
14
37
|
def get_model_settings(
|
|
15
38
|
ctx: AnyContext,
|
|
16
39
|
model_settings_attr: (
|
|
17
|
-
"ModelSettings | Callable[[
|
|
40
|
+
"ModelSettings | Callable[[AnyContext], ModelSettings] | None"
|
|
18
41
|
) = None,
|
|
19
42
|
) -> "ModelSettings | None":
|
|
20
43
|
"""Gets the model settings, resolving callables if necessary."""
|
|
@@ -47,7 +70,7 @@ def get_model_api_key(
|
|
|
47
70
|
) -> str | None:
|
|
48
71
|
"""Gets the model API key, rendering if configured."""
|
|
49
72
|
api_key = get_attr(ctx, model_api_key_attr, None, auto_render=render_model_api_key)
|
|
50
|
-
if api_key is None and llm_config.
|
|
73
|
+
if api_key is None and llm_config.default_model_api_key is not None:
|
|
51
74
|
return llm_config.default_model_api_key
|
|
52
75
|
if isinstance(api_key, str) or api_key is None:
|
|
53
76
|
return api_key
|
|
@@ -56,18 +79,21 @@ def get_model_api_key(
|
|
|
56
79
|
|
|
57
80
|
def get_model(
|
|
58
81
|
ctx: AnyContext,
|
|
59
|
-
model_attr: "Callable[[
|
|
82
|
+
model_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
|
|
60
83
|
render_model: bool,
|
|
61
|
-
model_base_url_attr:
|
|
84
|
+
model_base_url_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
|
|
62
85
|
render_model_base_url: bool = True,
|
|
63
|
-
model_api_key_attr:
|
|
86
|
+
model_api_key_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None" = None,
|
|
64
87
|
render_model_api_key: bool = True,
|
|
65
|
-
|
|
88
|
+
is_small_model: bool = False,
|
|
89
|
+
) -> "str | Model":
|
|
66
90
|
"""Gets the model instance or name, handling defaults and configuration."""
|
|
67
91
|
from pydantic_ai.models import Model
|
|
68
92
|
|
|
69
93
|
model = get_attr(ctx, model_attr, None, auto_render=render_model)
|
|
70
94
|
if model is None:
|
|
95
|
+
if is_small_model:
|
|
96
|
+
return llm_config.default_small_model
|
|
71
97
|
return llm_config.default_model
|
|
72
98
|
if isinstance(model, str):
|
|
73
99
|
model_base_url = get_model_base_url(
|
|
@@ -76,11 +102,11 @@ def get_model(
|
|
|
76
102
|
model_api_key = get_model_api_key(ctx, model_api_key_attr, render_model_api_key)
|
|
77
103
|
new_llm_config = LLMConfig(
|
|
78
104
|
default_model_name=model,
|
|
79
|
-
|
|
80
|
-
|
|
105
|
+
default_model_base_url=model_base_url,
|
|
106
|
+
default_model_api_key=model_api_key,
|
|
81
107
|
)
|
|
82
108
|
if model_base_url is None and model_api_key is None:
|
|
83
|
-
default_model_provider =
|
|
109
|
+
default_model_provider = _get_default_model_provider(is_small_model)
|
|
84
110
|
if default_model_provider is not None:
|
|
85
111
|
new_llm_config.set_default_model_provider(default_model_provider)
|
|
86
112
|
return new_llm_config.default_model
|
|
@@ -88,3 +114,9 @@ def get_model(
|
|
|
88
114
|
if isinstance(model, Model):
|
|
89
115
|
return model
|
|
90
116
|
raise ValueError(f"Invalid model type resolved: {type(model)}, value: {model}")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_default_model_provider(is_small_model: bool = False):
|
|
120
|
+
if is_small_model:
|
|
121
|
+
return llm_config.default_small_model_provider
|
|
122
|
+
return llm_config.default_model_provider
|