zrb 1.13.1__py3-none-any.whl → 1.21.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. zrb/__init__.py +2 -6
  2. zrb/attr/type.py +8 -8
  3. zrb/builtin/__init__.py +2 -0
  4. zrb/builtin/group.py +31 -15
  5. zrb/builtin/http.py +7 -8
  6. zrb/builtin/llm/attachment.py +40 -0
  7. zrb/builtin/llm/chat_session.py +130 -144
  8. zrb/builtin/llm/chat_session_cmd.py +226 -0
  9. zrb/builtin/llm/chat_trigger.py +73 -0
  10. zrb/builtin/llm/history.py +4 -4
  11. zrb/builtin/llm/llm_ask.py +218 -110
  12. zrb/builtin/llm/tool/api.py +74 -62
  13. zrb/builtin/llm/tool/cli.py +35 -16
  14. zrb/builtin/llm/tool/code.py +49 -47
  15. zrb/builtin/llm/tool/file.py +262 -251
  16. zrb/builtin/llm/tool/note.py +84 -0
  17. zrb/builtin/llm/tool/rag.py +25 -18
  18. zrb/builtin/llm/tool/sub_agent.py +29 -22
  19. zrb/builtin/llm/tool/web.py +135 -143
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  22. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  23. zrb/builtin/searxng/config/settings.yml +5671 -0
  24. zrb/builtin/searxng/start.py +21 -0
  25. zrb/builtin/setup/latex/ubuntu.py +1 -0
  26. zrb/builtin/setup/ubuntu.py +1 -1
  27. zrb/builtin/shell/autocomplete/bash.py +4 -3
  28. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  29. zrb/config/config.py +255 -78
  30. zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
  31. zrb/config/default_prompt/interactive_system_prompt.md +24 -30
  32. zrb/config/default_prompt/persona.md +1 -1
  33. zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
  34. zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
  35. zrb/config/default_prompt/summarization_prompt.md +8 -13
  36. zrb/config/default_prompt/system_prompt.md +36 -30
  37. zrb/config/llm_config.py +129 -24
  38. zrb/config/llm_context/config.py +127 -90
  39. zrb/config/llm_context/config_parser.py +1 -7
  40. zrb/config/llm_context/workflow.py +81 -0
  41. zrb/config/llm_rate_limitter.py +89 -45
  42. zrb/context/any_shared_context.py +7 -1
  43. zrb/context/context.py +8 -2
  44. zrb/context/shared_context.py +6 -8
  45. zrb/group/any_group.py +12 -5
  46. zrb/group/group.py +67 -3
  47. zrb/input/any_input.py +5 -1
  48. zrb/input/base_input.py +18 -6
  49. zrb/input/text_input.py +7 -24
  50. zrb/runner/cli.py +21 -20
  51. zrb/runner/common_util.py +24 -19
  52. zrb/runner/web_route/task_input_api_route.py +5 -5
  53. zrb/runner/web_route/task_session_api_route.py +1 -4
  54. zrb/runner/web_util/user.py +7 -3
  55. zrb/session/any_session.py +12 -6
  56. zrb/session/session.py +39 -18
  57. zrb/task/any_task.py +24 -3
  58. zrb/task/base/context.py +17 -9
  59. zrb/task/base/execution.py +15 -8
  60. zrb/task/base/lifecycle.py +8 -4
  61. zrb/task/base/monitoring.py +12 -7
  62. zrb/task/base_task.py +69 -5
  63. zrb/task/base_trigger.py +12 -5
  64. zrb/task/llm/agent.py +138 -52
  65. zrb/task/llm/config.py +45 -13
  66. zrb/task/llm/conversation_history.py +76 -6
  67. zrb/task/llm/conversation_history_model.py +0 -168
  68. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  69. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  70. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  71. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  72. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  73. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  74. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  75. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  76. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  77. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  78. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  79. zrb/task/llm/file_replacement.py +206 -0
  80. zrb/task/llm/file_tool_model.py +57 -0
  81. zrb/task/llm/history_summarization.py +22 -35
  82. zrb/task/llm/history_summarization_tool.py +24 -0
  83. zrb/task/llm/print_node.py +182 -63
  84. zrb/task/llm/prompt.py +213 -153
  85. zrb/task/llm/tool_wrapper.py +210 -53
  86. zrb/task/llm/workflow.py +76 -0
  87. zrb/task/llm_task.py +98 -47
  88. zrb/task/make_task.py +2 -3
  89. zrb/task/rsync_task.py +25 -10
  90. zrb/task/scheduler.py +4 -4
  91. zrb/util/attr.py +50 -40
  92. zrb/util/cli/markdown.py +12 -0
  93. zrb/util/cli/text.py +30 -0
  94. zrb/util/file.py +27 -11
  95. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  96. zrb/util/string/conversion.py +1 -1
  97. zrb/util/truncate.py +23 -0
  98. zrb/util/yaml.py +204 -0
  99. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/METADATA +40 -20
  100. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/RECORD +102 -79
  101. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/WHEEL +1 -1
  102. zrb/task/llm/default_workflow/coding.md +0 -24
  103. zrb/task/llm/default_workflow/copywriting.md +0 -17
  104. zrb/task/llm/default_workflow/researching.md +0 -18
  105. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/entry_points.txt +0 -0
zrb/task/base_task.py CHANGED
@@ -21,6 +21,7 @@ from zrb.task.base.execution import (
21
21
  )
22
22
  from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
23
23
  from zrb.task.base.operators import handle_lshift, handle_rshift
24
+ from zrb.util.string.conversion import to_snake_case
24
25
 
25
26
 
26
27
  class BaseTask(AnyTask):
@@ -216,7 +217,10 @@ class BaseTask(AnyTask):
216
217
  return build_task_context(self, session)
217
218
 
218
219
  def run(
219
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
220
+ self,
221
+ session: AnySession | None = None,
222
+ str_kwargs: dict[str, str] | None = None,
223
+ kwargs: dict[str, Any] | None = None,
220
224
  ) -> Any:
221
225
  """
222
226
  Synchronously runs the task and its dependencies, handling async setup and cleanup.
@@ -235,12 +239,19 @@ class BaseTask(AnyTask):
235
239
  Any: The final result of the main task execution.
236
240
  """
237
241
  # Use asyncio.run() to execute the async cleanup wrapper
238
- return asyncio.run(run_and_cleanup(self, session, str_kwargs))
242
+ return asyncio.run(
243
+ run_and_cleanup(self, session=session, str_kwargs=str_kwargs, kwargs=kwargs)
244
+ )
239
245
 
240
246
  async def async_run(
241
- self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
247
+ self,
248
+ session: AnySession | None = None,
249
+ str_kwargs: dict[str, str] | None = None,
250
+ kwargs: dict[str, Any] | None = None,
242
251
  ) -> Any:
243
- return await run_task_async(self, session, str_kwargs)
252
+ return await run_task_async(
253
+ self, session=session, str_kwargs=str_kwargs, kwargs=kwargs
254
+ )
244
255
 
245
256
  async def exec_root_tasks(self, session: AnySession):
246
257
  return await execute_root_tasks(self, session)
@@ -276,7 +287,60 @@ class BaseTask(AnyTask):
276
287
  # Add definition location to the error
277
288
  if hasattr(e, "add_note"):
278
289
  e.add_note(additional_error_note)
279
- else:
290
+ elif hasattr(e, "__notes__"):
280
291
  # fallback: use the __notes__ attribute directly
281
292
  e.__notes__ = getattr(e, "__notes__", []) + [additional_error_note]
282
293
  raise e
294
+
295
+ def to_function(self) -> Callable[..., Any]:
296
+ from zrb.context.shared_context import SharedContext
297
+ from zrb.session.session import Session
298
+
299
+ def task_runner_fn(**kwargs) -> Any:
300
+ task_kwargs = self._get_func_kwargs(kwargs)
301
+ shared_ctx = SharedContext()
302
+ session = Session(shared_ctx=shared_ctx)
303
+ return self.run(session=session, kwargs=task_kwargs)
304
+
305
+ task_runner_fn.__doc__ = self._create_fn_docstring()
306
+ task_runner_fn.__signature__ = self._create_fn_signature()
307
+ task_runner_fn.__name__ = self.name
308
+ return task_runner_fn
309
+
310
+ def _get_func_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
311
+ fn_kwargs = {}
312
+ for inp in self.inputs:
313
+ snake_input_name = to_snake_case(inp.name)
314
+ if snake_input_name in kwargs:
315
+ fn_kwargs[inp.name] = kwargs[snake_input_name]
316
+ return fn_kwargs
317
+
318
+ def _create_fn_docstring(self) -> str:
319
+ from zrb.context.shared_context import SharedContext
320
+
321
+ stub_shared_ctx = SharedContext()
322
+ str_input_default_values = {}
323
+ for inp in self.inputs:
324
+ str_input_default_values[inp.name] = inp.get_default_str(stub_shared_ctx)
325
+ # Create docstring
326
+ doc = f"{self.description}\n\n"
327
+ if len(self.inputs) > 0:
328
+ doc += "Args:\n"
329
+ for inp in self.inputs:
330
+ str_input_default = str_input_default_values.get(inp.name, "")
331
+ doc += (
332
+ f" {inp.name}: {inp.description} (default: {str_input_default})"
333
+ )
334
+ doc += "\n"
335
+ return doc
336
+
337
+ def _create_fn_signature(self) -> inspect.Signature:
338
+ params = []
339
+ for inp in self.inputs:
340
+ params.append(
341
+ inspect.Parameter(
342
+ name=to_snake_case(inp.name),
343
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
344
+ )
345
+ )
346
+ return inspect.Signature(params)
zrb/task/base_trigger.py CHANGED
@@ -36,7 +36,7 @@ class BaseTrigger(BaseTask):
36
36
  input: list[AnyInput | None] | AnyInput | None = None,
37
37
  env: list[AnyEnv | None] | AnyEnv | None = None,
38
38
  action: fstring | Callable[[AnyContext], Any] | None = None,
39
- execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
39
+ execute_condition: bool | str | Callable[[AnyContext], bool] = True,
40
40
  queue_name: fstring | None = None,
41
41
  callback: list[AnyCallback] | AnyCallback = [],
42
42
  retries: int = 2,
@@ -127,7 +127,7 @@ class BaseTrigger(BaseTask):
127
127
  return self._callbacks
128
128
 
129
129
  async def exec_root_tasks(self, session: AnySession):
130
- exchange_xcom = self.get_exchange_xcom(session)
130
+ exchange_xcom = self._get_exchange_xcom(session)
131
131
  exchange_xcom.add_push_callback(lambda: self._exchange_push_callback(session))
132
132
  return await super().exec_root_tasks(session)
133
133
 
@@ -136,8 +136,7 @@ class BaseTrigger(BaseTask):
136
136
  session.defer_coro(coro)
137
137
 
138
138
  async def _fanout_and_trigger_callback(self, session: AnySession):
139
- exchange_xcom = self.get_exchange_xcom(session)
140
- data = exchange_xcom.pop()
139
+ data = self.pop_exchange_xcom(session)
141
140
  coros = []
142
141
  for callback in self.callbacks:
143
142
  xcom_dict = DotDict({self.queue_name: Xcom([data])})
@@ -156,8 +155,16 @@ class BaseTrigger(BaseTask):
156
155
  )
157
156
  await asyncio.gather(*coros)
158
157
 
159
- def get_exchange_xcom(self, session: AnySession) -> Xcom:
158
+ def _get_exchange_xcom(self, session: AnySession) -> Xcom:
160
159
  shared_ctx = session.shared_ctx
161
160
  if self.queue_name not in shared_ctx.xcom:
162
161
  shared_ctx.xcom[self.queue_name] = Xcom()
163
162
  return shared_ctx.xcom[self.queue_name]
163
+
164
+ def push_exchange_xcom(self, session: AnySession, data: Any):
165
+ exchange_xcom = self._get_exchange_xcom(session)
166
+ exchange_xcom.push(data)
167
+
168
+ def pop_exchange_xcom(self, session: AnySession) -> Any:
169
+ exchange_xcom = self._get_exchange_xcom(session)
170
+ return exchange_xcom.pop()
zrb/task/llm/agent.py CHANGED
@@ -1,5 +1,7 @@
1
+ import inspect
1
2
  import json
2
3
  from collections.abc import Callable
4
+ from dataclasses import dataclass
3
5
  from typing import TYPE_CHECKING, Any
4
6
 
5
7
  from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
@@ -9,32 +11,68 @@ from zrb.task.llm.error import extract_api_error_details
9
11
  from zrb.task.llm.print_node import print_node
10
12
  from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
11
13
  from zrb.task.llm.typing import ListOfDict
14
+ from zrb.util.cli.style import stylize_faint
12
15
 
13
16
  if TYPE_CHECKING:
14
17
  from pydantic_ai import Agent, Tool
15
18
  from pydantic_ai.agent import AgentRun
19
+ from pydantic_ai.messages import UserContent
16
20
  from pydantic_ai.models import Model
21
+ from pydantic_ai.output import OutputDataT, OutputSpec
17
22
  from pydantic_ai.settings import ModelSettings
18
23
  from pydantic_ai.toolsets import AbstractToolset
19
24
 
20
25
  ToolOrCallable = Tool | Callable
21
- else:
22
- ToolOrCallable = Any
23
26
 
24
27
 
25
28
  def create_agent_instance(
26
29
  ctx: AnyContext,
27
- model: "str | Model | None" = None,
30
+ model: "str | Model",
31
+ output_type: "OutputSpec[OutputDataT]" = str,
28
32
  system_prompt: str = "",
29
33
  model_settings: "ModelSettings | None" = None,
30
- tools: list[ToolOrCallable] = [],
31
- toolsets: list["AbstractToolset[Agent]"] = [],
34
+ tools: "list[ToolOrCallable]" = [],
35
+ toolsets: list["AbstractToolset[None]"] = [],
32
36
  retries: int = 3,
33
- ) -> "Agent":
37
+ yolo_mode: bool | list[str] | None = None,
38
+ ) -> "Agent[None, Any]":
34
39
  """Creates a new Agent instance with configured tools and servers."""
35
- from pydantic_ai import Agent, Tool
40
+ from pydantic_ai import Agent, RunContext, Tool
36
41
  from pydantic_ai.tools import GenerateToolJsonSchema
42
+ from pydantic_ai.toolsets import ToolsetTool, WrapperToolset
43
+
44
+ @dataclass
45
+ class ConfirmationWrapperToolset(WrapperToolset):
46
+ ctx: AnyContext
47
+ yolo_mode: bool | list[str]
48
+
49
+ async def call_tool(
50
+ self, name: str, tool_args: dict, ctx: RunContext, tool: ToolsetTool[None]
51
+ ) -> Any:
52
+ # The `tool` object is passed in. Use it for inspection.
53
+ # Define a temporary function that performs the actual tool call.
54
+ async def execute_delegated_tool_call(**params):
55
+ # Pass all arguments down the chain.
56
+ return await self.wrapped.call_tool(name, tool_args, ctx, tool)
37
57
 
58
+ # For the confirmation UI, make our temporary function look like the real one.
59
+ try:
60
+ execute_delegated_tool_call.__name__ = name
61
+ execute_delegated_tool_call.__doc__ = tool.function.__doc__
62
+ execute_delegated_tool_call.__signature__ = inspect.signature(
63
+ tool.function
64
+ )
65
+ except (AttributeError, TypeError):
66
+ pass # Ignore if we can't inspect the original function
67
+ # Use the existing wrap_func to get the confirmation logic
68
+ wrapped_executor = wrap_func(
69
+ execute_delegated_tool_call, self.ctx, self.yolo_mode
70
+ )
71
+ # Call the wrapped executor. This will trigger the confirmation prompt.
72
+ return await wrapped_executor(**tool_args)
73
+
74
+ if yolo_mode is None:
75
+ yolo_mode = False
38
76
  # Normalize tools
39
77
  tool_list = []
40
78
  for tool_or_callable in tools:
@@ -44,7 +82,7 @@ def create_agent_instance(
44
82
  tool = tool_or_callable
45
83
  tool_list.append(
46
84
  Tool(
47
- function=wrap_func(tool.function),
85
+ function=wrap_func(tool.function, ctx, yolo_mode),
48
86
  takes_ctx=tool.takes_ctx,
49
87
  max_retries=tool.max_retries,
50
88
  name=tool.name,
@@ -58,13 +96,19 @@ def create_agent_instance(
58
96
  )
59
97
  else:
60
98
  # Turn function into tool
61
- tool_list.append(wrap_tool(tool_or_callable, ctx))
99
+ tool_list.append(wrap_tool(tool_or_callable, ctx, yolo_mode))
100
+ # Wrap toolsets
101
+ wrapped_toolsets = [
102
+ ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
103
+ for toolset in toolsets
104
+ ]
62
105
  # Return Agent
63
- return Agent(
106
+ return Agent[None, Any](
64
107
  model=model,
65
- system_prompt=system_prompt,
108
+ output_type=output_type,
109
+ instructions=system_prompt,
66
110
  tools=tool_list,
67
- toolsets=toolsets,
111
+ toolsets=wrapped_toolsets,
68
112
  model_settings=model_settings,
69
113
  retries=retries,
70
114
  )
@@ -72,58 +116,71 @@ def create_agent_instance(
72
116
 
73
117
  def get_agent(
74
118
  ctx: AnyContext,
75
- agent_attr: "Agent | Callable[[AnySharedContext], Agent] | None",
76
- model: "str | Model | None",
77
- system_prompt: str,
78
- model_settings: "ModelSettings | None",
119
+ model: "str | Model",
120
+ output_type: "OutputSpec[OutputDataT]" = str,
121
+ system_prompt: str = "",
122
+ model_settings: "ModelSettings | None" = None,
79
123
  tools_attr: (
80
- list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
81
- ),
82
- additional_tools: list[ToolOrCallable],
83
- toolsets_attr: "list[AbstractToolset[Agent]] | Callable[[AnySharedContext], list[AbstractToolset[Agent]]]", # noqa
84
- additional_toolsets: "list[AbstractToolset[Agent]]",
124
+ "list[ToolOrCallable] | Callable[[AnyContext], list[ToolOrCallable]]"
125
+ ) = [],
126
+ additional_tools: "list[ToolOrCallable]" = [],
127
+ toolsets_attr: "list[AbstractToolset[None] | str] | Callable[[AnyContext], list[AbstractToolset[None] | str]]" = [], # noqa
128
+ additional_toolsets: "list[AbstractToolset[None] | str]" = [],
85
129
  retries: int = 3,
130
+ yolo_mode: bool | list[str] | None = None,
86
131
  ) -> "Agent":
87
132
  """Retrieves the configured Agent instance or creates one if necessary."""
88
- from pydantic_ai import Agent
89
-
90
- # Render agent instance and return if agent_attr is already an agent
91
- if isinstance(agent_attr, Agent):
92
- return agent_attr
93
- if callable(agent_attr):
94
- agent_instance = agent_attr(ctx)
95
- if not isinstance(agent_instance, Agent):
96
- err_msg = (
97
- "Callable agent factory did not return an Agent instance, "
98
- f"got: {type(agent_instance)}"
99
- )
100
- raise TypeError(err_msg)
101
- return agent_instance
102
133
  # Get tools for agent
103
134
  tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
104
135
  tools.extend(additional_tools)
105
136
  # Get Toolsets for agent
106
- tool_sets = list(toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr)
107
- tool_sets.extend(additional_toolsets)
137
+ toolset_or_str_list = list(
138
+ toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr
139
+ )
140
+ toolset_or_str_list.extend(additional_toolsets)
141
+ toolsets = _render_toolset_or_str_list(ctx, toolset_or_str_list)
108
142
  # If no agent provided, create one using the configuration
109
143
  return create_agent_instance(
110
144
  ctx=ctx,
111
145
  model=model,
146
+ output_type=output_type,
112
147
  system_prompt=system_prompt,
113
148
  tools=tools,
114
- toolsets=tool_sets,
149
+ toolsets=toolsets,
115
150
  model_settings=model_settings,
116
151
  retries=retries,
152
+ yolo_mode=yolo_mode,
117
153
  )
118
154
 
119
155
 
156
+ def _render_toolset_or_str_list(
157
+ ctx: AnyContext, toolset_or_str_list: list["AbstractToolset[None] | str"]
158
+ ) -> list["AbstractToolset[None]"]:
159
+ from pydantic_ai.mcp import load_mcp_servers
160
+
161
+ toolsets = []
162
+ for toolset_or_str in toolset_or_str_list:
163
+ if isinstance(toolset_or_str, str):
164
+ try:
165
+ servers = load_mcp_servers(toolset_or_str)
166
+ for server in servers:
167
+ toolsets.append(server)
168
+ except Exception as e:
169
+ ctx.log_error(f"Invalid MCP Config {toolset_or_str}: {e}")
170
+ continue
171
+ toolsets.append(toolset_or_str)
172
+ return toolsets
173
+
174
+
120
175
  async def run_agent_iteration(
121
176
  ctx: AnyContext,
122
- agent: "Agent",
177
+ agent: "Agent[None, Any]",
123
178
  user_prompt: str,
124
- history_list: ListOfDict,
179
+ attachments: "list[UserContent] | None" = None,
180
+ history_list: ListOfDict | None = None,
125
181
  rate_limitter: LLMRateLimiter | None = None,
126
182
  max_retry: int = 2,
183
+ log_indent_level: int = 0,
127
184
  ) -> "AgentRun":
128
185
  """
129
186
  Runs a single iteration of the agent execution loop.
@@ -149,8 +206,12 @@ async def run_agent_iteration(
149
206
  ctx=ctx,
150
207
  agent=agent,
151
208
  user_prompt=user_prompt,
152
- history_list=history_list,
153
- rate_limitter=rate_limitter,
209
+ attachments=[] if attachments is None else attachments,
210
+ history_list=[] if history_list is None else history_list,
211
+ rate_limitter=(
212
+ llm_rate_limitter if rate_limitter is None else rate_limitter
213
+ ),
214
+ log_indent_level=log_indent_level,
154
215
  )
155
216
  except BaseException:
156
217
  attempt += 1
@@ -163,28 +224,34 @@ async def _run_single_agent_iteration(
163
224
  ctx: AnyContext,
164
225
  agent: "Agent",
165
226
  user_prompt: str,
227
+ attachments: "list[UserContent]",
166
228
  history_list: ListOfDict,
167
- rate_limitter: LLMRateLimiter | None = None,
229
+ rate_limitter: LLMRateLimiter,
230
+ log_indent_level: int,
168
231
  ) -> "AgentRun":
169
232
  from openai import APIError
170
233
  from pydantic_ai.messages import ModelMessagesTypeAdapter
171
234
 
172
- agent_payload = estimate_request_payload(agent, user_prompt, history_list)
235
+ agent_payload = _estimate_request_payload(
236
+ agent, user_prompt, attachments, history_list
237
+ )
238
+ callback = _create_print_throttle_notif(ctx)
173
239
  if rate_limitter:
174
- await rate_limitter.throttle(agent_payload)
240
+ await rate_limitter.throttle(agent_payload, callback)
175
241
  else:
176
- await llm_rate_limitter.throttle(agent_payload)
177
-
242
+ await llm_rate_limitter.throttle(agent_payload, callback)
243
+ user_prompt_with_attachments = [user_prompt] + attachments
178
244
  async with agent:
179
245
  async with agent.iter(
180
- user_prompt=user_prompt,
246
+ user_prompt=user_prompt_with_attachments,
181
247
  message_history=ModelMessagesTypeAdapter.validate_python(history_list),
182
248
  ) as agent_run:
183
249
  async for node in agent_run:
184
250
  # Each node represents a step in the agent's execution
185
- # Reference: https://ai.pydantic.dev/agents/#streaming
186
251
  try:
187
- await print_node(_get_plain_printer(ctx), agent_run, node)
252
+ await print_node(
253
+ _get_plain_printer(ctx), agent_run, node, log_indent_level
254
+ )
188
255
  except APIError as e:
189
256
  # Extract detailed error information from the response
190
257
  error_details = extract_api_error_details(e)
@@ -197,8 +264,18 @@ async def _run_single_agent_iteration(
197
264
  return agent_run
198
265
 
199
266
 
200
- def estimate_request_payload(
201
- agent: "Agent", user_prompt: str, history_list: ListOfDict
267
+ def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[], None]:
268
+ def _print_throttle_notif():
269
+ ctx.print(stylize_faint(" ⌛>> Request Throttled"), plain=True)
270
+
271
+ return _print_throttle_notif
272
+
273
+
274
+ def _estimate_request_payload(
275
+ agent: "Agent",
276
+ user_prompt: str,
277
+ attachments: "list[UserContent]",
278
+ history_list: ListOfDict,
202
279
  ) -> str:
203
280
  system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
204
281
  return json.dumps(
@@ -206,10 +283,19 @@ def estimate_request_payload(
206
283
  {"role": "system", "content": "\n".join(system_prompts)},
207
284
  *history_list,
208
285
  {"role": "user", "content": user_prompt},
286
+ *[_estimate_attachment_payload(attachment) for attachment in attachments],
209
287
  ]
210
288
  )
211
289
 
212
290
 
291
+ def _estimate_attachment_payload(attachment: "UserContent") -> Any:
292
+ if hasattr(attachment, "url"):
293
+ return {"role": "user", "content": attachment.url}
294
+ if hasattr(attachment, "data"):
295
+ return {"role": "user", "content": "x" * len(attachment.data)}
296
+ return ""
297
+
298
+
213
299
  def _get_plain_printer(ctx: AnyContext):
214
300
  def printer(*args, **kwargs):
215
301
  if "plain" not in kwargs:
zrb/task/llm/config.py CHANGED
@@ -1,20 +1,43 @@
1
- from typing import TYPE_CHECKING, Any, Callable
1
+ from typing import TYPE_CHECKING, Callable
2
2
 
3
3
  if TYPE_CHECKING:
4
4
  from pydantic_ai.models import Model
5
5
  from pydantic_ai.settings import ModelSettings
6
6
 
7
- from zrb.attr.type import StrAttr, fstring
7
+ from zrb.attr.type import BoolAttr, StrAttr, StrListAttr
8
8
  from zrb.config.llm_config import LLMConfig, llm_config
9
9
  from zrb.context.any_context import AnyContext
10
- from zrb.context.any_shared_context import AnySharedContext
11
- from zrb.util.attr import get_attr
10
+ from zrb.util.attr import get_attr, get_bool_attr, get_str_list_attr
11
+
12
+
13
+ def get_yolo_mode(
14
+ ctx: AnyContext,
15
+ yolo_mode_attr: (
16
+ Callable[[AnyContext], list[str] | bool | None] | StrListAttr | BoolAttr | None
17
+ ) = None,
18
+ render_yolo_mode: bool = True,
19
+ ) -> bool | list[str]:
20
+ if yolo_mode_attr is None:
21
+ return llm_config.default_yolo_mode
22
+ try:
23
+ return get_bool_attr(
24
+ ctx,
25
+ yolo_mode_attr,
26
+ False,
27
+ auto_render=render_yolo_mode,
28
+ )
29
+ except Exception:
30
+ return get_str_list_attr(
31
+ ctx,
32
+ yolo_mode_attr,
33
+ auto_render=render_yolo_mode,
34
+ )
12
35
 
13
36
 
14
37
  def get_model_settings(
15
38
  ctx: AnyContext,
16
39
  model_settings_attr: (
17
- "ModelSettings | Callable[[AnySharedContext], ModelSettings] | None"
40
+ "ModelSettings | Callable[[AnyContext], ModelSettings] | None"
18
41
  ) = None,
19
42
  ) -> "ModelSettings | None":
20
43
  """Gets the model settings, resolving callables if necessary."""
@@ -47,7 +70,7 @@ def get_model_api_key(
47
70
  ) -> str | None:
48
71
  """Gets the model API key, rendering if configured."""
49
72
  api_key = get_attr(ctx, model_api_key_attr, None, auto_render=render_model_api_key)
50
- if api_key is None and llm_config.default_api_key is not None:
73
+ if api_key is None and llm_config.default_model_api_key is not None:
51
74
  return llm_config.default_model_api_key
52
75
  if isinstance(api_key, str) or api_key is None:
53
76
  return api_key
@@ -56,18 +79,21 @@ def get_model_api_key(
56
79
 
57
80
  def get_model(
58
81
  ctx: AnyContext,
59
- model_attr: "Callable[[AnySharedContext], Model | str | fstring] | Model | None",
82
+ model_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
60
83
  render_model: bool,
61
- model_base_url_attr: StrAttr | None = None,
84
+ model_base_url_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
62
85
  render_model_base_url: bool = True,
63
- model_api_key_attr: StrAttr | None = None,
86
+ model_api_key_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None" = None,
64
87
  render_model_api_key: bool = True,
65
- ) -> "str | Model | None":
88
+ is_small_model: bool = False,
89
+ ) -> "str | Model":
66
90
  """Gets the model instance or name, handling defaults and configuration."""
67
91
  from pydantic_ai.models import Model
68
92
 
69
93
  model = get_attr(ctx, model_attr, None, auto_render=render_model)
70
94
  if model is None:
95
+ if is_small_model:
96
+ return llm_config.default_small_model
71
97
  return llm_config.default_model
72
98
  if isinstance(model, str):
73
99
  model_base_url = get_model_base_url(
@@ -76,11 +102,11 @@ def get_model(
76
102
  model_api_key = get_model_api_key(ctx, model_api_key_attr, render_model_api_key)
77
103
  new_llm_config = LLMConfig(
78
104
  default_model_name=model,
79
- default_base_url=model_base_url,
80
- default_api_key=model_api_key,
105
+ default_model_base_url=model_base_url,
106
+ default_model_api_key=model_api_key,
81
107
  )
82
108
  if model_base_url is None and model_api_key is None:
83
- default_model_provider = llm_config.default_model_provider
109
+ default_model_provider = _get_default_model_provider(is_small_model)
84
110
  if default_model_provider is not None:
85
111
  new_llm_config.set_default_model_provider(default_model_provider)
86
112
  return new_llm_config.default_model
@@ -88,3 +114,9 @@ def get_model(
88
114
  if isinstance(model, Model):
89
115
  return model
90
116
  raise ValueError(f"Invalid model type resolved: {type(model)}, value: {model}")
117
+
118
+
119
+ def _get_default_model_provider(is_small_model: bool = False):
120
+ if is_small_model:
121
+ return llm_config.default_small_model_provider
122
+ return llm_config.default_model_provider