zrb 1.8.10__py3-none-any.whl → 1.21.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (147) hide show
  1. zrb/__init__.py +126 -113
  2. zrb/__main__.py +1 -1
  3. zrb/attr/type.py +10 -7
  4. zrb/builtin/__init__.py +2 -50
  5. zrb/builtin/git.py +12 -1
  6. zrb/builtin/group.py +31 -15
  7. zrb/builtin/http.py +7 -8
  8. zrb/builtin/llm/attachment.py +40 -0
  9. zrb/builtin/llm/chat_completion.py +274 -0
  10. zrb/builtin/llm/chat_session.py +152 -85
  11. zrb/builtin/llm/chat_session_cmd.py +288 -0
  12. zrb/builtin/llm/chat_trigger.py +79 -0
  13. zrb/builtin/llm/history.py +7 -9
  14. zrb/builtin/llm/llm_ask.py +221 -98
  15. zrb/builtin/llm/tool/api.py +74 -52
  16. zrb/builtin/llm/tool/cli.py +46 -17
  17. zrb/builtin/llm/tool/code.py +71 -90
  18. zrb/builtin/llm/tool/file.py +301 -241
  19. zrb/builtin/llm/tool/note.py +84 -0
  20. zrb/builtin/llm/tool/rag.py +38 -8
  21. zrb/builtin/llm/tool/sub_agent.py +67 -50
  22. zrb/builtin/llm/tool/web.py +146 -122
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  24. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  25. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  26. zrb/builtin/searxng/config/settings.yml +5671 -0
  27. zrb/builtin/searxng/start.py +21 -0
  28. zrb/builtin/setup/latex/ubuntu.py +1 -0
  29. zrb/builtin/setup/ubuntu.py +1 -1
  30. zrb/builtin/shell/autocomplete/bash.py +4 -3
  31. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  32. zrb/builtin/todo.py +13 -2
  33. zrb/config/config.py +614 -0
  34. zrb/config/default_prompt/file_extractor_system_prompt.md +112 -0
  35. zrb/config/default_prompt/interactive_system_prompt.md +29 -0
  36. zrb/config/default_prompt/persona.md +1 -0
  37. zrb/config/default_prompt/repo_extractor_system_prompt.md +112 -0
  38. zrb/config/default_prompt/repo_summarizer_system_prompt.md +29 -0
  39. zrb/config/default_prompt/summarization_prompt.md +57 -0
  40. zrb/config/default_prompt/system_prompt.md +38 -0
  41. zrb/config/llm_config.py +339 -0
  42. zrb/config/llm_context/config.py +166 -0
  43. zrb/config/llm_context/config_parser.py +40 -0
  44. zrb/config/llm_context/workflow.py +81 -0
  45. zrb/config/llm_rate_limitter.py +190 -0
  46. zrb/{runner → config}/web_auth_config.py +17 -22
  47. zrb/context/any_shared_context.py +17 -1
  48. zrb/context/context.py +16 -2
  49. zrb/context/shared_context.py +18 -8
  50. zrb/group/any_group.py +12 -5
  51. zrb/group/group.py +67 -3
  52. zrb/input/any_input.py +5 -1
  53. zrb/input/base_input.py +18 -6
  54. zrb/input/option_input.py +13 -1
  55. zrb/input/text_input.py +8 -25
  56. zrb/runner/cli.py +25 -23
  57. zrb/runner/common_util.py +24 -19
  58. zrb/runner/web_app.py +3 -3
  59. zrb/runner/web_route/docs_route.py +1 -1
  60. zrb/runner/web_route/error_page/serve_default_404.py +1 -1
  61. zrb/runner/web_route/error_page/show_error_page.py +1 -1
  62. zrb/runner/web_route/home_page/home_page_route.py +2 -2
  63. zrb/runner/web_route/login_api_route.py +1 -1
  64. zrb/runner/web_route/login_page/login_page_route.py +2 -2
  65. zrb/runner/web_route/logout_api_route.py +1 -1
  66. zrb/runner/web_route/logout_page/logout_page_route.py +2 -2
  67. zrb/runner/web_route/node_page/group/show_group_page.py +1 -1
  68. zrb/runner/web_route/node_page/node_page_route.py +1 -1
  69. zrb/runner/web_route/node_page/task/show_task_page.py +1 -1
  70. zrb/runner/web_route/refresh_token_api_route.py +1 -1
  71. zrb/runner/web_route/static/static_route.py +1 -1
  72. zrb/runner/web_route/task_input_api_route.py +6 -6
  73. zrb/runner/web_route/task_session_api_route.py +20 -12
  74. zrb/runner/web_util/cookie.py +1 -1
  75. zrb/runner/web_util/token.py +1 -1
  76. zrb/runner/web_util/user.py +8 -4
  77. zrb/session/any_session.py +24 -17
  78. zrb/session/session.py +50 -25
  79. zrb/session_state_logger/any_session_state_logger.py +9 -4
  80. zrb/session_state_logger/file_session_state_logger.py +16 -6
  81. zrb/session_state_logger/session_state_logger_factory.py +1 -1
  82. zrb/task/any_task.py +30 -9
  83. zrb/task/base/context.py +17 -9
  84. zrb/task/base/execution.py +15 -8
  85. zrb/task/base/lifecycle.py +8 -4
  86. zrb/task/base/monitoring.py +12 -7
  87. zrb/task/base_task.py +69 -5
  88. zrb/task/base_trigger.py +12 -5
  89. zrb/task/cmd_task.py +1 -1
  90. zrb/task/llm/agent.py +154 -161
  91. zrb/task/llm/agent_runner.py +152 -0
  92. zrb/task/llm/config.py +47 -18
  93. zrb/task/llm/conversation_history.py +209 -0
  94. zrb/task/llm/conversation_history_model.py +67 -0
  95. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  96. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  97. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  98. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  99. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  100. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  101. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  102. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  103. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  104. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  105. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  106. zrb/task/llm/error.py +24 -10
  107. zrb/task/llm/file_replacement.py +206 -0
  108. zrb/task/llm/file_tool_model.py +57 -0
  109. zrb/task/llm/history_processor.py +206 -0
  110. zrb/task/llm/history_summarization.py +11 -166
  111. zrb/task/llm/print_node.py +193 -69
  112. zrb/task/llm/prompt.py +242 -45
  113. zrb/task/llm/subagent_conversation_history.py +41 -0
  114. zrb/task/llm/tool_wrapper.py +260 -57
  115. zrb/task/llm/workflow.py +76 -0
  116. zrb/task/llm_task.py +182 -171
  117. zrb/task/make_task.py +2 -3
  118. zrb/task/rsync_task.py +26 -11
  119. zrb/task/scheduler.py +4 -4
  120. zrb/util/attr.py +54 -39
  121. zrb/util/callable.py +23 -0
  122. zrb/util/cli/markdown.py +12 -0
  123. zrb/util/cli/text.py +30 -0
  124. zrb/util/file.py +29 -11
  125. zrb/util/git.py +8 -11
  126. zrb/util/git_diff_model.py +10 -0
  127. zrb/util/git_subtree.py +9 -14
  128. zrb/util/git_subtree_model.py +32 -0
  129. zrb/util/init_path.py +1 -1
  130. zrb/util/markdown.py +62 -0
  131. zrb/util/string/conversion.py +2 -2
  132. zrb/util/todo.py +17 -50
  133. zrb/util/todo_model.py +46 -0
  134. zrb/util/truncate.py +23 -0
  135. zrb/util/yaml.py +204 -0
  136. zrb/xcom/xcom.py +10 -0
  137. zrb-1.21.29.dist-info/METADATA +270 -0
  138. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/RECORD +140 -98
  139. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
  140. zrb/config.py +0 -335
  141. zrb/llm_config.py +0 -411
  142. zrb/llm_rate_limitter.py +0 -125
  143. zrb/task/llm/context.py +0 -102
  144. zrb/task/llm/context_enrichment.py +0 -199
  145. zrb/task/llm/history.py +0 -211
  146. zrb-1.8.10.dist-info/METADATA +0 -264
  147. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
@@ -17,9 +17,13 @@ async def monitor_task_readiness(
17
17
  """
18
18
  ctx = task.get_ctx(session)
19
19
  readiness_checks = task.readiness_checks
20
- readiness_check_period = getattr(task, "_readiness_check_period", 5.0)
21
- readiness_failure_threshold = getattr(task, "_readiness_failure_threshold", 1)
22
- readiness_timeout = getattr(task, "_readiness_timeout", 60)
20
+ readiness_check_period = (
21
+ task._readiness_check_period if task._readiness_check_period else 5.0
22
+ )
23
+ readiness_failure_threshold = (
24
+ task._readiness_failure_threshold if task._readiness_failure_threshold else 1
25
+ )
26
+ readiness_timeout = task._readiness_timeout if task._readiness_timeout else 60
23
27
 
24
28
  if not readiness_checks:
25
29
  ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
@@ -41,8 +45,9 @@ async def monitor_task_readiness(
41
45
  session.get_task_status(check).reset_history()
42
46
  session.get_task_status(check).reset()
43
47
  # Clear previous XCom data for the check task if needed
44
- check_xcom: Xcom = ctx.xcom.get(check.name)
45
- check_xcom.clear()
48
+ check_xcom: Xcom | None = ctx.xcom.get(check.name)
49
+ if check_xcom is not None:
50
+ check_xcom.clear()
46
51
 
47
52
  readiness_check_coros = [
48
53
  run_async(check.exec_chain(session)) for check in readiness_checks
@@ -77,7 +82,7 @@ async def monitor_task_readiness(
77
82
  )
78
83
  # Ensure check tasks are marked as failed on timeout
79
84
  for check in readiness_checks:
80
- if not session.get_task_status(check).is_finished:
85
+ if not session.get_task_status(check).is_ready:
81
86
  session.get_task_status(check).mark_as_failed()
82
87
 
83
88
  except (asyncio.CancelledError, KeyboardInterrupt):
@@ -92,7 +97,7 @@ async def monitor_task_readiness(
92
97
  )
93
98
  # Mark checks as failed
94
99
  for check in readiness_checks:
95
- if not session.get_task_status(check).is_finished:
100
+ if not session.get_task_status(check).is_ready:
96
101
  session.get_task_status(check).mark_as_failed()
97
102
 
98
103
  # If failure threshold is reached
zrb/task/base_task.py CHANGED
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
21
21
  )
22
22
  from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
23
23
  from zrb.task.base.operators import handle_lshift, handle_rshift
24
+ from zrb.util.string.conversion import to_snake_case
24
25
 
25
26
 
26
27
  class BaseTask(AnyTask):
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
216
217
  return build_task_context(self, session)
217
218
 
218
219
  def run(
219
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
220
+ self,
221
+ session: AnySession | None = None,
222
+ str_kwargs: dict[str, str] | None = None,
223
+ kwargs: dict[str, Any] | None = None,
220
224
  ) -> Any:
221
225
  """
222
226
  Synchronously runs the task and its dependencies, handling async setup and cleanup.
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
235
239
  Any: The final result of the main task execution.
236
240
  """
237
241
  # Use asyncio.run() to execute the async cleanup wrapper
238
- return asyncio.run(run_and_cleanup(self, session, str_kwargs))
242
+ return asyncio.run(
243
+ run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
244
+ )
239
245
 
240
246
  async def async_run(
241
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
247
+ self,
248
+ session: AnySession | None = None,
249
+ str_kwargs: dict[str, str] | None = None,
250
+ kwargs: dict[str, Any] | None = None,
242
251
  ) -> Any:
243
- return await run_task_async(self, session, str_kwargs)
252
+ return await run_task_async(
253
+ self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
254
+ )
244
255
 
245
256
  async def exec_root_tasks(self, session: AnySession):
246
257
  return await execute_root_tasks(self, session)
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
276
287
  # Add definition location to the error
277
288
  if hasattr(e, "add_note"):
278
289
  e.add_note(additional_error_note)
279
- else:
290
+ elif hasattr(e, "__notes__"):
280
291
  # fallback: use the __notes__ attribute directly
281
292
  e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
282
293
  raise e
294
+
295
+ def to_function(self) -> Callable[..., Any]:
296
+ from zrb.context.shared_context import SharedContext
297
+ from zrb.session.session import Session
298
+
299
+ def task_runner_fn(**kwargs) -> Any:
300
+ task_kwargs = self._get_func_kwargs(kwargs)
301
+ shared_ctx = SharedContext()
302
+ session = Session(shared_ctx=shared_ctx)
303
+ return self.run(session=session, kwargs=task_kwargs)
304
+
305
+ task_runner_fn.__doc__ = self._create_fn_docstring()
306
+ task_runner_fn.__signature__ = self._create_fn_signature()
307
+ task_runner_fn.__name__ = self.name
308
+ return task_runner_fn
309
+
310
+ def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
311
+ fn_kwargs = {}
312
+ for inp in self.inputs:
313
+ snake_input_name = to_snake_case(inp.name)
314
+ if snake_input_name in kwargs:
315
+ fn_kwargs[inp.name] = kwargs[snake_input_name]
316
+ return fn_kwargs
317
+
318
+ def _create_fn_docstring(self) -> str:
319
+ from zrb.context.shared_context import SharedContext
320
+
321
+ stub_shared_ctx = SharedContext()
322
+ str_input_default_values = {}
323
+ for inp in self.inputs:
324
+ str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
325
+ # Create docstring
326
+ doc = f"{self.description}\n\n"
327
+ if len(self.inputs) > 0:
328
+ doc += "Args:\n"
329
+ for inp in self.inputs:
330
+ str_input_default = str_input_default_values.get(inp.name, "")
331
+ doc += (
332
+ f" {inp.name}: {inp.description} (default: {str_input_default})"
333
+ )
334
+ doc += "\n"
335
+ return doc
336
+
337
+ def _create_fn_signature(self) -> inspect.Signature:
338
+ params = []
339
+ for inp in self.inputs:
340
+ params.append(
341
+ inspect.Parameter(
342
+ name=to_snake_case(inp.name),
343
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
344
+ )
345
+ )
346
+ return inspect.Signature(params)
zrb/task/base_trigger.py CHANGED
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
36
36
  input: list[AnyInput | None] | AnyInput | None = None,
37
37
  env: list[AnyEnv | None] | AnyEnv | None = None,
38
38
  action: fstring | Callable[[AnyContext], Any] | None = None,
39
- execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
39
+ execute_condition: bool | str | Callable[[AnyContext], bool] = True,
40
40
  queue_name: fstring | None = None,
41
41
  callback: list[AnyCallback] | AnyCallback = [],
42
42
  retries: int = 2,
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
127
127
  return self._callbacks
128
128
 
129
129
  async def exec_root_tasks(self, session: AnySession):
130
- exchange_xcom = self.get_exchange_xcom(session)
130
+ exchange_xcom = self._get_exchange_xcom(session)
131
131
  exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
132
132
  return await super().exec_root_tasks(session)
133
133
 
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
136
136
  session.defer_coro(coro)
137
137
 
138
138
  async def _fanout_and_trigger_callback(self, session: AnySession):
139
- exchange_xcom = self.get_exchange_xcom(session)
140
- data = exchange_xcom.pop()
139
+ data = self.pop_exchange_xcom(session)
141
140
  coros = []
142
141
  for callback in self.callbacks:
143
142
  xcom_dict = DotDict({self.queue_name: Xcom([data])})
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
156
155
  )
157
156
  await asyncio.gather(*coros)
158
157
 
159
- def get_exchange_xcom(self, session: AnySession) -> Xcom:
158
+ def _get_exchange_xcom(self, session: AnySession) -> Xcom:
160
159
  shared_ctx = session.shared_ctx
161
160
  if self.queue_name not in shared_ctx.xcom:
162
161
  shared_ctx.xcom[self.queue_name] = Xcom()
163
162
  return shared_ctx.xcom[self.queue_name]
163
+
164
+ def push_exchange_xcom(self, session: AnySession, data: Any):
165
+ exchange_xcom = self._get_exchange_xcom(session)
166
+ exchange_xcom.push(data)
167
+
168
+ def pop_exchange_xcom(self, session: AnySession) -> Any:
169
+ exchange_xcom = self._get_exchange_xcom(session)
170
+ return exchange_xcom.pop()
zrb/task/cmd_task.py CHANGED
@@ -4,7 +4,7 @@ from functools import partial
4
4
  from zrb.attr.type import BoolAttr, IntAttr, StrAttr
5
5
  from zrb.cmd.cmd_result import CmdResult
6
6
  from zrb.cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
7
- from zrb.config import CFG
7
+ from zrb.config.config import CFG
8
8
  from zrb.context.any_context import AnyContext
9
9
  from zrb.env.any_env import AnyEnv
10
10
  from zrb.input.any_input import AnyInput
zrb/task/llm/agent.py CHANGED
@@ -1,211 +1,204 @@
1
+ import inspect
1
2
  from collections.abc import Callable
3
+ from dataclasses import dataclass
2
4
  from typing import TYPE_CHECKING, Any
3
5
 
6
+ from zrb.config.llm_rate_limitter import LLMRateLimitter
7
+ from zrb.context.any_context import AnyContext
8
+ from zrb.task.llm.history_processor import create_summarize_history_processor
9
+ from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
10
+
4
11
  if TYPE_CHECKING:
5
12
  from pydantic_ai import Agent, Tool
6
- from pydantic_ai.agent import AgentRun
7
- from pydantic_ai.mcp import MCPServer
13
+ from pydantic_ai._agent_graph import HistoryProcessor
8
14
  from pydantic_ai.models import Model
15
+ from pydantic_ai.output import OutputDataT, OutputSpec
9
16
  from pydantic_ai.settings import ModelSettings
10
- else:
11
- Agent = Any
12
- Tool = Any
13
- AgentRun = Any
14
- MCPServer = Any
15
- ModelMessagesTypeAdapter = Any
16
- Model = Any
17
- ModelSettings = Any
18
-
19
- import json
20
-
21
- from zrb.context.any_context import AnyContext
22
- from zrb.context.any_shared_context import AnySharedContext
23
- from zrb.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
24
- from zrb.task.llm.error import extract_api_error_details
25
- from zrb.task.llm.print_node import print_node
26
- from zrb.task.llm.tool_wrapper import wrap_tool
27
- from zrb.task.llm.typing import ListOfDict
17
+ from pydantic_ai.toolsets import AbstractToolset
28
18
 
29
- ToolOrCallable = Tool | Callable
19
+ ToolOrCallable = Tool | Callable
30
20
 
31
21
 
32
22
  def create_agent_instance(
33
23
  ctx: AnyContext,
34
- model: str | Model | None = None,
24
+ model: "str | Model",
25
+ rate_limitter: LLMRateLimitter | None = None,
26
+ output_type: "OutputSpec[OutputDataT]" = str,
35
27
  system_prompt: str = "",
36
- model_settings: ModelSettings | None = None,
37
- tools: list[ToolOrCallable] = [],
38
- mcp_servers: list[MCPServer] = [],
28
+ model_settings: "ModelSettings | None" = None,
29
+ tools: list["ToolOrCallable"] = [],
30
+ toolsets: list["AbstractToolset[None]"] = [],
39
31
  retries: int = 3,
40
- ) -> Agent:
32
+ yolo_mode: bool | list[str] | None = None,
33
+ summarization_model: "Model | str | None" = None,
34
+ summarization_model_settings: "ModelSettings | None" = None,
35
+ summarization_system_prompt: str | None = None,
36
+ summarization_retries: int = 2,
37
+ summarization_token_threshold: int | None = None,
38
+ history_processors: list["HistoryProcessor"] | None = None,
39
+ auto_summarize: bool = True,
40
+ ) -> "Agent[None, Any]":
41
41
  """Creates a new Agent instance with configured tools and servers."""
42
- from pydantic_ai import Agent, Tool
42
+ from pydantic_ai import Agent, RunContext, Tool
43
+ from pydantic_ai.tools import GenerateToolJsonSchema
44
+ from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
45
+
46
+ @dataclass
47
+ class ConfirmationWrapperToolset(WrapperToolset):
48
+ ctx: AnyContext
49
+ yolo_mode: bool | list[str]
50
+
51
+ async def call_tool(
52
+ self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
53
+ ) -> Any:
54
+ # The `tool` object is passed in. Use it for inspection.
55
+ # Define a temporary function that performs the actual tool call.
56
+ async def execute_delegated_tool_call(**params):
57
+ # Pass all arguments down the chain.
58
+ return await self.wrapped.call_tool(name, tool_args, ctx, tool)
59
+
60
+ # For the confirmation UI, make our temporary function look like the real one.
61
+ try:
62
+ execute_delegated_tool_call.__name__ = name
63
+ execute_delegated_tool_call.__doc__ = tool.function.__doc__
64
+ execute_delegated_tool_call.__signature__ = inspect.signature(
65
+ tool.function
66
+ )
67
+ except (AttributeError, TypeError):
68
+ pass # Ignore if we can't inspect the original function
69
+ # Use the existing wrap_func to get the confirmation logic
70
+ wrapped_executor = wrap_func(
71
+ execute_delegated_tool_call, self.ctx, self.yolo_mode
72
+ )
73
+ # Call the wrapped executor. This will trigger the confirmation prompt.
74
+ return await wrapped_executor(**tool_args)
43
75
 
76
+ if yolo_mode is None:
77
+ yolo_mode = False
44
78
  # Normalize tools
45
79
  tool_list = []
46
80
  for tool_or_callable in tools:
47
81
  if isinstance(tool_or_callable, Tool):
48
82
  tool_list.append(tool_or_callable)
83
+ # Update tool's function
84
+ tool = tool_or_callable
85
+ tool_list.append(
86
+ Tool(
87
+ function=wrap_func(tool.function, ctx, yolo_mode),
88
+ takes_ctx=tool.takes_ctx,
89
+ max_retries=tool.max_retries,
90
+ name=tool.name,
91
+ description=tool.description,
92
+ prepare=tool.prepare,
93
+ docstring_format=tool.docstring_format,
94
+ require_parameter_descriptions=tool.require_parameter_descriptions,
95
+ schema_generator=GenerateToolJsonSchema,
96
+ strict=tool.strict,
97
+ )
98
+ )
49
99
  else:
50
- # Pass ctx to wrap_tool
51
- tool_list.append(wrap_tool(tool_or_callable, ctx))
100
+ # Turn function into tool
101
+ tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
102
+ # Wrap toolsets
103
+ wrapped_toolsets = [
104
+ ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
105
+ for toolset in toolsets
106
+ ]
107
+ # Create History processor with summarizer
108
+ history_processors = [] if history_processors is None else history_processors
109
+ if auto_summarize:
110
+ history_processors += [
111
+ create_summarize_history_processor(
112
+ ctx=ctx,
113
+ system_prompt=system_prompt,
114
+ rate_limitter=rate_limitter,
115
+ summarization_model=summarization_model,
116
+ summarization_model_settings=summarization_model_settings,
117
+ summarization_system_prompt=summarization_system_prompt,
118
+ summarization_token_threshold=summarization_token_threshold,
119
+ summarization_retries=summarization_retries,
120
+ )
121
+ ]
52
122
  # Return Agent
53
- return Agent(
123
+ return Agent[None, Any](
54
124
  model=model,
55
- system_prompt=system_prompt,
125
+ output_type=output_type,
126
+ instructions=system_prompt,
56
127
  tools=tool_list,
57
- mcp_servers=mcp_servers,
128
+ toolsets=wrapped_toolsets,
58
129
  model_settings=model_settings,
59
130
  retries=retries,
131
+ history_processors=history_processors,
60
132
  )
61
133
 
62
134
 
63
135
  def get_agent(
64
136
  ctx: AnyContext,
65
- agent_attr: Agent | Callable[[AnySharedContext], Agent] | None,
66
- model: str | Model | None,
67
- system_prompt: str,
68
- model_settings: ModelSettings | None,
137
+ model: "str | Model",
138
+ rate_limitter: LLMRateLimitter | None = None,
139
+ output_type: "OutputSpec[OutputDataT]" = str,
140
+ system_prompt: str = "",
141
+ model_settings: "ModelSettings | None" = None,
69
142
  tools_attr: (
70
- list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
71
- ),
72
- additional_tools: list[ToolOrCallable],
73
- mcp_servers_attr: list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]],
74
- additional_mcp_servers: list[MCPServer],
143
+ "list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
144
+ ) = [],
145
+ additional_tools: "list[ToolOrCallable]" = [],
146
+ toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
147
+ additional_toolsets: "list[AbstractToolset[None] | str]" = [],
75
148
  retries: int = 3,
76
- ) -> Agent:
149
+ yolo_mode: bool | list[str] | None = None,
150
+ summarization_model: "Model | str | None" = None,
151
+ summarization_model_settings: "ModelSettings | None" = None,
152
+ summarization_system_prompt: str | None = None,
153
+ summarization_retries: int = 2,
154
+ summarization_token_threshold: int | None = None,
155
+ history_processors: list["HistoryProcessor"] | None = None,
156
+ ) -> "Agent":
77
157
  """Retrieves the configured Agent instance or creates one if necessary."""
78
- from pydantic_ai import Agent
79
-
80
- # Render agent instance and return if agent_attr is already an agent
81
- if isinstance(agent_attr, Agent):
82
- return agent_attr
83
- if callable(agent_attr):
84
- agent_instance = agent_attr(ctx)
85
- if not isinstance(agent_instance, Agent):
86
- err_msg = (
87
- "Callable agent factory did not return an Agent instance, "
88
- f"got: {type(agent_instance)}"
89
- )
90
- raise TypeError(err_msg)
91
- return agent_instance
92
158
  # Get tools for agent
93
159
  tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
94
160
  tools.extend(additional_tools)
95
- # Get MCP Servers for agent
96
- mcp_servers = list(
97
- mcp_servers_attr(ctx) if callable(mcp_servers_attr) else mcp_servers_attr
161
+ # Get Toolsets for agent
162
+ toolset_or_str_list = list(
163
+ toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
98
164
  )
99
- mcp_servers.extend(additional_mcp_servers)
165
+ toolset_or_str_list.extend(additional_toolsets)
166
+ toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
100
167
  # If no agent provided, create one using the configuration
101
168
  return create_agent_instance(
102
169
  ctx=ctx,
103
170
  model=model,
171
+ rate_limitter=rate_limitter,
172
+ output_type=output_type,
104
173
  system_prompt=system_prompt,
105
174
  tools=tools,
106
- mcp_servers=mcp_servers,
175
+ toolsets=toolsets,
107
176
  model_settings=model_settings,
108
177
  retries=retries,
178
+ yolo_mode=yolo_mode,
179
+ summarization_model=summarization_model,
180
+ summarization_model_settings=summarization_model_settings,
181
+ summarization_system_prompt=summarization_system_prompt,
182
+ summarization_retries=summarization_retries,
183
+ summarization_token_threshold=summarization_token_threshold,
184
+ history_processors=history_processors,
109
185
  )
110
186
 
111
187
 
112
- async def run_agent_iteration(
113
- ctx: AnyContext,
114
- agent: Agent,
115
- user_prompt: str,
116
- history_list: ListOfDict,
117
- rate_limitter: LLMRateLimiter | None = None,
118
- max_retry: int = 2,
119
- ) -> AgentRun:
120
- """
121
- Runs a single iteration of the agent execution loop.
122
-
123
- Args:
124
- ctx: The task context.
125
- agent: The Pydantic AI agent instance.
126
- user_prompt: The user's input prompt.
127
- history_list: The current conversation history.
128
-
129
- Returns:
130
- The agent run result object.
131
-
132
- Raises:
133
- Exception: If any error occurs during agent execution.
134
- """
135
- if max_retry < 0:
136
- raise ValueError("Max retry cannot be less than 0")
137
- attempt = 0
138
- while attempt < max_retry:
139
- try:
140
- return await _run_single_agent_iteration(
141
- ctx=ctx,
142
- agent=agent,
143
- user_prompt=user_prompt,
144
- history_list=history_list,
145
- rate_limitter=rate_limitter,
146
- )
147
- except BaseException:
148
- attempt += 1
149
- if attempt == max_retry:
150
- raise
151
- raise Exception("Max retry exceeded")
152
-
153
-
154
- async def _run_single_agent_iteration(
155
- ctx: AnyContext,
156
- agent: Agent,
157
- user_prompt: str,
158
- history_list: ListOfDict,
159
- rate_limitter: LLMRateLimiter | None = None,
160
- ) -> AgentRun:
161
- from openai import APIError
162
- from pydantic_ai.messages import ModelMessagesTypeAdapter
163
-
164
- agent_payload = estimate_request_payload(agent, user_prompt, history_list)
165
- if rate_limitter:
166
- await rate_limitter.throttle(agent_payload)
167
- else:
168
- await llm_rate_limitter.throttle(agent_payload)
169
-
170
- async with agent.run_mcp_servers():
171
- async with agent.iter(
172
- user_prompt=user_prompt,
173
- message_history=ModelMessagesTypeAdapter.validate_python(history_list),
174
- ) as agent_run:
175
- async for node in agent_run:
176
- # Each node represents a step in the agent's execution
177
- # Reference: https://ai.pydantic.dev/agents/#streaming
178
- try:
179
- await print_node(_get_plain_printer(ctx), agent_run, node)
180
- except APIError as e:
181
- # Extract detailed error information from the response
182
- error_details = extract_api_error_details(e)
183
- ctx.log_error(f"API Error: {error_details}")
184
- raise
185
- except Exception as e:
186
- ctx.log_error(f"Error processing node: {str(e)}")
187
- ctx.log_error(f"Error type: {type(e).__name__}")
188
- raise
189
- return agent_run
190
-
191
-
192
- def estimate_request_payload(
193
- agent: Agent, user_prompt: str, history_list: ListOfDict
194
- ) -> str:
195
- system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
196
- return json.dumps(
197
- [
198
- {"role": "system", "content": "\n".join(system_prompts)},
199
- *history_list,
200
- {"role": "user", "content": user_prompt},
201
- ]
202
- )
203
-
204
-
205
- def _get_plain_printer(ctx: AnyContext):
206
- def printer(*args, **kwargs):
207
- if "plain" not in kwargs:
208
- kwargs["plain"] = True
209
- return ctx.print(*args, **kwargs)
210
-
211
- return printer
188
+ def _render_toolset_or_str_list(
189
+ ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
190
+ ) -> list["AbstractToolset[None]"]:
191
+ from pydantic_ai.mcp import load_mcp_servers
192
+
193
+ toolsets = []
194
+ for toolset_or_str in toolset_or_str_list:
195
+ if isinstance(toolset_or_str, str):
196
+ try:
197
+ servers = load_mcp_servers(toolset_or_str)
198
+ for server in servers:
199
+ toolsets.append(server)
200
+ except Exception as e:
201
+ ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
202
+ continue
203
+ toolsets.append(toolset_or_str)
204
+ return toolsets