zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (108) hide show
  1. zrb/__init__.py +2 -6
  2. zrb/attr/type.py +10 -7
  3. zrb/builtin/__init__.py +2 -0
  4. zrb/builtin/git.py +12 -1
  5. zrb/builtin/group.py +31 -15
  6. zrb/builtin/llm/attachment.py +40 -0
  7. zrb/builtin/llm/chat_completion.py +274 -0
  8. zrb/builtin/llm/chat_session.py +126 -167
  9. zrb/builtin/llm/chat_session_cmd.py +288 -0
  10. zrb/builtin/llm/chat_trigger.py +79 -0
  11. zrb/builtin/llm/history.py +4 -4
  12. zrb/builtin/llm/llm_ask.py +217 -135
  13. zrb/builtin/llm/tool/api.py +74 -70
  14. zrb/builtin/llm/tool/cli.py +35 -21
  15. zrb/builtin/llm/tool/code.py +55 -73
  16. zrb/builtin/llm/tool/file.py +278 -344
  17. zrb/builtin/llm/tool/note.py +84 -0
  18. zrb/builtin/llm/tool/rag.py +27 -34
  19. zrb/builtin/llm/tool/sub_agent.py +54 -41
  20. zrb/builtin/llm/tool/web.py +74 -98
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  23. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  24. zrb/builtin/searxng/config/settings.yml +5671 -0
  25. zrb/builtin/searxng/start.py +21 -0
  26. zrb/builtin/shell/autocomplete/bash.py +4 -3
  27. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  28. zrb/config/config.py +202 -27
  29. zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
  30. zrb/config/default_prompt/interactive_system_prompt.md +24 -30
  31. zrb/config/default_prompt/persona.md +1 -1
  32. zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
  33. zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
  34. zrb/config/default_prompt/summarization_prompt.md +57 -16
  35. zrb/config/default_prompt/system_prompt.md +36 -30
  36. zrb/config/llm_config.py +119 -23
  37. zrb/config/llm_context/config.py +127 -90
  38. zrb/config/llm_context/config_parser.py +1 -7
  39. zrb/config/llm_context/workflow.py +81 -0
  40. zrb/config/llm_rate_limitter.py +100 -47
  41. zrb/context/any_shared_context.py +7 -1
  42. zrb/context/context.py +8 -2
  43. zrb/context/shared_context.py +3 -7
  44. zrb/group/any_group.py +3 -3
  45. zrb/group/group.py +3 -3
  46. zrb/input/any_input.py +5 -1
  47. zrb/input/base_input.py +18 -6
  48. zrb/input/option_input.py +13 -1
  49. zrb/input/text_input.py +7 -24
  50. zrb/runner/cli.py +21 -20
  51. zrb/runner/common_util.py +24 -19
  52. zrb/runner/web_route/task_input_api_route.py +5 -5
  53. zrb/runner/web_util/user.py +7 -3
  54. zrb/session/any_session.py +12 -6
  55. zrb/session/session.py +39 -18
  56. zrb/task/any_task.py +24 -3
  57. zrb/task/base/context.py +17 -9
  58. zrb/task/base/execution.py +15 -8
  59. zrb/task/base/lifecycle.py +8 -4
  60. zrb/task/base/monitoring.py +12 -7
  61. zrb/task/base_task.py +69 -5
  62. zrb/task/base_trigger.py +12 -5
  63. zrb/task/llm/agent.py +128 -167
  64. zrb/task/llm/agent_runner.py +152 -0
  65. zrb/task/llm/config.py +39 -20
  66. zrb/task/llm/conversation_history.py +110 -29
  67. zrb/task/llm/conversation_history_model.py +4 -179
  68. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  69. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  70. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  71. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  72. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  73. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  74. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  75. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  76. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  77. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  78. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  79. zrb/task/llm/file_replacement.py +206 -0
  80. zrb/task/llm/file_tool_model.py +57 -0
  81. zrb/task/llm/history_processor.py +206 -0
  82. zrb/task/llm/history_summarization.py +2 -193
  83. zrb/task/llm/print_node.py +184 -64
  84. zrb/task/llm/prompt.py +175 -179
  85. zrb/task/llm/subagent_conversation_history.py +41 -0
  86. zrb/task/llm/tool_wrapper.py +226 -85
  87. zrb/task/llm/workflow.py +76 -0
  88. zrb/task/llm_task.py +109 -71
  89. zrb/task/make_task.py +2 -3
  90. zrb/task/rsync_task.py +25 -10
  91. zrb/task/scheduler.py +4 -4
  92. zrb/util/attr.py +54 -39
  93. zrb/util/cli/markdown.py +12 -0
  94. zrb/util/cli/text.py +30 -0
  95. zrb/util/file.py +12 -3
  96. zrb/util/git.py +2 -2
  97. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  98. zrb/util/string/conversion.py +1 -1
  99. zrb/util/truncate.py +23 -0
  100. zrb/util/yaml.py +204 -0
  101. zrb/xcom/xcom.py +10 -0
  102. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
  103. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
  104. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
  105. zrb/task/llm/default_workflow/coding.md +0 -24
  106. zrb/task/llm/default_workflow/copywriting.md +0 -17
  107. zrb/task/llm/default_workflow/researching.md +0 -18
  108. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
@@ -53,7 +53,9 @@ def check_execute_condition(task: "BaseTask", session: AnySession) -> bool:
53
53
  Evaluates the task's execute_condition attribute.
54
54
  """
55
55
  ctx = task.get_ctx(session)
56
- execute_condition_attr = getattr(task, "_execute_condition", True)
56
+ execute_condition_attr = (
57
+ task._execute_condition if task._execute_condition is not None else True
58
+ )
57
59
  return get_bool_attr(ctx, execute_condition_attr, True, auto_render=True)
58
60
 
59
61
 
@@ -63,8 +65,12 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
63
65
  """
64
66
  ctx = task.get_ctx(session)
65
67
  readiness_checks = task.readiness_checks
66
- readiness_check_delay = getattr(task, "_readiness_check_delay", 0.5)
67
- monitor_readiness = getattr(task, "_monitor_readiness", False)
68
+ readiness_check_delay = (
69
+ task._readiness_check_delay if task._readiness_check_delay is not None else 0.5
70
+ )
71
+ monitor_readiness = (
72
+ task._monitor_readiness if task._monitor_readiness is not None else False
73
+ )
68
74
 
69
75
  if not readiness_checks: # Simplified check for empty list
70
76
  ctx.log_info("No readiness checks")
@@ -140,8 +146,8 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
140
146
  handling success (triggering successors) and failure (triggering fallbacks).
141
147
  """
142
148
  ctx = task.get_ctx(session)
143
- retries = getattr(task, "_retries", 2)
144
- retry_period = getattr(task, "_retry_period", 0)
149
+ retries = task._retries if task._retries is not None else 2
150
+ retry_period = task._retry_period if task._retry_period is not None else 0
145
151
  max_attempt = retries + 1
146
152
  ctx.set_max_attempt(max_attempt)
147
153
 
@@ -163,8 +169,9 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
163
169
  session.get_task_status(task).mark_as_completed()
164
170
 
165
171
  # Store result in XCom
166
- task_xcom: Xcom = ctx.xcom.get(task.name)
167
- task_xcom.push(result)
172
+ task_xcom: Xcom | None = ctx.xcom.get(task.name)
173
+ if task_xcom is not None:
174
+ task_xcom.push(result)
168
175
 
169
176
  # Skip fallbacks and execute successors on success
170
177
  skip_fallbacks(task, session)
@@ -201,7 +208,7 @@ async def run_default_action(task: "BaseTask", ctx: AnyContext) -> Any:
201
208
  This is the default implementation called by BaseTask._exec_action.
202
209
  Subclasses like LLMTask override _exec_action with their own logic.
203
210
  """
204
- action = getattr(task, "_action", None)
211
+ action = task._action
205
212
  if action is None:
206
213
  ctx.log_debug("No action defined for this task.")
207
214
  return None
@@ -12,7 +12,8 @@ from zrb.util.run import run_async
12
12
  async def run_and_cleanup(
13
13
  task: AnyTask,
14
14
  session: AnySession | None = None,
15
- str_kwargs: dict[str, str] = {},
15
+ str_kwargs: dict[str, str] | None = None,
16
+ kwargs: dict[str, Any] | None = None,
16
17
  ) -> Any:
17
18
  """
18
19
  Wrapper for async_run that ensures session termination and cleanup of
@@ -23,7 +24,9 @@ async def run_and_cleanup(
23
24
  session = Session(shared_ctx=SharedContext())
24
25
 
25
26
  # Create the main task execution coroutine
26
- main_task_coro = asyncio.create_task(run_task_async(task, session, str_kwargs))
27
+ main_task_coro = asyncio.create_task(
28
+ run_task_async(task, session, str_kwargs, kwargs)
29
+ )
27
30
 
28
31
  try:
29
32
  result = await main_task_coro
@@ -67,7 +70,8 @@ async def run_and_cleanup(
67
70
  async def run_task_async(
68
71
  task: AnyTask,
69
72
  session: AnySession | None = None,
70
- str_kwargs: dict[str, str] = {},
73
+ str_kwargs: dict[str, str] | None = None,
74
+ kwargs: dict[str, Any] | None = None,
71
75
  ) -> Any:
72
76
  """
73
77
  Asynchronous entry point for running a task (`task.async_run()`).
@@ -77,7 +81,7 @@ async def run_task_async(
77
81
  session = Session(shared_ctx=SharedContext())
78
82
 
79
83
  # Populate shared context with inputs and environment variables
80
- fill_shared_context_inputs(task, session.shared_ctx, str_kwargs)
84
+ fill_shared_context_inputs(session.shared_ctx, task, str_kwargs, kwargs)
81
85
  fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
82
86
 
83
87
  # Start the execution chain from the root tasks
@@ -17,9 +17,13 @@ async def monitor_task_readiness(
17
17
  """
18
18
  ctx = task.get_ctx(session)
19
19
  readiness_checks = task.readiness_checks
20
- readiness_check_period = getattr(task, "_readiness_check_period", 5.0)
21
- readiness_failure_threshold = getattr(task, "_readiness_failure_threshold", 1)
22
- readiness_timeout = getattr(task, "_readiness_timeout", 60)
20
+ readiness_check_period = (
21
+ task._readiness_check_period if task._readiness_check_period else 5.0
22
+ )
23
+ readiness_failure_threshold = (
24
+ task._readiness_failure_threshold if task._readiness_failure_threshold else 1
25
+ )
26
+ readiness_timeout = task._readiness_timeout if task._readiness_timeout else 60
23
27
 
24
28
  if not readiness_checks:
25
29
  ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
@@ -41,8 +45,9 @@ async def monitor_task_readiness(
41
45
  session.get_task_status(check).reset_history()
42
46
  session.get_task_status(check).reset()
43
47
  # Clear previous XCom data for the check task if needed
44
- check_xcom: Xcom = ctx.xcom.get(check.name)
45
- check_xcom.clear()
48
+ check_xcom: Xcom | None = ctx.xcom.get(check.name)
49
+ if check_xcom is not None:
50
+ check_xcom.clear()
46
51
 
47
52
  readiness_check_coros = [
48
53
  run_async(check.exec_chain(session)) for check in readiness_checks
@@ -77,7 +82,7 @@ async def monitor_task_readiness(
77
82
  )
78
83
  # Ensure check tasks are marked as failed on timeout
79
84
  for check in readiness_checks:
80
- if not session.get_task_status(check).is_finished:
85
+ if not session.get_task_status(check).is_ready:
81
86
  session.get_task_status(check).mark_as_failed()
82
87
 
83
88
  except (asyncio.CancelledError, KeyboardInterrupt):
@@ -92,7 +97,7 @@ async def monitor_task_readiness(
92
97
  )
93
98
  # Mark checks as failed
94
99
  for check in readiness_checks:
95
- if not session.get_task_status(check).is_finished:
100
+ if not session.get_task_status(check).is_ready:
96
101
  session.get_task_status(check).mark_as_failed()
97
102
 
98
103
  # If failure threshold is reached
zrb/task/base_task.py CHANGED
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
21
21
  )
22
22
  from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
23
23
  from zrb.task.base.operators import handle_lshift, handle_rshift
24
+ from zrb.util.string.conversion import to_snake_case
24
25
 
25
26
 
26
27
  class BaseTask(AnyTask):
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
216
217
  return build_task_context(self, session)
217
218
 
218
219
  def run(
219
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
220
+ self,
221
+ session: AnySession | None = None,
222
+ str_kwargs: dict[str, str] | None = None,
223
+ kwargs: dict[str, Any] | None = None,
220
224
  ) -> Any:
221
225
  """
222
226
  Synchronously runs the task and its dependencies, handling async setup and cleanup.
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
235
239
  Any: The final result of the main task execution.
236
240
  """
237
241
  # Use asyncio.run() to execute the async cleanup wrapper
238
- return asyncio.run(run_and_cleanup(self, session, str_kwargs))
242
+ return asyncio.run(
243
+ run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
244
+ )
239
245
 
240
246
  async def async_run(
241
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
247
+ self,
248
+ session: AnySession | None = None,
249
+ str_kwargs: dict[str, str] | None = None,
250
+ kwargs: dict[str, Any] | None = None,
242
251
  ) -> Any:
243
- return await run_task_async(self, session, str_kwargs)
252
+ return await run_task_async(
253
+ self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
254
+ )
244
255
 
245
256
  async def exec_root_tasks(self, session: AnySession):
246
257
  return await execute_root_tasks(self, session)
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
276
287
  # Add definition location to the error
277
288
  if hasattr(e, "add_note"):
278
289
  e.add_note(additional_error_note)
279
- else:
290
+ elif hasattr(e, "__notes__"):
280
291
  # fallback: use the __notes__ attribute directly
281
292
  e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
282
293
  raise e
294
+
295
+ def to_function(self) -> Callable[..., Any]:
296
+ from zrb.context.shared_context import SharedContext
297
+ from zrb.session.session import Session
298
+
299
+ def task_runner_fn(**kwargs) -> Any:
300
+ task_kwargs = self._get_func_kwargs(kwargs)
301
+ shared_ctx = SharedContext()
302
+ session = Session(shared_ctx=shared_ctx)
303
+ return self.run(session=session, kwargs=task_kwargs)
304
+
305
+ task_runner_fn.__doc__ = self._create_fn_docstring()
306
+ task_runner_fn.__signature__ = self._create_fn_signature()
307
+ task_runner_fn.__name__ = self.name
308
+ return task_runner_fn
309
+
310
+ def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
311
+ fn_kwargs = {}
312
+ for inp in self.inputs:
313
+ snake_input_name = to_snake_case(inp.name)
314
+ if snake_input_name in kwargs:
315
+ fn_kwargs[inp.name] = kwargs[snake_input_name]
316
+ return fn_kwargs
317
+
318
+ def _create_fn_docstring(self) -> str:
319
+ from zrb.context.shared_context import SharedContext
320
+
321
+ stub_shared_ctx = SharedContext()
322
+ str_input_default_values = {}
323
+ for inp in self.inputs:
324
+ str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
325
+ # Create docstring
326
+ doc = f"{self.description}\n\n"
327
+ if len(self.inputs) > 0:
328
+ doc += "Args:\n"
329
+ for inp in self.inputs:
330
+ str_input_default = str_input_default_values.get(inp.name, "")
331
+ doc += (
332
+ f" {inp.name}: {inp.description} (default: {str_input_default})"
333
+ )
334
+ doc += "\n"
335
+ return doc
336
+
337
+ def _create_fn_signature(self) -> inspect.Signature:
338
+ params = []
339
+ for inp in self.inputs:
340
+ params.append(
341
+ inspect.Parameter(
342
+ name=to_snake_case(inp.name),
343
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
344
+ )
345
+ )
346
+ return inspect.Signature(params)
zrb/task/base_trigger.py CHANGED
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
36
36
  input: list[AnyInput | None] | AnyInput | None = None,
37
37
  env: list[AnyEnv | None] | AnyEnv | None = None,
38
38
  action: fstring | Callable[[AnyContext], Any] | None = None,
39
- execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
39
+ execute_condition: bool | str | Callable[[AnyContext], bool] = True,
40
40
  queue_name: fstring | None = None,
41
41
  callback: list[AnyCallback] | AnyCallback = [],
42
42
  retries: int = 2,
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
127
127
  return self._callbacks
128
128
 
129
129
  async def exec_root_tasks(self, session: AnySession):
130
- exchange_xcom = self.get_exchange_xcom(session)
130
+ exchange_xcom = self._get_exchange_xcom(session)
131
131
  exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
132
132
  return await super().exec_root_tasks(session)
133
133
 
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
136
136
  session.defer_coro(coro)
137
137
 
138
138
  async def _fanout_and_trigger_callback(self, session: AnySession):
139
- exchange_xcom = self.get_exchange_xcom(session)
140
- data = exchange_xcom.pop()
139
+ data = self.pop_exchange_xcom(session)
141
140
  coros = []
142
141
  for callback in self.callbacks:
143
142
  xcom_dict = DotDict({self.queue_name: Xcom([data])})
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
156
155
  )
157
156
  await asyncio.gather(*coros)
158
157
 
159
- def get_exchange_xcom(self, session: AnySession) -> Xcom:
158
+ def _get_exchange_xcom(self, session: AnySession) -> Xcom:
160
159
  shared_ctx = session.shared_ctx
161
160
  if self.queue_name not in shared_ctx.xcom:
162
161
  shared_ctx.xcom[self.queue_name] = Xcom()
163
162
  return shared_ctx.xcom[self.queue_name]
163
+
164
+ def push_exchange_xcom(self, session: AnySession, data: Any):
165
+ exchange_xcom = self._get_exchange_xcom(session)
166
+ exchange_xcom.push(data)
167
+
168
+ def pop_exchange_xcom(self, session: AnySession) -> Any:
169
+ exchange_xcom = self._get_exchange_xcom(session)
170
+ return exchange_xcom.pop()
zrb/task/llm/agent.py CHANGED
@@ -1,20 +1,18 @@
1
- import json
1
+ import inspect
2
2
  from collections.abc import Callable
3
+ from dataclasses import dataclass
3
4
  from typing import TYPE_CHECKING, Any
4
5
 
5
- from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
6
+ from zrb.config.llm_rate_limitter import LLMRateLimitter
6
7
  from zrb.context.any_context import AnyContext
7
- from zrb.context.any_shared_context import AnySharedContext
8
- from zrb.task.llm.error import extract_api_error_details
9
- from zrb.task.llm.print_node import print_node
8
+ from zrb.task.llm.history_processor import create_summarize_history_processor
10
9
  from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
11
- from zrb.task.llm.typing import ListOfDict
12
10
 
13
11
  if TYPE_CHECKING:
14
12
  from pydantic_ai import Agent, Tool
15
- from pydantic_ai.agent import AgentRun
16
- from pydantic_ai.messages import UserContent
13
+ from pydantic_ai._agent_graph import HistoryProcessor
17
14
  from pydantic_ai.models import Model
15
+ from pydantic_ai.output import OutputDataT, OutputSpec
18
16
  from pydantic_ai.settings import ModelSettings
19
17
  from pydantic_ai.toolsets import AbstractToolset
20
18
 
@@ -24,19 +22,59 @@ if TYPE_CHECKING:
24
22
  def create_agent_instance(
25
23
  ctx: AnyContext,
26
24
  model: "str | Model",
25
+ rate_limitter: LLMRateLimitter | None = None,
26
+ output_type: "OutputSpec[OutputDataT]" = str,
27
27
  system_prompt: str = "",
28
28
  model_settings: "ModelSettings | None" = None,
29
- tools: "list[ToolOrCallable]" = [],
30
- toolsets: list["AbstractToolset[Agent]"] = [],
29
+ tools: list["ToolOrCallable"] = [],
30
+ toolsets: list["AbstractToolset[None]"] = [],
31
31
  retries: int = 3,
32
- is_yolo_mode: bool | None = None,
33
- ) -> "Agent":
32
+ yolo_mode: bool | list[str] | None = None,
33
+ summarization_model: "Model | str | None" = None,
34
+ summarization_model_settings: "ModelSettings | None" = None,
35
+ summarization_system_prompt: str | None = None,
36
+ summarization_retries: int = 2,
37
+ summarization_token_threshold: int | None = None,
38
+ history_processors: list["HistoryProcessor"] | None = None,
39
+ auto_summarize: bool = True,
40
+ ) -> "Agent[None, Any]":
34
41
  """Creates a new Agent instance with configured tools and servers."""
35
- from pydantic_ai import Agent, Tool
42
+ from pydantic_ai import Agent, RunContext, Tool
36
43
  from pydantic_ai.tools import GenerateToolJsonSchema
44
+ from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
45
+
46
+ @dataclass
47
+ class ConfirmationWrapperToolset(WrapperToolset):
48
+ ctx: AnyContext
49
+ yolo_mode: bool | list[str]
50
+
51
+ async def call_tool(
52
+ self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
53
+ ) -> Any:
54
+ # The `tool` object is passed in. Use it for inspection.
55
+ # Define a temporary function that performs the actual tool call.
56
+ async def execute_delegated_tool_call(**params):
57
+ # Pass all arguments down the chain.
58
+ return await self.wrapped.call_tool(name, tool_args, ctx, tool)
59
+
60
+ # For the confirmation UI, make our temporary function look like the real one.
61
+ try:
62
+ execute_delegated_tool_call.__name__ = name
63
+ execute_delegated_tool_call.__doc__ = tool.function.__doc__
64
+ execute_delegated_tool_call.__signature__ = inspect.signature(
65
+ tool.function
66
+ )
67
+ except (AttributeError, TypeError):
68
+ pass # Ignore if we can't inspect the original function
69
+ # Use the existing wrap_func to get the confirmation logic
70
+ wrapped_executor = wrap_func(
71
+ execute_delegated_tool_call, self.ctx, self.yolo_mode
72
+ )
73
+ # Call the wrapped executor. This will trigger the confirmation prompt.
74
+ return await wrapped_executor(**tool_args)
37
75
 
38
- if is_yolo_mode is None:
39
- is_yolo_mode = False
76
+ if yolo_mode is None:
77
+ yolo_mode = False
40
78
  # Normalize tools
41
79
  tool_list = []
42
80
  for tool_or_callable in tools:
@@ -46,7 +84,7 @@ def create_agent_instance(
46
84
  tool = tool_or_callable
47
85
  tool_list.append(
48
86
  Tool(
49
- function=wrap_func(tool.function, ctx, is_yolo_mode),
87
+ function=wrap_func(tool.function, ctx, yolo_mode),
50
88
  takes_ctx=tool.takes_ctx,
51
89
  max_retries=tool.max_retries,
52
90
  name=tool.name,
@@ -60,184 +98,107 @@ def create_agent_instance(
60
98
  )
61
99
  else:
62
100
  # Turn function into tool
63
- tool_list.append(wrap_tool(tool_or_callable, ctx, is_yolo_mode))
101
+ tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
102
+ # Wrap toolsets
103
+ wrapped_toolsets = [
104
+ ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
105
+ for toolset in toolsets
106
+ ]
107
+ # Create History processor with summarizer
108
+ history_processors = [] if history_processors is None else history_processors
109
+ if auto_summarize:
110
+ history_processors += [
111
+ create_summarize_history_processor(
112
+ ctx=ctx,
113
+ system_prompt=system_prompt,
114
+ rate_limitter=rate_limitter,
115
+ summarization_model=summarization_model,
116
+ summarization_model_settings=summarization_model_settings,
117
+ summarization_system_prompt=summarization_system_prompt,
118
+ summarization_token_threshold=summarization_token_threshold,
119
+ summarization_retries=summarization_retries,
120
+ )
121
+ ]
64
122
  # Return Agent
65
- return Agent(
123
+ return Agent[None, Any](
66
124
  model=model,
67
- system_prompt=system_prompt,
125
+ output_type=output_type,
126
+ instructions=system_prompt,
68
127
  tools=tool_list,
69
- toolsets=toolsets,
128
+ toolsets=wrapped_toolsets,
70
129
  model_settings=model_settings,
71
130
  retries=retries,
131
+ history_processors=history_processors,
72
132
  )
73
133
 
74
134
 
75
135
  def get_agent(
76
136
  ctx: AnyContext,
77
- agent_attr: "Agent | Callable[[AnySharedContext], Agent] | None",
78
137
  model: "str | Model",
79
- system_prompt: str,
80
- model_settings: "ModelSettings | None",
138
+ rate_limitter: LLMRateLimitter | None = None,
139
+ output_type: "OutputSpec[OutputDataT]" = str,
140
+ system_prompt: str = "",
141
+ model_settings: "ModelSettings | None" = None,
81
142
  tools_attr: (
82
- "list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]"
83
- ),
84
- additional_tools: "list[ToolOrCallable]",
85
- toolsets_attr: "list[AbstractToolset[Agent]] | Callable[[AnySharedContext], list[AbstractToolset[Agent]]]", # noqa
86
- additional_toolsets: "list[AbstractToolset[Agent]]",
143
+ "list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
144
+ ) = [],
145
+ additional_tools: "list[ToolOrCallable]" = [],
146
+ toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
147
+ additional_toolsets: "list[AbstractToolset[None] | str]" = [],
87
148
  retries: int = 3,
88
- is_yolo_mode: bool | None = None,
149
+ yolo_mode: bool | list[str] | None = None,
150
+ summarization_model: "Model | str | None" = None,
151
+ summarization_model_settings: "ModelSettings | None" = None,
152
+ summarization_system_prompt: str | None = None,
153
+ summarization_retries: int = 2,
154
+ summarization_token_threshold: int | None = None,
155
+ history_processors: list["HistoryProcessor"] | None = None,
89
156
  ) -> "Agent":
90
157
  """Retrieves the configured Agent instance or creates one if necessary."""
91
- from pydantic_ai import Agent
92
-
93
- # Render agent instance and return if agent_attr is already an agent
94
- if isinstance(agent_attr, Agent):
95
- return agent_attr
96
- if callable(agent_attr):
97
- agent_instance = agent_attr(ctx)
98
- if not isinstance(agent_instance, Agent):
99
- err_msg = (
100
- "Callable agent factory did not return an Agent instance, "
101
- f"got: {type(agent_instance)}"
102
- )
103
- raise TypeError(err_msg)
104
- return agent_instance
105
158
  # Get tools for agent
106
159
  tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
107
160
  tools.extend(additional_tools)
108
161
  # Get Toolsets for agent
109
- tool_sets = list(toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr)
110
- tool_sets.extend(additional_toolsets)
162
+ toolset_or_str_list = list(
163
+ toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
164
+ )
165
+ toolset_or_str_list.extend(additional_toolsets)
166
+ toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
111
167
  # If no agent provided, create one using the configuration
112
168
  return create_agent_instance(
113
169
  ctx=ctx,
114
170
  model=model,
171
+ rate_limitter=rate_limitter,
172
+ output_type=output_type,
115
173
  system_prompt=system_prompt,
116
174
  tools=tools,
117
- toolsets=tool_sets,
175
+ toolsets=toolsets,
118
176
  model_settings=model_settings,
119
177
  retries=retries,
120
- is_yolo_mode=is_yolo_mode,
121
- )
122
-
123
-
124
- async def run_agent_iteration(
125
- ctx: AnyContext,
126
- agent: "Agent",
127
- user_prompt: str,
128
- attachments: "list[UserContent] | None" = None,
129
- history_list: ListOfDict | None = None,
130
- rate_limitter: LLMRateLimiter | None = None,
131
- max_retry: int = 2,
132
- ) -> "AgentRun":
133
- """
134
- Runs a single iteration of the agent execution loop.
135
-
136
- Args:
137
- ctx: The task context.
138
- agent: The Pydantic AI agent instance.
139
- user_prompt: The user's input prompt.
140
- history_list: The current conversation history.
141
-
142
- Returns:
143
- The agent run result object.
144
-
145
- Raises:
146
- Exception: If any error occurs during agent execution.
147
- """
148
- if max_retry < 0:
149
- raise ValueError("Max retry cannot be less than 0")
150
- attempt = 0
151
- while attempt < max_retry:
152
- try:
153
- return await _run_single_agent_iteration(
154
- ctx=ctx,
155
- agent=agent,
156
- user_prompt=user_prompt,
157
- attachments=[] if attachments is None else attachments,
158
- history_list=[] if history_list is None else history_list,
159
- rate_limitter=(
160
- llm_rate_limitter if rate_limitter is None else rate_limitter
161
- ),
162
- )
163
- except BaseException:
164
- attempt += 1
165
- if attempt == max_retry:
166
- raise
167
- raise Exception("Max retry exceeded")
168
-
169
-
170
- async def _run_single_agent_iteration(
171
- ctx: AnyContext,
172
- agent: "Agent",
173
- user_prompt: str,
174
- attachments: "list[UserContent]",
175
- history_list: ListOfDict,
176
- rate_limitter: LLMRateLimiter,
177
- ) -> "AgentRun":
178
- from openai import APIError
179
- from pydantic_ai.messages import ModelMessagesTypeAdapter
180
-
181
- agent_payload = _estimate_request_payload(
182
- agent, user_prompt, attachments, history_list
178
+ yolo_mode=yolo_mode,
179
+ summarization_model=summarization_model,
180
+ summarization_model_settings=summarization_model_settings,
181
+ summarization_system_prompt=summarization_system_prompt,
182
+ summarization_retries=summarization_retries,
183
+ summarization_token_threshold=summarization_token_threshold,
184
+ history_processors=history_processors,
183
185
  )
184
- if rate_limitter:
185
- await rate_limitter.throttle(agent_payload)
186
- else:
187
- await llm_rate_limitter.throttle(agent_payload)
188
-
189
- user_prompt_with_attachments = [user_prompt] + attachments
190
- async with agent:
191
- async with agent.iter(
192
- user_prompt=user_prompt_with_attachments,
193
- message_history=ModelMessagesTypeAdapter.validate_python(history_list),
194
- ) as agent_run:
195
- async for node in agent_run:
196
- # Each node represents a step in the agent's execution
197
- # Reference: https://ai.pydantic.dev/agents/#streaming
198
- try:
199
- await print_node(_get_plain_printer(ctx), agent_run, node)
200
- except APIError as e:
201
- # Extract detailed error information from the response
202
- error_details = extract_api_error_details(e)
203
- ctx.log_error(f"API Error: {error_details}")
204
- raise
205
- except Exception as e:
206
- ctx.log_error(f"Error processing node: {str(e)}")
207
- ctx.log_error(f"Error type: {type(e).__name__}")
208
- raise
209
- return agent_run
210
-
211
-
212
- def _estimate_request_payload(
213
- agent: "Agent",
214
- user_prompt: str,
215
- attachments: "list[UserContent]",
216
- history_list: ListOfDict,
217
- ) -> str:
218
- system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
219
- return json.dumps(
220
- [
221
- {"role": "system", "content": "\n".join(system_prompts)},
222
- *history_list,
223
- {"role": "user", "content": user_prompt},
224
- *[_estimate_attachment_payload(attachment) for attachment in attachments],
225
- ]
226
- )
227
-
228
-
229
- def _estimate_attachment_payload(attachment: "UserContent") -> Any:
230
- if hasattr(attachment, "url"):
231
- return {"role": "user", "content": attachment.url}
232
- if hasattr(attachment, "data"):
233
- return {"role": "user", "content": "x" * len(attachment.data)}
234
- return ""
235
-
236
186
 
237
- def _get_plain_printer(ctx: AnyContext):
238
- def printer(*args, **kwargs):
239
- if "plain" not in kwargs:
240
- kwargs["plain"] = True
241
- return ctx.print(*args, **kwargs)
242
187
 
243
- return printer
188
+ def _render_toolset_or_str_list(
189
+ ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
190
+ ) -> list["AbstractToolset[None]"]:
191
+ from pydantic_ai.mcp import load_mcp_servers
192
+
193
+ toolsets = []
194
+ for toolset_or_str in toolset_or_str_list:
195
+ if isinstance(toolset_or_str, str):
196
+ try:
197
+ servers = load_mcp_servers(toolset_or_str)
198
+ for server in servers:
199
+ toolsets.append(server)
200
+ except Exception as e:
201
+ ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
202
+ continue
203
+ toolsets.append(toolset_or_str)
204
+ return toolsets