zrb 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zrb/task/llm/agent.py CHANGED
@@ -7,15 +7,15 @@ from zrb.context.any_context import AnyContext
7
7
  from zrb.context.any_shared_context import AnySharedContext
8
8
  from zrb.task.llm.error import extract_api_error_details
9
9
  from zrb.task.llm.print_node import print_node
10
- from zrb.task.llm.tool_wrapper import wrap_tool
10
+ from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
11
11
  from zrb.task.llm.typing import ListOfDict
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from pydantic_ai import Agent, Tool
15
15
  from pydantic_ai.agent import AgentRun
16
- from pydantic_ai.mcp import MCPServer
17
16
  from pydantic_ai.models import Model
18
17
  from pydantic_ai.settings import ModelSettings
18
+ from pydantic_ai.toolsets import AbstractToolset
19
19
 
20
20
  ToolOrCallable = Tool | Callable
21
21
  else:
@@ -28,26 +28,43 @@ def create_agent_instance(
28
28
  system_prompt: str = "",
29
29
  model_settings: "ModelSettings | None" = None,
30
30
  tools: list[ToolOrCallable] = [],
31
- mcp_servers: list["MCPServer"] = [],
31
+ toolsets: list["AbstractToolset[Agent]"] = [],
32
32
  retries: int = 3,
33
33
  ) -> "Agent":
34
34
  """Creates a new Agent instance with configured tools and servers."""
35
35
  from pydantic_ai import Agent, Tool
36
+ from pydantic_ai.tools import GenerateToolJsonSchema
36
37
 
37
38
  # Normalize tools
38
39
  tool_list = []
39
40
  for tool_or_callable in tools:
40
41
  if isinstance(tool_or_callable, Tool):
41
42
  tool_list.append(tool_or_callable)
43
+ # Update tool's function
44
+ tool = tool_or_callable
45
+ tool_list.append(
46
+ Tool(
47
+ function=wrap_func(tool.function),
48
+ takes_ctx=tool.takes_ctx,
49
+ max_retries=tool.max_retries,
50
+ name=tool.name,
51
+ description=tool.description,
52
+ prepare=tool.prepare,
53
+ docstring_format=tool.docstring_format,
54
+ require_parameter_descriptions=tool.require_parameter_descriptions,
55
+ schema_generator=GenerateToolJsonSchema,
56
+ strict=tool.strict,
57
+ )
58
+ )
42
59
  else:
43
- # Pass ctx to wrap_tool
60
+ # Turn function into tool
44
61
  tool_list.append(wrap_tool(tool_or_callable, ctx))
45
62
  # Return Agent
46
63
  return Agent(
47
64
  model=model,
48
65
  system_prompt=system_prompt,
49
66
  tools=tool_list,
50
- toolsets=mcp_servers,
67
+ toolsets=toolsets,
51
68
  model_settings=model_settings,
52
69
  retries=retries,
53
70
  )
@@ -63,8 +80,8 @@ def get_agent(
63
80
  list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
64
81
  ),
65
82
  additional_tools: list[ToolOrCallable],
66
- mcp_servers_attr: "list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]]",
67
- additional_mcp_servers: "list[MCPServer]",
83
+ toolsets_attr: "list[AbstractToolset[Agent]] | Callable[[AnySharedContext], list[AbstractToolset[Agent]]]", # noqa
84
+ additional_toolsets: "list[AbstractToolset[Agent]]",
68
85
  retries: int = 3,
69
86
  ) -> "Agent":
70
87
  """Retrieves the configured Agent instance or creates one if necessary."""
@@ -85,18 +102,16 @@ def get_agent(
85
102
  # Get tools for agent
86
103
  tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
87
104
  tools.extend(additional_tools)
88
- # Get MCP Servers for agent
89
- mcp_servers = list(
90
- mcp_servers_attr(ctx) if callable(mcp_servers_attr) else mcp_servers_attr
91
- )
92
- mcp_servers.extend(additional_mcp_servers)
105
+ # Get Toolsets for agent
106
+ tool_sets = list(toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr)
107
+ tool_sets.extend(additional_toolsets)
93
108
  # If no agent provided, create one using the configuration
94
109
  return create_agent_instance(
95
110
  ctx=ctx,
96
111
  model=model,
97
112
  system_prompt=system_prompt,
98
113
  tools=tools,
99
- mcp_servers=mcp_servers,
114
+ toolsets=tool_sets,
100
115
  model_settings=model_settings,
101
116
  retries=retries,
102
117
  )
@@ -176,46 +176,23 @@ class ConversationHistory:
176
176
  """
177
177
  return json.dumps({"content": self._fetch_long_term_note()})
178
178
 
179
- def add_long_term_info(self, new_info: str) -> str:
179
+ def write_long_term_note(self, content: str) -> str:
180
180
  """
181
- Add new info for long-term reference.
181
+ Write the entire content of the long-term references.
182
+ This will overwrite any existing long-term notes.
182
183
 
183
184
  Args:
184
- new_info (str): New info to be added into long-term references.
185
+ content (str): The full content of the long-term notes.
185
186
 
186
187
  Returns:
187
- str: JSON with new content of the notes.
188
-
189
- Raises:
190
- Exception: If the note cannot be read.
191
- """
192
- llm_context_config.add_to_context(new_info, cwd="/")
193
- return json.dumps({"success": True, "content": self._fetch_long_term_note()})
194
-
195
- def remove_long_term_info(self, irrelevant_info: str) -> str:
188
+ str: JSON indicating success.
196
189
  """
197
- Remove irrelevant info from long-term reference.
198
-
199
- Args:
200
- irrelevant_info (str): Irrelevant info to be removed from long-term references.
201
-
202
- Returns:
203
- str: JSON with new content of the notes and deletion status.
204
-
205
- Raises:
206
- Exception: If the note cannot be read.
207
- """
208
- was_removed = llm_context_config.remove_from_context(irrelevant_info, cwd="/")
209
- return json.dumps(
210
- {
211
- "success": was_removed,
212
- "content": self._fetch_long_term_note(),
213
- }
214
- )
190
+ llm_context_config.write_context(content, context_path="/")
191
+ return json.dumps({"success": True})
215
192
 
216
193
  def read_contextual_note(self) -> str:
217
194
  """
218
- Read the content of the contextual references.
195
+ Read the content of the contextual references for the current project.
219
196
 
220
197
  This tool helps you retrieve knowledge or notes stored for contextual reference.
221
198
  If the note does not exist, you may want to create it using the write tool.
@@ -228,52 +205,25 @@ class ConversationHistory:
228
205
  """
229
206
  return json.dumps({"content": self._fetch_contextual_note()})
230
207
 
231
- def add_contextual_info(self, new_info: str, context_path: str | None) -> str:
232
- """
233
- Add new info for contextual reference.
234
-
235
- Args:
236
- new_info (str): New info to be added into contextual references.
237
- context_path (str, optional): contextual directory path for new info
238
-
239
- Returns:
240
- str: JSON with new content of the notes.
241
-
242
- Raises:
243
- Exception: If the note cannot be read.
244
- """
245
- if context_path is None:
246
- context_path = self.project_path
247
- llm_context_config.add_to_context(new_info, context_path=context_path)
248
- return json.dumps({"success": True, "content": self._fetch_contextual_note()})
249
-
250
- def remove_contextual_info(
251
- self, irrelevant_info: str, context_path: str | None
208
+ def write_contextual_note(
209
+ self, content: str, context_path: str | None = None
252
210
  ) -> str:
253
211
  """
254
- Remove irrelevant info from contextual reference.
212
+ Write the entire content of the contextual references for a specific path.
213
+ This will overwrite any existing contextual notes for that path.
255
214
 
256
215
  Args:
257
- irrelevant_info (str): Irrelevant info to be removed from contextual references.
258
- context_path (str, optional): contextual directory path of the irrelevant info
216
+ content (str): The full content of the contextual notes.
217
+ context_path (str, optional): The directory path for the context.
218
+ Defaults to the current project path.
259
219
 
260
220
  Returns:
261
- str: JSON with new content of the notes and deletion status.
262
-
263
- Raises:
264
- Exception: If the note cannot be read.
221
+ str: JSON indicating success.
265
222
  """
266
223
  if context_path is None:
267
224
  context_path = self.project_path
268
- was_removed = llm_context_config.remove_from_context(
269
- irrelevant_info, context_path=context_path
270
- )
271
- return json.dumps(
272
- {
273
- "success": was_removed,
274
- "content": self._fetch_contextual_note(),
275
- }
276
- )
225
+ llm_context_config.write_context(content, context_path=context_path)
226
+ return json.dumps({"success": True})
277
227
 
278
228
  def _fetch_long_term_note(self):
279
229
  contexts = llm_context_config.get_contexts(cwd=self.project_path)
@@ -1,5 +1,3 @@
1
- # Special Instructions for Software Engineering
2
-
3
1
  When the user's request involves writing or modifying code, you MUST follow these domain-specific rules in addition to your core workflow.
4
2
 
5
3
  ## 1. Critical Prohibitions
@@ -1,5 +1,3 @@
1
- # Special Instructions for Content Creation & Management
2
-
3
1
  When the user's request involves creating, refining, or organizing textual content, you MUST follow these domain-specific rules in addition to your core workflow.
4
2
 
5
3
  ## 1. Core Principles
@@ -1,5 +1,3 @@
1
- # Special Instructions for Research, Analysis, and Summarization
2
-
3
1
  When the user's request involves finding, synthesizing, or analyzing information, you MUST follow these domain-specific rules in addition to your core workflow.
4
2
 
5
3
  ## 1. Core Principles
@@ -146,11 +146,9 @@ async def summarize_history(
146
146
  conversation_history.write_past_conversation_summary,
147
147
  conversation_history.write_past_conversation_transcript,
148
148
  conversation_history.read_long_term_note,
149
- conversation_history.add_long_term_info,
150
- conversation_history.remove_long_term_info,
149
+ conversation_history.write_long_term_note,
151
150
  conversation_history.read_contextual_note,
152
- conversation_history.add_contextual_info,
153
- conversation_history.remove_contextual_info,
151
+ conversation_history.write_contextual_note,
154
152
  ],
155
153
  )
156
154
  try:
@@ -14,6 +14,7 @@ async def print_node(print_func: Callable, agent_run: Any, node: Any):
14
14
  PartDeltaEvent,
15
15
  PartStartEvent,
16
16
  TextPartDelta,
17
+ ThinkingPartDelta,
17
18
  ToolCallPartDelta,
18
19
  )
19
20
 
@@ -33,7 +34,9 @@ async def print_node(print_func: Callable, agent_run: Any, node: Any):
33
34
  )
34
35
  is_streaming = False
35
36
  elif isinstance(event, PartDeltaEvent):
36
- if isinstance(event.delta, TextPartDelta):
37
+ if isinstance(event.delta, TextPartDelta) or isinstance(
38
+ event.delta, ThinkingPartDelta
39
+ ):
37
40
  print_func(
38
41
  stylize_faint(f"{event.delta.content_delta}"),
39
42
  end="",
zrb/task/llm/prompt.py CHANGED
@@ -3,11 +3,12 @@ import platform
3
3
  import re
4
4
  from datetime import datetime, timezone
5
5
 
6
- from zrb.attr.type import StrAttr
6
+ from zrb.attr.type import StrAttr, StrListAttr
7
7
  from zrb.config.llm_config import llm_config as llm_config
8
+ from zrb.config.llm_context.config import llm_context_config
8
9
  from zrb.context.any_context import AnyContext
9
10
  from zrb.task.llm.conversation_history_model import ConversationHistory
10
- from zrb.util.attr import get_attr, get_str_attr
11
+ from zrb.util.attr import get_attr, get_str_attr, get_str_list_attr
11
12
  from zrb.util.file import read_dir, read_file_with_line_numbers
12
13
  from zrb.util.llm.prompt import make_prompt_section
13
14
 
@@ -15,13 +16,14 @@ from zrb.util.llm.prompt import make_prompt_section
15
16
  def get_persona(
16
17
  ctx: AnyContext,
17
18
  persona_attr: StrAttr | None,
19
+ render_persona: bool,
18
20
  ) -> str:
19
21
  """Gets the persona, prioritizing task-specific, then default."""
20
22
  persona = get_attr(
21
23
  ctx,
22
24
  persona_attr,
23
25
  None,
24
- auto_render=False,
26
+ auto_render=render_persona,
25
27
  )
26
28
  if persona is not None:
27
29
  return persona
@@ -31,13 +33,14 @@ def get_persona(
31
33
  def get_base_system_prompt(
32
34
  ctx: AnyContext,
33
35
  system_prompt_attr: StrAttr | None,
36
+ render_system_prompt: bool,
34
37
  ) -> str:
35
38
  """Gets the base system prompt, prioritizing task-specific, then default."""
36
39
  system_prompt = get_attr(
37
40
  ctx,
38
41
  system_prompt_attr,
39
42
  None,
40
- auto_render=False,
43
+ auto_render=render_system_prompt,
41
44
  )
42
45
  if system_prompt is not None:
43
46
  return system_prompt
@@ -47,33 +50,95 @@ def get_base_system_prompt(
47
50
  def get_special_instruction_prompt(
48
51
  ctx: AnyContext,
49
52
  special_instruction_prompt_attr: StrAttr | None,
53
+ render_spcecial_instruction_prompt: bool,
50
54
  ) -> str:
51
55
  """Gets the special instruction prompt, prioritizing task-specific, then default."""
52
56
  special_instruction = get_attr(
53
57
  ctx,
54
58
  special_instruction_prompt_attr,
55
59
  None,
56
- auto_render=False,
60
+ auto_render=render_spcecial_instruction_prompt,
57
61
  )
58
62
  if special_instruction is not None:
59
63
  return special_instruction
60
64
  return llm_config.default_special_instruction_prompt
61
65
 
62
66
 
67
+ def get_modes(
68
+ ctx: AnyContext,
69
+ modes_attr: StrAttr | None,
70
+ render_modes: bool,
71
+ ) -> str:
72
+ """Gets the modes, prioritizing task-specific, then default."""
73
+ raw_modes = get_str_list_attr(
74
+ ctx,
75
+ modes_attr,
76
+ auto_render=render_modes,
77
+ )
78
+ modes = [mode.strip() for mode in raw_modes if mode.strip() != ""]
79
+ if len(modes) > 0:
80
+ return modes
81
+ return llm_config.default_modes or []
82
+
83
+
84
+ def get_workflow_prompt(
85
+ ctx: AnyContext,
86
+ modes_attr: StrAttr | None,
87
+ render_modes: bool,
88
+ ) -> str:
89
+ modes = get_modes(ctx, modes_attr, render_modes)
90
+ # Get user-defined workflows
91
+ workflows = {
92
+ workflow_name: content
93
+ for workflow_name, content in llm_context_config.get_workflows().items()
94
+ if workflow_name in modes
95
+ }
96
+ # Get requested builtin-workflow names
97
+ requested_builtin_workflow_names = [
98
+ workflow_name
99
+ for workflow_name in ("coding", "copywriting", "researching")
100
+ if workflow_name in modes and workflow_name not in workflows
101
+ ]
102
+ # add builtin-workflows if requested
103
+ if len(requested_builtin_workflow_names) > 0:
104
+ dir_path = os.path.dirname(__file__)
105
+ for workflow_name in requested_builtin_workflow_names:
106
+ workflow_file_path = os.path.join(
107
+ dir_path, "default_workflow", f"{workflow_name}.md"
108
+ )
109
+ with open(workflow_file_path, "r") as f:
110
+ workflows[workflow_name] = f.read()
111
+ return "\n".join(
112
+ [
113
+ make_prompt_section(header.capitalize(), content)
114
+ for header, content in workflows.items()
115
+ if header.lower() in modes
116
+ ]
117
+ )
118
+
119
+
63
120
  def get_system_and_user_prompt(
64
121
  ctx: AnyContext,
65
122
  user_message: str,
66
123
  persona_attr: StrAttr | None = None,
124
+ render_persona: bool = False,
67
125
  system_prompt_attr: StrAttr | None = None,
126
+ render_system_prompt: bool = False,
68
127
  special_instruction_prompt_attr: StrAttr | None = None,
128
+ render_special_instruction_prompt: bool = False,
129
+ modes_attr: StrListAttr | None = None,
130
+ render_modes: bool = False,
69
131
  conversation_history: ConversationHistory | None = None,
70
132
  ) -> tuple[str, str]:
71
133
  """Combines persona, base system prompt, and special instructions."""
72
- persona = get_persona(ctx, persona_attr)
73
- base_system_prompt = get_base_system_prompt(ctx, system_prompt_attr)
74
- special_instruction = get_special_instruction_prompt(
75
- ctx, special_instruction_prompt_attr
134
+ persona = get_persona(ctx, persona_attr, render_persona)
135
+ base_system_prompt = get_base_system_prompt(
136
+ ctx, system_prompt_attr, render_system_prompt
76
137
  )
138
+ special_instruction_prompt = get_special_instruction_prompt(
139
+ ctx, special_instruction_prompt_attr, render_special_instruction_prompt
140
+ )
141
+ workflow_prompt = get_workflow_prompt(ctx, modes_attr, render_modes)
77
142
  if conversation_history is None:
78
143
  conversation_history = ConversationHistory()
79
144
  conversation_context, new_user_message = extract_conversation_context(user_message)
@@ -81,7 +146,8 @@ def get_system_and_user_prompt(
81
146
  [
82
147
  make_prompt_section("Persona", persona),
83
148
  make_prompt_section("System Prompt", base_system_prompt),
84
- make_prompt_section("Special Instruction", special_instruction),
149
+ make_prompt_section("Special Instruction", special_instruction_prompt),
150
+ make_prompt_section("Special Workflows", workflow_prompt),
85
151
  make_prompt_section(
86
152
  "Past Conversation",
87
153
  "\n".join(
@@ -194,30 +260,15 @@ def get_user_message(
194
260
  def get_summarization_system_prompt(
195
261
  ctx: AnyContext,
196
262
  summarization_prompt_attr: StrAttr | None,
263
+ render_summarization_prompt: bool,
197
264
  ) -> str:
198
265
  """Gets the summarization prompt, rendering if configured and handling defaults."""
199
266
  summarization_prompt = get_attr(
200
267
  ctx,
201
268
  summarization_prompt_attr,
202
269
  None,
203
- auto_render=False,
270
+ auto_render=render_summarization_prompt,
204
271
  )
205
272
  if summarization_prompt is not None:
206
273
  return summarization_prompt
207
274
  return llm_config.default_summarization_prompt
208
-
209
-
210
- def get_context_enrichment_prompt(
211
- ctx: AnyContext,
212
- context_enrichment_prompt_attr: StrAttr | None,
213
- ) -> str:
214
- """Gets the context enrichment prompt, rendering if configured and handling defaults."""
215
- context_enrichment_prompt = get_attr(
216
- ctx,
217
- context_enrichment_prompt_attr,
218
- None,
219
- auto_render=False,
220
- )
221
- if context_enrichment_prompt is not None:
222
- return context_enrichment_prompt
223
- return llm_config.default_context_enrichment_prompt
@@ -5,9 +5,12 @@ import typing
5
5
  from collections.abc import Callable
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from zrb.config.config import CFG
8
9
  from zrb.context.any_context import AnyContext
9
10
  from zrb.task.llm.error import ToolExecutionError
11
+ from zrb.util.callable import get_callable_name
10
12
  from zrb.util.run import run_async
13
+ from zrb.util.string.conversion import to_boolean
11
14
 
12
15
  if TYPE_CHECKING:
13
16
  from pydantic_ai import Tool
@@ -18,16 +21,19 @@ def wrap_tool(func: Callable, ctx: AnyContext) -> "Tool":
18
21
  from pydantic_ai import RunContext, Tool
19
22
 
20
23
  original_sig = inspect.signature(func)
21
- # Use helper function for clarity
22
24
  needs_run_context_for_pydantic = _has_context_parameter(original_sig, RunContext)
25
+ wrapper = wrap_func(func, ctx)
26
+ return Tool(wrapper, takes_ctx=needs_run_context_for_pydantic)
27
+
28
+
29
+ def wrap_func(func: Callable, ctx: AnyContext) -> Callable:
30
+ original_sig = inspect.signature(func)
23
31
  needs_any_context_for_injection = _has_context_parameter(original_sig, AnyContext)
24
32
  takes_no_args = len(original_sig.parameters) == 0
25
33
  # Pass individual flags to the wrapper creator
26
34
  wrapper = _create_wrapper(func, original_sig, ctx, needs_any_context_for_injection)
27
- # Adjust signature - _adjust_signature determines exclusions based on type
28
35
  _adjust_signature(wrapper, original_sig, takes_no_args)
29
- # takes_ctx in pydantic-ai Tool is specifically for RunContext
30
- return Tool(wrapper, takes_ctx=needs_run_context_for_pydantic)
36
+ return wrapper
31
37
 
32
38
 
33
39
  def _has_context_parameter(original_sig: inspect.Signature, context_type: type) -> bool:
@@ -71,13 +77,11 @@ def _create_wrapper(
71
77
  async def wrapper(*args, **kwargs):
72
78
  # Identify AnyContext parameter name from the original signature if needed
73
79
  any_context_param_name = None
74
-
75
80
  if needs_any_context_for_injection:
76
81
  for param in original_sig.parameters.values():
77
82
  if _is_annotated_with_context(param.annotation, AnyContext):
78
83
  any_context_param_name = param.name
79
84
  break # Found it, no need to continue
80
-
81
85
  if any_context_param_name is None:
82
86
  # This should not happen if needs_any_context_for_injection is True,
83
87
  # but check for safety
@@ -87,24 +91,25 @@ def _create_wrapper(
87
91
  # Inject the captured ctx into kwargs. This will overwrite if the LLM
88
92
  # somehow provided it.
89
93
  kwargs[any_context_param_name] = ctx
90
-
91
94
  # If the dummy argument was added for schema generation and is present in kwargs,
92
95
  # remove it before calling the original function, unless the original function
93
96
  # actually expects a parameter named '_dummy'.
94
97
  if "_dummy" in kwargs and "_dummy" not in original_sig.parameters:
95
98
  del kwargs["_dummy"]
96
-
97
99
  try:
98
- # Call the original function.
99
- # pydantic-ai is responsible for injecting RunContext if takes_ctx is True.
100
- # Our wrapper injects AnyContext if needed.
101
- # The arguments received by the wrapper (*args, **kwargs) are those
102
- # provided by the LLM, potentially with RunContext already injected by
103
- # pydantic-ai if takes_ctx is True. We just need to ensure AnyContext
104
- # is injected if required by the original function.
105
- # The dummy argument handling is moved to _adjust_signature's logic
106
- # for schema generation, it's not needed here before calling the actual
107
- # function.
100
+ if not CFG.LLM_YOLO_MODE and not ctx.is_web_mode and ctx.is_tty:
101
+ func_name = get_callable_name(func)
102
+ ctx.print(f"✅ >> Allow to run tool: {func_name} (Y/n)", plain=True)
103
+ user_confirmation_str = await _read_line()
104
+ try:
105
+ user_confirmation = to_boolean(user_confirmation_str)
106
+ except Exception:
107
+ user_confirmation = False
108
+ if not user_confirmation:
109
+ ctx.print(f"❌ >> Rejecting {func_name} call. Why?", plain=True)
110
+ reason = await _read_line()
111
+ ctx.print("", plain=True)
112
+ raise ValueError(f"User disapproval: {reason}")
108
113
  return await run_async(func(*args, **kwargs))
109
114
  except Exception as e:
110
115
  error_model = ToolExecutionError(
@@ -118,6 +123,13 @@ def _create_wrapper(
118
123
  return wrapper
119
124
 
120
125
 
126
+ async def _read_line():
127
+ from prompt_toolkit import PromptSession
128
+
129
+ reader = PromptSession()
130
+ return await reader.prompt_async()
131
+
132
+
121
133
  def _adjust_signature(
122
134
  wrapper: Callable, original_sig: inspect.Signature, takes_no_args: bool
123
135
  ):