zrb 1.8.10__py3-none-any.whl → 1.21.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +126 -113
- zrb/__main__.py +1 -1
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +2 -50
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/http.py +7 -8
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_completion.py +274 -0
- zrb/builtin/llm/chat_session.py +152 -85
- zrb/builtin/llm/chat_session_cmd.py +288 -0
- zrb/builtin/llm/chat_trigger.py +79 -0
- zrb/builtin/llm/history.py +7 -9
- zrb/builtin/llm/llm_ask.py +221 -98
- zrb/builtin/llm/tool/api.py +74 -52
- zrb/builtin/llm/tool/cli.py +46 -17
- zrb/builtin/llm/tool/code.py +71 -90
- zrb/builtin/llm/tool/file.py +301 -241
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +38 -8
- zrb/builtin/llm/tool/sub_agent.py +67 -50
- zrb/builtin/llm/tool/web.py +146 -122
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/setup/latex/ubuntu.py +1 -0
- zrb/builtin/setup/ubuntu.py +1 -1
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/builtin/todo.py +13 -2
- zrb/config/config.py +614 -0
- zrb/config/default_prompt/file_extractor_system_prompt.md +112 -0
- zrb/config/default_prompt/interactive_system_prompt.md +29 -0
- zrb/config/default_prompt/persona.md +1 -0
- zrb/config/default_prompt/repo_extractor_system_prompt.md +112 -0
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +29 -0
- zrb/config/default_prompt/summarization_prompt.md +57 -0
- zrb/config/default_prompt/system_prompt.md +38 -0
- zrb/config/llm_config.py +339 -0
- zrb/config/llm_context/config.py +166 -0
- zrb/config/llm_context/config_parser.py +40 -0
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +190 -0
- zrb/{runner → config}/web_auth_config.py +17 -22
- zrb/context/any_shared_context.py +17 -1
- zrb/context/context.py +16 -2
- zrb/context/shared_context.py +18 -8
- zrb/group/any_group.py +12 -5
- zrb/group/group.py +67 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +13 -1
- zrb/input/text_input.py +8 -25
- zrb/runner/cli.py +25 -23
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_app.py +3 -3
- zrb/runner/web_route/docs_route.py +1 -1
- zrb/runner/web_route/error_page/serve_default_404.py +1 -1
- zrb/runner/web_route/error_page/show_error_page.py +1 -1
- zrb/runner/web_route/home_page/home_page_route.py +2 -2
- zrb/runner/web_route/login_api_route.py +1 -1
- zrb/runner/web_route/login_page/login_page_route.py +2 -2
- zrb/runner/web_route/logout_api_route.py +1 -1
- zrb/runner/web_route/logout_page/logout_page_route.py +2 -2
- zrb/runner/web_route/node_page/group/show_group_page.py +1 -1
- zrb/runner/web_route/node_page/node_page_route.py +1 -1
- zrb/runner/web_route/node_page/task/show_task_page.py +1 -1
- zrb/runner/web_route/refresh_token_api_route.py +1 -1
- zrb/runner/web_route/static/static_route.py +1 -1
- zrb/runner/web_route/task_input_api_route.py +6 -6
- zrb/runner/web_route/task_session_api_route.py +20 -12
- zrb/runner/web_util/cookie.py +1 -1
- zrb/runner/web_util/token.py +1 -1
- zrb/runner/web_util/user.py +8 -4
- zrb/session/any_session.py +24 -17
- zrb/session/session.py +50 -25
- zrb/session_state_logger/any_session_state_logger.py +9 -4
- zrb/session_state_logger/file_session_state_logger.py +16 -6
- zrb/session_state_logger/session_state_logger_factory.py +1 -1
- zrb/task/any_task.py +30 -9
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/cmd_task.py +1 -1
- zrb/task/llm/agent.py +154 -161
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/config.py +47 -18
- zrb/task/llm/conversation_history.py +209 -0
- zrb/task/llm/conversation_history_model.py +67 -0
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/error.py +24 -10
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +11 -166
- zrb/task/llm/print_node.py +193 -69
- zrb/task/llm/prompt.py +242 -45
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_wrapper.py +260 -57
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +182 -171
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +26 -11
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +54 -39
- zrb/util/callable.py +23 -0
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +29 -11
- zrb/util/git.py +8 -11
- zrb/util/git_diff_model.py +10 -0
- zrb/util/git_subtree.py +9 -14
- zrb/util/git_subtree_model.py +32 -0
- zrb/util/init_path.py +1 -1
- zrb/util/markdown.py +62 -0
- zrb/util/string/conversion.py +2 -2
- zrb/util/todo.py +17 -50
- zrb/util/todo_model.py +46 -0
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- zrb-1.21.29.dist-info/METADATA +270 -0
- {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/RECORD +140 -98
- {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
- zrb/config.py +0 -335
- zrb/llm_config.py +0 -411
- zrb/llm_rate_limitter.py +0 -125
- zrb/task/llm/context.py +0 -102
- zrb/task/llm/context_enrichment.py +0 -199
- zrb/task/llm/history.py +0 -211
- zrb-1.8.10.dist-info/METADATA +0 -264
- {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
zrb/task/base/monitoring.py
CHANGED
|
@@ -17,9 +17,13 @@ async def monitor_task_readiness(
|
|
|
17
17
|
"""
|
|
18
18
|
ctx = task.get_ctx(session)
|
|
19
19
|
readiness_checks = task.readiness_checks
|
|
20
|
-
readiness_check_period =
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
readiness_check_period = (
|
|
21
|
+
task._readiness_check_period if task._readiness_check_period else 5.0
|
|
22
|
+
)
|
|
23
|
+
readiness_failure_threshold = (
|
|
24
|
+
task._readiness_failure_threshold if task._readiness_failure_threshold else 1
|
|
25
|
+
)
|
|
26
|
+
readiness_timeout = task._readiness_timeout if task._readiness_timeout else 60
|
|
23
27
|
|
|
24
28
|
if not readiness_checks:
|
|
25
29
|
ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
|
|
@@ -41,8 +45,9 @@ async def monitor_task_readiness(
|
|
|
41
45
|
session.get_task_status(check).reset_history()
|
|
42
46
|
session.get_task_status(check).reset()
|
|
43
47
|
# Clear previous XCom data for the check task if needed
|
|
44
|
-
check_xcom: Xcom = ctx.xcom.get(check.name)
|
|
45
|
-
check_xcom
|
|
48
|
+
check_xcom: Xcom | None = ctx.xcom.get(check.name)
|
|
49
|
+
if check_xcom is not None:
|
|
50
|
+
check_xcom.clear()
|
|
46
51
|
|
|
47
52
|
readiness_check_coros = [
|
|
48
53
|
run_async(check.exec_chain(session)) for check in readiness_checks
|
|
@@ -77,7 +82,7 @@ async def monitor_task_readiness(
|
|
|
77
82
|
)
|
|
78
83
|
# Ensure check tasks are marked as failed on timeout
|
|
79
84
|
for check in readiness_checks:
|
|
80
|
-
if not session.get_task_status(check).
|
|
85
|
+
if not session.get_task_status(check).is_ready:
|
|
81
86
|
session.get_task_status(check).mark_as_failed()
|
|
82
87
|
|
|
83
88
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
@@ -92,7 +97,7 @@ async def monitor_task_readiness(
|
|
|
92
97
|
)
|
|
93
98
|
# Mark checks as failed
|
|
94
99
|
for check in readiness_checks:
|
|
95
|
-
if not session.get_task_status(check).
|
|
100
|
+
if not session.get_task_status(check).is_ready:
|
|
96
101
|
session.get_task_status(check).mark_as_failed()
|
|
97
102
|
|
|
98
103
|
# If failure threshold is reached
|
zrb/task/base_task.py
CHANGED
|
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
|
|
|
21
21
|
)
|
|
22
22
|
from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
|
|
23
23
|
from zrb.task.base.operators import handle_lshift, handle_rshift
|
|
24
|
+
from zrb.util.string.conversion import to_snake_case
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class BaseTask(AnyTask):
|
|
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
|
|
|
216
217
|
return build_task_context(self, session)
|
|
217
218
|
|
|
218
219
|
def run(
|
|
219
|
-
self,
|
|
220
|
+
self,
|
|
221
|
+
session: AnySession | None = None,
|
|
222
|
+
str_kwargs: dict[str, str] | None = None,
|
|
223
|
+
kwargs: dict[str, Any] | None = None,
|
|
220
224
|
) -> Any:
|
|
221
225
|
"""
|
|
222
226
|
Synchronously runs the task and its dependencies, handling async setup and cleanup.
|
|
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
|
|
|
235
239
|
Any: The final result of the main task execution.
|
|
236
240
|
"""
|
|
237
241
|
# Use asyncio.run() to execute the async cleanup wrapper
|
|
238
|
-
return asyncio.run(
|
|
242
|
+
return asyncio.run(
|
|
243
|
+
run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
|
|
244
|
+
)
|
|
239
245
|
|
|
240
246
|
async def async_run(
|
|
241
|
-
self,
|
|
247
|
+
self,
|
|
248
|
+
session: AnySession | None = None,
|
|
249
|
+
str_kwargs: dict[str, str] | None = None,
|
|
250
|
+
kwargs: dict[str, Any] | None = None,
|
|
242
251
|
) -> Any:
|
|
243
|
-
return await run_task_async(
|
|
252
|
+
return await run_task_async(
|
|
253
|
+
self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
|
|
254
|
+
)
|
|
244
255
|
|
|
245
256
|
async def exec_root_tasks(self, session: AnySession):
|
|
246
257
|
return await execute_root_tasks(self, session)
|
|
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
|
|
|
276
287
|
# Add definition location to the error
|
|
277
288
|
if hasattr(e, "add_note"):
|
|
278
289
|
e.add_note(additional_error_note)
|
|
279
|
-
|
|
290
|
+
elif hasattr(e, "__notes__"):
|
|
280
291
|
# fallback: use the __notes__ attribute directly
|
|
281
292
|
e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
|
|
282
293
|
raise e
|
|
294
|
+
|
|
295
|
+
def to_function(self) -> Callable[..., Any]:
|
|
296
|
+
from zrb.context.shared_context import SharedContext
|
|
297
|
+
from zrb.session.session import Session
|
|
298
|
+
|
|
299
|
+
def task_runner_fn(**kwargs) -> Any:
|
|
300
|
+
task_kwargs = self._get_func_kwargs(kwargs)
|
|
301
|
+
shared_ctx = SharedContext()
|
|
302
|
+
session = Session(shared_ctx=shared_ctx)
|
|
303
|
+
return self.run(session=session, kwargs=task_kwargs)
|
|
304
|
+
|
|
305
|
+
task_runner_fn.__doc__ = self._create_fn_docstring()
|
|
306
|
+
task_runner_fn.__signature__ = self._create_fn_signature()
|
|
307
|
+
task_runner_fn.__name__ = self.name
|
|
308
|
+
return task_runner_fn
|
|
309
|
+
|
|
310
|
+
def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
311
|
+
fn_kwargs = {}
|
|
312
|
+
for inp in self.inputs:
|
|
313
|
+
snake_input_name = to_snake_case(inp.name)
|
|
314
|
+
if snake_input_name in kwargs:
|
|
315
|
+
fn_kwargs[inp.name] = kwargs[snake_input_name]
|
|
316
|
+
return fn_kwargs
|
|
317
|
+
|
|
318
|
+
def _create_fn_docstring(self) -> str:
|
|
319
|
+
from zrb.context.shared_context import SharedContext
|
|
320
|
+
|
|
321
|
+
stub_shared_ctx = SharedContext()
|
|
322
|
+
str_input_default_values = {}
|
|
323
|
+
for inp in self.inputs:
|
|
324
|
+
str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
|
|
325
|
+
# Create docstring
|
|
326
|
+
doc = f"{self.description}\n\n"
|
|
327
|
+
if len(self.inputs) > 0:
|
|
328
|
+
doc += "Args:\n"
|
|
329
|
+
for inp in self.inputs:
|
|
330
|
+
str_input_default = str_input_default_values.get(inp.name, "")
|
|
331
|
+
doc += (
|
|
332
|
+
f" {inp.name}: {inp.description} (default: {str_input_default})"
|
|
333
|
+
)
|
|
334
|
+
doc += "\n"
|
|
335
|
+
return doc
|
|
336
|
+
|
|
337
|
+
def _create_fn_signature(self) -> inspect.Signature:
|
|
338
|
+
params = []
|
|
339
|
+
for inp in self.inputs:
|
|
340
|
+
params.append(
|
|
341
|
+
inspect.Parameter(
|
|
342
|
+
name=to_snake_case(inp.name),
|
|
343
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
return inspect.Signature(params)
|
zrb/task/base_trigger.py
CHANGED
|
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
|
|
|
36
36
|
input: list[AnyInput | None] | AnyInput | None = None,
|
|
37
37
|
env: list[AnyEnv | None] | AnyEnv | None = None,
|
|
38
38
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
|
39
|
-
execute_condition: bool | str | Callable[[
|
|
39
|
+
execute_condition: bool | str | Callable[[AnyContext], bool] = True,
|
|
40
40
|
queue_name: fstring | None = None,
|
|
41
41
|
callback: list[AnyCallback] | AnyCallback = [],
|
|
42
42
|
retries: int = 2,
|
|
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
|
|
|
127
127
|
return self._callbacks
|
|
128
128
|
|
|
129
129
|
async def exec_root_tasks(self, session: AnySession):
|
|
130
|
-
exchange_xcom = self.
|
|
130
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
131
131
|
exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
|
|
132
132
|
return await super().exec_root_tasks(session)
|
|
133
133
|
|
|
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
|
|
|
136
136
|
session.defer_coro(coro)
|
|
137
137
|
|
|
138
138
|
async def _fanout_and_trigger_callback(self, session: AnySession):
|
|
139
|
-
|
|
140
|
-
data = exchange_xcom.pop()
|
|
139
|
+
data = self.pop_exchange_xcom(session)
|
|
141
140
|
coros = []
|
|
142
141
|
for callback in self.callbacks:
|
|
143
142
|
xcom_dict = DotDict({self.queue_name: Xcom([data])})
|
|
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
|
|
|
156
155
|
)
|
|
157
156
|
await asyncio.gather(*coros)
|
|
158
157
|
|
|
159
|
-
def
|
|
158
|
+
def _get_exchange_xcom(self, session: AnySession) -> Xcom:
|
|
160
159
|
shared_ctx = session.shared_ctx
|
|
161
160
|
if self.queue_name not in shared_ctx.xcom:
|
|
162
161
|
shared_ctx.xcom[self.queue_name] = Xcom()
|
|
163
162
|
return shared_ctx.xcom[self.queue_name]
|
|
163
|
+
|
|
164
|
+
def push_exchange_xcom(self, session: AnySession, data: Any):
|
|
165
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
166
|
+
exchange_xcom.push(data)
|
|
167
|
+
|
|
168
|
+
def pop_exchange_xcom(self, session: AnySession) -> Any:
|
|
169
|
+
exchange_xcom = self._get_exchange_xcom(session)
|
|
170
|
+
return exchange_xcom.pop()
|
zrb/task/cmd_task.py
CHANGED
|
@@ -4,7 +4,7 @@ from functools import partial
|
|
|
4
4
|
from zrb.attr.type import BoolAttr, IntAttr, StrAttr
|
|
5
5
|
from zrb.cmd.cmd_result import CmdResult
|
|
6
6
|
from zrb.cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
|
|
7
|
-
from zrb.config import CFG
|
|
7
|
+
from zrb.config.config import CFG
|
|
8
8
|
from zrb.context.any_context import AnyContext
|
|
9
9
|
from zrb.env.any_env import AnyEnv
|
|
10
10
|
from zrb.input.any_input import AnyInput
|
zrb/task/llm/agent.py
CHANGED
|
@@ -1,211 +1,204 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass
|
|
2
4
|
from typing import TYPE_CHECKING, Any
|
|
3
5
|
|
|
6
|
+
from zrb.config.llm_rate_limitter import LLMRateLimitter
|
|
7
|
+
from zrb.context.any_context import AnyContext
|
|
8
|
+
from zrb.task.llm.history_processor import create_summarize_history_processor
|
|
9
|
+
from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
|
|
10
|
+
|
|
4
11
|
if TYPE_CHECKING:
|
|
5
12
|
from pydantic_ai import Agent, Tool
|
|
6
|
-
from pydantic_ai.
|
|
7
|
-
from pydantic_ai.mcp import MCPServer
|
|
13
|
+
from pydantic_ai._agent_graph import HistoryProcessor
|
|
8
14
|
from pydantic_ai.models import Model
|
|
15
|
+
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
9
16
|
from pydantic_ai.settings import ModelSettings
|
|
10
|
-
|
|
11
|
-
Agent = Any
|
|
12
|
-
Tool = Any
|
|
13
|
-
AgentRun = Any
|
|
14
|
-
MCPServer = Any
|
|
15
|
-
ModelMessagesTypeAdapter = Any
|
|
16
|
-
Model = Any
|
|
17
|
-
ModelSettings = Any
|
|
18
|
-
|
|
19
|
-
import json
|
|
20
|
-
|
|
21
|
-
from zrb.context.any_context import AnyContext
|
|
22
|
-
from zrb.context.any_shared_context import AnySharedContext
|
|
23
|
-
from zrb.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
|
|
24
|
-
from zrb.task.llm.error import extract_api_error_details
|
|
25
|
-
from zrb.task.llm.print_node import print_node
|
|
26
|
-
from zrb.task.llm.tool_wrapper import wrap_tool
|
|
27
|
-
from zrb.task.llm.typing import ListOfDict
|
|
17
|
+
from pydantic_ai.toolsets import AbstractToolset
|
|
28
18
|
|
|
29
|
-
ToolOrCallable = Tool | Callable
|
|
19
|
+
ToolOrCallable = Tool | Callable
|
|
30
20
|
|
|
31
21
|
|
|
32
22
|
def create_agent_instance(
|
|
33
23
|
ctx: AnyContext,
|
|
34
|
-
model: str | Model
|
|
24
|
+
model: "str | Model",
|
|
25
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
26
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
35
27
|
system_prompt: str = "",
|
|
36
|
-
model_settings: ModelSettings | None = None,
|
|
37
|
-
tools: list[ToolOrCallable] = [],
|
|
38
|
-
|
|
28
|
+
model_settings: "ModelSettings | None" = None,
|
|
29
|
+
tools: list["ToolOrCallable"] = [],
|
|
30
|
+
toolsets: list["AbstractToolset[None]"] = [],
|
|
39
31
|
retries: int = 3,
|
|
40
|
-
|
|
32
|
+
yolo_mode: bool | list[str] | None = None,
|
|
33
|
+
summarization_model: "Model | str | None" = None,
|
|
34
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
35
|
+
summarization_system_prompt: str | None = None,
|
|
36
|
+
summarization_retries: int = 2,
|
|
37
|
+
summarization_token_threshold: int | None = None,
|
|
38
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
39
|
+
auto_summarize: bool = True,
|
|
40
|
+
) -> "Agent[None, Any]":
|
|
41
41
|
"""Creates a new Agent instance with configured tools and servers."""
|
|
42
|
-
from pydantic_ai import Agent, Tool
|
|
42
|
+
from pydantic_ai import Agent, RunContext, Tool
|
|
43
|
+
from pydantic_ai.tools import GenerateToolJsonSchema
|
|
44
|
+
from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ConfirmationWrapperToolset(WrapperToolset):
|
|
48
|
+
ctx: AnyContext
|
|
49
|
+
yolo_mode: bool | list[str]
|
|
50
|
+
|
|
51
|
+
async def call_tool(
|
|
52
|
+
self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
|
|
53
|
+
) -> Any:
|
|
54
|
+
# The `tool` object is passed in. Use it for inspection.
|
|
55
|
+
# Define a temporary function that performs the actual tool call.
|
|
56
|
+
async def execute_delegated_tool_call(**params):
|
|
57
|
+
# Pass all arguments down the chain.
|
|
58
|
+
return await self.wrapped.call_tool(name, tool_args, ctx, tool)
|
|
59
|
+
|
|
60
|
+
# For the confirmation UI, make our temporary function look like the real one.
|
|
61
|
+
try:
|
|
62
|
+
execute_delegated_tool_call.__name__ = name
|
|
63
|
+
execute_delegated_tool_call.__doc__ = tool.function.__doc__
|
|
64
|
+
execute_delegated_tool_call.__signature__ = inspect.signature(
|
|
65
|
+
tool.function
|
|
66
|
+
)
|
|
67
|
+
except (AttributeError, TypeError):
|
|
68
|
+
pass # Ignore if we can't inspect the original function
|
|
69
|
+
# Use the existing wrap_func to get the confirmation logic
|
|
70
|
+
wrapped_executor = wrap_func(
|
|
71
|
+
execute_delegated_tool_call, self.ctx, self.yolo_mode
|
|
72
|
+
)
|
|
73
|
+
# Call the wrapped executor. This will trigger the confirmation prompt.
|
|
74
|
+
return await wrapped_executor(**tool_args)
|
|
43
75
|
|
|
76
|
+
if yolo_mode is None:
|
|
77
|
+
yolo_mode = False
|
|
44
78
|
# Normalize tools
|
|
45
79
|
tool_list = []
|
|
46
80
|
for tool_or_callable in tools:
|
|
47
81
|
if isinstance(tool_or_callable, Tool):
|
|
48
82
|
tool_list.append(tool_or_callable)
|
|
83
|
+
# Update tool's function
|
|
84
|
+
tool = tool_or_callable
|
|
85
|
+
tool_list.append(
|
|
86
|
+
Tool(
|
|
87
|
+
function=wrap_func(tool.function, ctx, yolo_mode),
|
|
88
|
+
takes_ctx=tool.takes_ctx,
|
|
89
|
+
max_retries=tool.max_retries,
|
|
90
|
+
name=tool.name,
|
|
91
|
+
description=tool.description,
|
|
92
|
+
prepare=tool.prepare,
|
|
93
|
+
docstring_format=tool.docstring_format,
|
|
94
|
+
require_parameter_descriptions=tool.require_parameter_descriptions,
|
|
95
|
+
schema_generator=GenerateToolJsonSchema,
|
|
96
|
+
strict=tool.strict,
|
|
97
|
+
)
|
|
98
|
+
)
|
|
49
99
|
else:
|
|
50
|
-
#
|
|
51
|
-
tool_list.append(wrap_tool(tool_or_callable, ctx))
|
|
100
|
+
# Turn function into tool
|
|
101
|
+
tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
|
|
102
|
+
# Wrap toolsets
|
|
103
|
+
wrapped_toolsets = [
|
|
104
|
+
ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
|
|
105
|
+
for toolset in toolsets
|
|
106
|
+
]
|
|
107
|
+
# Create History processor with summarizer
|
|
108
|
+
history_processors = [] if history_processors is None else history_processors
|
|
109
|
+
if auto_summarize:
|
|
110
|
+
history_processors += [
|
|
111
|
+
create_summarize_history_processor(
|
|
112
|
+
ctx=ctx,
|
|
113
|
+
system_prompt=system_prompt,
|
|
114
|
+
rate_limitter=rate_limitter,
|
|
115
|
+
summarization_model=summarization_model,
|
|
116
|
+
summarization_model_settings=summarization_model_settings,
|
|
117
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
118
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
119
|
+
summarization_retries=summarization_retries,
|
|
120
|
+
)
|
|
121
|
+
]
|
|
52
122
|
# Return Agent
|
|
53
|
-
return Agent(
|
|
123
|
+
return Agent[None, Any](
|
|
54
124
|
model=model,
|
|
55
|
-
|
|
125
|
+
output_type=output_type,
|
|
126
|
+
instructions=system_prompt,
|
|
56
127
|
tools=tool_list,
|
|
57
|
-
|
|
128
|
+
toolsets=wrapped_toolsets,
|
|
58
129
|
model_settings=model_settings,
|
|
59
130
|
retries=retries,
|
|
131
|
+
history_processors=history_processors,
|
|
60
132
|
)
|
|
61
133
|
|
|
62
134
|
|
|
63
135
|
def get_agent(
|
|
64
136
|
ctx: AnyContext,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
137
|
+
model: "str | Model",
|
|
138
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
139
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
140
|
+
system_prompt: str = "",
|
|
141
|
+
model_settings: "ModelSettings | None" = None,
|
|
69
142
|
tools_attr: (
|
|
70
|
-
list[ToolOrCallable] | Callable[[
|
|
71
|
-
),
|
|
72
|
-
additional_tools: list[ToolOrCallable],
|
|
73
|
-
|
|
74
|
-
|
|
143
|
+
"list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
|
|
144
|
+
) = [],
|
|
145
|
+
additional_tools: "list[ToolOrCallable]" = [],
|
|
146
|
+
toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
|
|
147
|
+
additional_toolsets: "list[AbstractToolset[None] | str]" = [],
|
|
75
148
|
retries: int = 3,
|
|
76
|
-
|
|
149
|
+
yolo_mode: bool | list[str] | None = None,
|
|
150
|
+
summarization_model: "Model | str | None" = None,
|
|
151
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
152
|
+
summarization_system_prompt: str | None = None,
|
|
153
|
+
summarization_retries: int = 2,
|
|
154
|
+
summarization_token_threshold: int | None = None,
|
|
155
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
156
|
+
) -> "Agent":
|
|
77
157
|
"""Retrieves the configured Agent instance or creates one if necessary."""
|
|
78
|
-
from pydantic_ai import Agent
|
|
79
|
-
|
|
80
|
-
# Render agent instance and return if agent_attr is already an agent
|
|
81
|
-
if isinstance(agent_attr, Agent):
|
|
82
|
-
return agent_attr
|
|
83
|
-
if callable(agent_attr):
|
|
84
|
-
agent_instance = agent_attr(ctx)
|
|
85
|
-
if not isinstance(agent_instance, Agent):
|
|
86
|
-
err_msg = (
|
|
87
|
-
"Callable agent factory did not return an Agent instance, "
|
|
88
|
-
f"got: {type(agent_instance)}"
|
|
89
|
-
)
|
|
90
|
-
raise TypeError(err_msg)
|
|
91
|
-
return agent_instance
|
|
92
158
|
# Get tools for agent
|
|
93
159
|
tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
|
|
94
160
|
tools.extend(additional_tools)
|
|
95
|
-
# Get
|
|
96
|
-
|
|
97
|
-
|
|
161
|
+
# Get Toolsets for agent
|
|
162
|
+
toolset_or_str_list = list(
|
|
163
|
+
toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
|
|
98
164
|
)
|
|
99
|
-
|
|
165
|
+
toolset_or_str_list.extend(additional_toolsets)
|
|
166
|
+
toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
|
|
100
167
|
# If no agent provided, create one using the configuration
|
|
101
168
|
return create_agent_instance(
|
|
102
169
|
ctx=ctx,
|
|
103
170
|
model=model,
|
|
171
|
+
rate_limitter=rate_limitter,
|
|
172
|
+
output_type=output_type,
|
|
104
173
|
system_prompt=system_prompt,
|
|
105
174
|
tools=tools,
|
|
106
|
-
|
|
175
|
+
toolsets=toolsets,
|
|
107
176
|
model_settings=model_settings,
|
|
108
177
|
retries=retries,
|
|
178
|
+
yolo_mode=yolo_mode,
|
|
179
|
+
summarization_model=summarization_model,
|
|
180
|
+
summarization_model_settings=summarization_model_settings,
|
|
181
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
182
|
+
summarization_retries=summarization_retries,
|
|
183
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
184
|
+
history_processors=history_processors,
|
|
109
185
|
)
|
|
110
186
|
|
|
111
187
|
|
|
112
|
-
|
|
113
|
-
ctx: AnyContext,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
The agent run result object.
|
|
131
|
-
|
|
132
|
-
Raises:
|
|
133
|
-
Exception: If any error occurs during agent execution.
|
|
134
|
-
"""
|
|
135
|
-
if max_retry < 0:
|
|
136
|
-
raise ValueError("Max retry cannot be less than 0")
|
|
137
|
-
attempt = 0
|
|
138
|
-
while attempt < max_retry:
|
|
139
|
-
try:
|
|
140
|
-
return await _run_single_agent_iteration(
|
|
141
|
-
ctx=ctx,
|
|
142
|
-
agent=agent,
|
|
143
|
-
user_prompt=user_prompt,
|
|
144
|
-
history_list=history_list,
|
|
145
|
-
rate_limitter=rate_limitter,
|
|
146
|
-
)
|
|
147
|
-
except BaseException:
|
|
148
|
-
attempt += 1
|
|
149
|
-
if attempt == max_retry:
|
|
150
|
-
raise
|
|
151
|
-
raise Exception("Max retry exceeded")
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
async def _run_single_agent_iteration(
|
|
155
|
-
ctx: AnyContext,
|
|
156
|
-
agent: Agent,
|
|
157
|
-
user_prompt: str,
|
|
158
|
-
history_list: ListOfDict,
|
|
159
|
-
rate_limitter: LLMRateLimiter | None = None,
|
|
160
|
-
) -> AgentRun:
|
|
161
|
-
from openai import APIError
|
|
162
|
-
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
163
|
-
|
|
164
|
-
agent_payload = estimate_request_payload(agent, user_prompt, history_list)
|
|
165
|
-
if rate_limitter:
|
|
166
|
-
await rate_limitter.throttle(agent_payload)
|
|
167
|
-
else:
|
|
168
|
-
await llm_rate_limitter.throttle(agent_payload)
|
|
169
|
-
|
|
170
|
-
async with agent.run_mcp_servers():
|
|
171
|
-
async with agent.iter(
|
|
172
|
-
user_prompt=user_prompt,
|
|
173
|
-
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
174
|
-
) as agent_run:
|
|
175
|
-
async for node in agent_run:
|
|
176
|
-
# Each node represents a step in the agent's execution
|
|
177
|
-
# Reference: https://ai.pydantic.dev/agents/#streaming
|
|
178
|
-
try:
|
|
179
|
-
await print_node(_get_plain_printer(ctx), agent_run, node)
|
|
180
|
-
except APIError as e:
|
|
181
|
-
# Extract detailed error information from the response
|
|
182
|
-
error_details = extract_api_error_details(e)
|
|
183
|
-
ctx.log_error(f"API Error: {error_details}")
|
|
184
|
-
raise
|
|
185
|
-
except Exception as e:
|
|
186
|
-
ctx.log_error(f"Error processing node: {str(e)}")
|
|
187
|
-
ctx.log_error(f"Error type: {type(e).__name__}")
|
|
188
|
-
raise
|
|
189
|
-
return agent_run
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def estimate_request_payload(
|
|
193
|
-
agent: Agent, user_prompt: str, history_list: ListOfDict
|
|
194
|
-
) -> str:
|
|
195
|
-
system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
|
|
196
|
-
return json.dumps(
|
|
197
|
-
[
|
|
198
|
-
{"role": "system", "content": "\n".join(system_prompts)},
|
|
199
|
-
*history_list,
|
|
200
|
-
{"role": "user", "content": user_prompt},
|
|
201
|
-
]
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
def _get_plain_printer(ctx: AnyContext):
|
|
206
|
-
def printer(*args, **kwargs):
|
|
207
|
-
if "plain" not in kwargs:
|
|
208
|
-
kwargs["plain"] = True
|
|
209
|
-
return ctx.print(*args, **kwargs)
|
|
210
|
-
|
|
211
|
-
return printer
|
|
188
|
+
def _render_toolset_or_str_list(
|
|
189
|
+
ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
|
|
190
|
+
) -> list["AbstractToolset[None]"]:
|
|
191
|
+
from pydantic_ai.mcp import load_mcp_servers
|
|
192
|
+
|
|
193
|
+
toolsets = []
|
|
194
|
+
for toolset_or_str in toolset_or_str_list:
|
|
195
|
+
if isinstance(toolset_or_str, str):
|
|
196
|
+
try:
|
|
197
|
+
servers = load_mcp_servers(toolset_or_str)
|
|
198
|
+
for server in servers:
|
|
199
|
+
toolsets.append(server)
|
|
200
|
+
except Exception as e:
|
|
201
|
+
ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
|
|
202
|
+
continue
|
|
203
|
+
toolsets.append(toolset_or_str)
|
|
204
|
+
return toolsets
|