zrb 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/chat_session.py +38 -9
- zrb/builtin/llm/llm_ask.py +11 -0
- zrb/builtin/llm/tool/sub_agent.py +5 -5
- zrb/config/config.py +4 -0
- zrb/config/default_prompt/interactive_system_prompt.md +8 -4
- zrb/config/default_prompt/summarization_prompt.md +16 -42
- zrb/config/default_prompt/system_prompt.md +8 -4
- zrb/config/llm_config.py +23 -30
- zrb/config/llm_context/config.py +100 -45
- zrb/config/llm_context/config_parser.py +46 -0
- zrb/context/shared_context.py +4 -1
- zrb/task/llm/agent.py +28 -13
- zrb/task/llm/conversation_history_model.py +18 -68
- zrb/{config/default_workflow/code.md → task/llm/default_workflow/coding.md} +0 -2
- zrb/{config/default_workflow/content.md → task/llm/default_workflow/copywriting.md} +0 -2
- zrb/{config/default_workflow/research.md → task/llm/default_workflow/researching.md} +0 -2
- zrb/task/llm/history_summarization.py +2 -4
- zrb/task/llm/print_node.py +4 -1
- zrb/task/llm/prompt.py +78 -27
- zrb/task/llm/tool_wrapper.py +30 -18
- zrb/task/llm_task.py +31 -35
- zrb/util/callable.py +23 -0
- zrb/util/llm/prompt.py +19 -10
- {zrb-1.11.0.dist-info → zrb-1.13.0.dist-info}/METADATA +3 -3
- {zrb-1.11.0.dist-info → zrb-1.13.0.dist-info}/RECORD +27 -26
- zrb/config/llm_context/config_handler.py +0 -238
- {zrb-1.11.0.dist-info → zrb-1.13.0.dist-info}/WHEEL +0 -0
- {zrb-1.11.0.dist-info → zrb-1.13.0.dist-info}/entry_points.txt +0 -0
zrb/task/llm/agent.py
CHANGED
@@ -7,15 +7,15 @@ from zrb.context.any_context import AnyContext
|
|
7
7
|
from zrb.context.any_shared_context import AnySharedContext
|
8
8
|
from zrb.task.llm.error import extract_api_error_details
|
9
9
|
from zrb.task.llm.print_node import print_node
|
10
|
-
from zrb.task.llm.tool_wrapper import wrap_tool
|
10
|
+
from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
|
11
11
|
from zrb.task.llm.typing import ListOfDict
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from pydantic_ai import Agent, Tool
|
15
15
|
from pydantic_ai.agent import AgentRun
|
16
|
-
from pydantic_ai.mcp import MCPServer
|
17
16
|
from pydantic_ai.models import Model
|
18
17
|
from pydantic_ai.settings import ModelSettings
|
18
|
+
from pydantic_ai.toolsets import AbstractToolset
|
19
19
|
|
20
20
|
ToolOrCallable = Tool | Callable
|
21
21
|
else:
|
@@ -28,26 +28,43 @@ def create_agent_instance(
|
|
28
28
|
system_prompt: str = "",
|
29
29
|
model_settings: "ModelSettings | None" = None,
|
30
30
|
tools: list[ToolOrCallable] = [],
|
31
|
-
|
31
|
+
toolsets: list["AbstractToolset[Agent]"] = [],
|
32
32
|
retries: int = 3,
|
33
33
|
) -> "Agent":
|
34
34
|
"""Creates a new Agent instance with configured tools and servers."""
|
35
35
|
from pydantic_ai import Agent, Tool
|
36
|
+
from pydantic_ai.tools import GenerateToolJsonSchema
|
36
37
|
|
37
38
|
# Normalize tools
|
38
39
|
tool_list = []
|
39
40
|
for tool_or_callable in tools:
|
40
41
|
if isinstance(tool_or_callable, Tool):
|
41
42
|
tool_list.append(tool_or_callable)
|
43
|
+
# Update tool's function
|
44
|
+
tool = tool_or_callable
|
45
|
+
tool_list.append(
|
46
|
+
Tool(
|
47
|
+
function=wrap_func(tool.function),
|
48
|
+
takes_ctx=tool.takes_ctx,
|
49
|
+
max_retries=tool.max_retries,
|
50
|
+
name=tool.name,
|
51
|
+
description=tool.description,
|
52
|
+
prepare=tool.prepare,
|
53
|
+
docstring_format=tool.docstring_format,
|
54
|
+
require_parameter_descriptions=tool.require_parameter_descriptions,
|
55
|
+
schema_generator=GenerateToolJsonSchema,
|
56
|
+
strict=tool.strict,
|
57
|
+
)
|
58
|
+
)
|
42
59
|
else:
|
43
|
-
#
|
60
|
+
# Turn function into tool
|
44
61
|
tool_list.append(wrap_tool(tool_or_callable, ctx))
|
45
62
|
# Return Agent
|
46
63
|
return Agent(
|
47
64
|
model=model,
|
48
65
|
system_prompt=system_prompt,
|
49
66
|
tools=tool_list,
|
50
|
-
toolsets=
|
67
|
+
toolsets=toolsets,
|
51
68
|
model_settings=model_settings,
|
52
69
|
retries=retries,
|
53
70
|
)
|
@@ -63,8 +80,8 @@ def get_agent(
|
|
63
80
|
list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
|
64
81
|
),
|
65
82
|
additional_tools: list[ToolOrCallable],
|
66
|
-
|
67
|
-
|
83
|
+
toolsets_attr: "list[AbstractToolset[Agent]] | Callable[[AnySharedContext], list[AbstractToolset[Agent]]]", # noqa
|
84
|
+
additional_toolsets: "list[AbstractToolset[Agent]]",
|
68
85
|
retries: int = 3,
|
69
86
|
) -> "Agent":
|
70
87
|
"""Retrieves the configured Agent instance or creates one if necessary."""
|
@@ -85,18 +102,16 @@ def get_agent(
|
|
85
102
|
# Get tools for agent
|
86
103
|
tools = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
|
87
104
|
tools.extend(additional_tools)
|
88
|
-
# Get
|
89
|
-
|
90
|
-
|
91
|
-
)
|
92
|
-
mcp_servers.extend(additional_mcp_servers)
|
105
|
+
# Get Toolsets for agent
|
106
|
+
tool_sets = list(toolsets_attr(ctx) if callable(toolsets_attr) else toolsets_attr)
|
107
|
+
tool_sets.extend(additional_toolsets)
|
93
108
|
# If no agent provided, create one using the configuration
|
94
109
|
return create_agent_instance(
|
95
110
|
ctx=ctx,
|
96
111
|
model=model,
|
97
112
|
system_prompt=system_prompt,
|
98
113
|
tools=tools,
|
99
|
-
|
114
|
+
toolsets=tool_sets,
|
100
115
|
model_settings=model_settings,
|
101
116
|
retries=retries,
|
102
117
|
)
|
@@ -176,46 +176,23 @@ class ConversationHistory:
|
|
176
176
|
"""
|
177
177
|
return json.dumps({"content": self._fetch_long_term_note()})
|
178
178
|
|
179
|
-
def
|
179
|
+
def write_long_term_note(self, content: str) -> str:
|
180
180
|
"""
|
181
|
-
|
181
|
+
Write the entire content of the long-term references.
|
182
|
+
This will overwrite any existing long-term notes.
|
182
183
|
|
183
184
|
Args:
|
184
|
-
|
185
|
+
content (str): The full content of the long-term notes.
|
185
186
|
|
186
187
|
Returns:
|
187
|
-
str: JSON
|
188
|
-
|
189
|
-
Raises:
|
190
|
-
Exception: If the note cannot be read.
|
191
|
-
"""
|
192
|
-
llm_context_config.add_to_context(new_info, cwd="/")
|
193
|
-
return json.dumps({"success": True, "content": self._fetch_long_term_note()})
|
194
|
-
|
195
|
-
def remove_long_term_info(self, irrelevant_info: str) -> str:
|
188
|
+
str: JSON indicating success.
|
196
189
|
"""
|
197
|
-
|
198
|
-
|
199
|
-
Args:
|
200
|
-
irrelevant_info (str): Irrelevant info to be removed from long-term references.
|
201
|
-
|
202
|
-
Returns:
|
203
|
-
str: JSON with new content of the notes and deletion status.
|
204
|
-
|
205
|
-
Raises:
|
206
|
-
Exception: If the note cannot be read.
|
207
|
-
"""
|
208
|
-
was_removed = llm_context_config.remove_from_context(irrelevant_info, cwd="/")
|
209
|
-
return json.dumps(
|
210
|
-
{
|
211
|
-
"success": was_removed,
|
212
|
-
"content": self._fetch_long_term_note(),
|
213
|
-
}
|
214
|
-
)
|
190
|
+
llm_context_config.write_context(content, context_path="/")
|
191
|
+
return json.dumps({"success": True})
|
215
192
|
|
216
193
|
def read_contextual_note(self) -> str:
|
217
194
|
"""
|
218
|
-
Read the content of the contextual references.
|
195
|
+
Read the content of the contextual references for the current project.
|
219
196
|
|
220
197
|
This tool helps you retrieve knowledge or notes stored for contextual reference.
|
221
198
|
If the note does not exist, you may want to create it using the write tool.
|
@@ -228,52 +205,25 @@ class ConversationHistory:
|
|
228
205
|
"""
|
229
206
|
return json.dumps({"content": self._fetch_contextual_note()})
|
230
207
|
|
231
|
-
def
|
232
|
-
|
233
|
-
Add new info for contextual reference.
|
234
|
-
|
235
|
-
Args:
|
236
|
-
new_info (str): New info to be added into contextual references.
|
237
|
-
context_path (str, optional): contextual directory path for new info
|
238
|
-
|
239
|
-
Returns:
|
240
|
-
str: JSON with new content of the notes.
|
241
|
-
|
242
|
-
Raises:
|
243
|
-
Exception: If the note cannot be read.
|
244
|
-
"""
|
245
|
-
if context_path is None:
|
246
|
-
context_path = self.project_path
|
247
|
-
llm_context_config.add_to_context(new_info, context_path=context_path)
|
248
|
-
return json.dumps({"success": True, "content": self._fetch_contextual_note()})
|
249
|
-
|
250
|
-
def remove_contextual_info(
|
251
|
-
self, irrelevant_info: str, context_path: str | None
|
208
|
+
def write_contextual_note(
|
209
|
+
self, content: str, context_path: str | None = None
|
252
210
|
) -> str:
|
253
211
|
"""
|
254
|
-
|
212
|
+
Write the entire content of the contextual references for a specific path.
|
213
|
+
This will overwrite any existing contextual notes for that path.
|
255
214
|
|
256
215
|
Args:
|
257
|
-
|
258
|
-
context_path (str, optional):
|
216
|
+
content (str): The full content of the contextual notes.
|
217
|
+
context_path (str, optional): The directory path for the context.
|
218
|
+
Defaults to the current project path.
|
259
219
|
|
260
220
|
Returns:
|
261
|
-
str: JSON
|
262
|
-
|
263
|
-
Raises:
|
264
|
-
Exception: If the note cannot be read.
|
221
|
+
str: JSON indicating success.
|
265
222
|
"""
|
266
223
|
if context_path is None:
|
267
224
|
context_path = self.project_path
|
268
|
-
|
269
|
-
|
270
|
-
)
|
271
|
-
return json.dumps(
|
272
|
-
{
|
273
|
-
"success": was_removed,
|
274
|
-
"content": self._fetch_contextual_note(),
|
275
|
-
}
|
276
|
-
)
|
225
|
+
llm_context_config.write_context(content, context_path=context_path)
|
226
|
+
return json.dumps({"success": True})
|
277
227
|
|
278
228
|
def _fetch_long_term_note(self):
|
279
229
|
contexts = llm_context_config.get_contexts(cwd=self.project_path)
|
@@ -146,11 +146,9 @@ async def summarize_history(
|
|
146
146
|
conversation_history.write_past_conversation_summary,
|
147
147
|
conversation_history.write_past_conversation_transcript,
|
148
148
|
conversation_history.read_long_term_note,
|
149
|
-
conversation_history.
|
150
|
-
conversation_history.remove_long_term_info,
|
149
|
+
conversation_history.write_long_term_note,
|
151
150
|
conversation_history.read_contextual_note,
|
152
|
-
conversation_history.
|
153
|
-
conversation_history.remove_contextual_info,
|
151
|
+
conversation_history.write_contextual_note,
|
154
152
|
],
|
155
153
|
)
|
156
154
|
try:
|
zrb/task/llm/print_node.py
CHANGED
@@ -14,6 +14,7 @@ async def print_node(print_func: Callable, agent_run: Any, node: Any):
|
|
14
14
|
PartDeltaEvent,
|
15
15
|
PartStartEvent,
|
16
16
|
TextPartDelta,
|
17
|
+
ThinkingPartDelta,
|
17
18
|
ToolCallPartDelta,
|
18
19
|
)
|
19
20
|
|
@@ -33,7 +34,9 @@ async def print_node(print_func: Callable, agent_run: Any, node: Any):
|
|
33
34
|
)
|
34
35
|
is_streaming = False
|
35
36
|
elif isinstance(event, PartDeltaEvent):
|
36
|
-
if isinstance(event.delta, TextPartDelta)
|
37
|
+
if isinstance(event.delta, TextPartDelta) or isinstance(
|
38
|
+
event.delta, ThinkingPartDelta
|
39
|
+
):
|
37
40
|
print_func(
|
38
41
|
stylize_faint(f"{event.delta.content_delta}"),
|
39
42
|
end="",
|
zrb/task/llm/prompt.py
CHANGED
@@ -3,11 +3,12 @@ import platform
|
|
3
3
|
import re
|
4
4
|
from datetime import datetime, timezone
|
5
5
|
|
6
|
-
from zrb.attr.type import StrAttr
|
6
|
+
from zrb.attr.type import StrAttr, StrListAttr
|
7
7
|
from zrb.config.llm_config import llm_config as llm_config
|
8
|
+
from zrb.config.llm_context.config import llm_context_config
|
8
9
|
from zrb.context.any_context import AnyContext
|
9
10
|
from zrb.task.llm.conversation_history_model import ConversationHistory
|
10
|
-
from zrb.util.attr import get_attr, get_str_attr
|
11
|
+
from zrb.util.attr import get_attr, get_str_attr, get_str_list_attr
|
11
12
|
from zrb.util.file import read_dir, read_file_with_line_numbers
|
12
13
|
from zrb.util.llm.prompt import make_prompt_section
|
13
14
|
|
@@ -15,13 +16,14 @@ from zrb.util.llm.prompt import make_prompt_section
|
|
15
16
|
def get_persona(
|
16
17
|
ctx: AnyContext,
|
17
18
|
persona_attr: StrAttr | None,
|
19
|
+
render_persona: bool,
|
18
20
|
) -> str:
|
19
21
|
"""Gets the persona, prioritizing task-specific, then default."""
|
20
22
|
persona = get_attr(
|
21
23
|
ctx,
|
22
24
|
persona_attr,
|
23
25
|
None,
|
24
|
-
auto_render=
|
26
|
+
auto_render=render_persona,
|
25
27
|
)
|
26
28
|
if persona is not None:
|
27
29
|
return persona
|
@@ -31,13 +33,14 @@ def get_persona(
|
|
31
33
|
def get_base_system_prompt(
|
32
34
|
ctx: AnyContext,
|
33
35
|
system_prompt_attr: StrAttr | None,
|
36
|
+
render_system_prompt: bool,
|
34
37
|
) -> str:
|
35
38
|
"""Gets the base system prompt, prioritizing task-specific, then default."""
|
36
39
|
system_prompt = get_attr(
|
37
40
|
ctx,
|
38
41
|
system_prompt_attr,
|
39
42
|
None,
|
40
|
-
auto_render=
|
43
|
+
auto_render=render_system_prompt,
|
41
44
|
)
|
42
45
|
if system_prompt is not None:
|
43
46
|
return system_prompt
|
@@ -47,33 +50,95 @@ def get_base_system_prompt(
|
|
47
50
|
def get_special_instruction_prompt(
|
48
51
|
ctx: AnyContext,
|
49
52
|
special_instruction_prompt_attr: StrAttr | None,
|
53
|
+
render_spcecial_instruction_prompt: bool,
|
50
54
|
) -> str:
|
51
55
|
"""Gets the special instruction prompt, prioritizing task-specific, then default."""
|
52
56
|
special_instruction = get_attr(
|
53
57
|
ctx,
|
54
58
|
special_instruction_prompt_attr,
|
55
59
|
None,
|
56
|
-
auto_render=
|
60
|
+
auto_render=render_spcecial_instruction_prompt,
|
57
61
|
)
|
58
62
|
if special_instruction is not None:
|
59
63
|
return special_instruction
|
60
64
|
return llm_config.default_special_instruction_prompt
|
61
65
|
|
62
66
|
|
67
|
+
def get_modes(
|
68
|
+
ctx: AnyContext,
|
69
|
+
modes_attr: StrAttr | None,
|
70
|
+
render_modes: bool,
|
71
|
+
) -> str:
|
72
|
+
"""Gets the modes, prioritizing task-specific, then default."""
|
73
|
+
raw_modes = get_str_list_attr(
|
74
|
+
ctx,
|
75
|
+
modes_attr,
|
76
|
+
auto_render=render_modes,
|
77
|
+
)
|
78
|
+
modes = [mode.strip() for mode in raw_modes if mode.strip() != ""]
|
79
|
+
if len(modes) > 0:
|
80
|
+
return modes
|
81
|
+
return llm_config.default_modes or []
|
82
|
+
|
83
|
+
|
84
|
+
def get_workflow_prompt(
|
85
|
+
ctx: AnyContext,
|
86
|
+
modes_attr: StrAttr | None,
|
87
|
+
render_modes: bool,
|
88
|
+
) -> str:
|
89
|
+
modes = get_modes(ctx, modes_attr, render_modes)
|
90
|
+
# Get user-defined workflows
|
91
|
+
workflows = {
|
92
|
+
workflow_name: content
|
93
|
+
for workflow_name, content in llm_context_config.get_workflows().items()
|
94
|
+
if workflow_name in modes
|
95
|
+
}
|
96
|
+
# Get requested builtin-workflow names
|
97
|
+
requested_builtin_workflow_names = [
|
98
|
+
workflow_name
|
99
|
+
for workflow_name in ("coding", "copywriting", "researching")
|
100
|
+
if workflow_name in modes and workflow_name not in workflows
|
101
|
+
]
|
102
|
+
# add builtin-workflows if requested
|
103
|
+
if len(requested_builtin_workflow_names) > 0:
|
104
|
+
dir_path = os.path.dirname(__file__)
|
105
|
+
for workflow_name in requested_builtin_workflow_names:
|
106
|
+
workflow_file_path = os.path.join(
|
107
|
+
dir_path, "default_workflow", f"{workflow_name}.md"
|
108
|
+
)
|
109
|
+
with open(workflow_file_path, "r") as f:
|
110
|
+
workflows[workflow_name] = f.read()
|
111
|
+
return "\n".join(
|
112
|
+
[
|
113
|
+
make_prompt_section(header.capitalize(), content)
|
114
|
+
for header, content in workflows.items()
|
115
|
+
if header.lower() in modes
|
116
|
+
]
|
117
|
+
)
|
118
|
+
|
119
|
+
|
63
120
|
def get_system_and_user_prompt(
|
64
121
|
ctx: AnyContext,
|
65
122
|
user_message: str,
|
66
123
|
persona_attr: StrAttr | None = None,
|
124
|
+
render_persona: bool = False,
|
67
125
|
system_prompt_attr: StrAttr | None = None,
|
126
|
+
render_system_prompt: bool = False,
|
68
127
|
special_instruction_prompt_attr: StrAttr | None = None,
|
128
|
+
render_special_instruction_prompt: bool = False,
|
129
|
+
modes_attr: StrListAttr | None = None,
|
130
|
+
render_modes: bool = False,
|
69
131
|
conversation_history: ConversationHistory | None = None,
|
70
132
|
) -> tuple[str, str]:
|
71
133
|
"""Combines persona, base system prompt, and special instructions."""
|
72
|
-
persona = get_persona(ctx, persona_attr)
|
73
|
-
base_system_prompt = get_base_system_prompt(
|
74
|
-
|
75
|
-
ctx, special_instruction_prompt_attr
|
134
|
+
persona = get_persona(ctx, persona_attr, render_persona)
|
135
|
+
base_system_prompt = get_base_system_prompt(
|
136
|
+
ctx, system_prompt_attr, render_system_prompt
|
76
137
|
)
|
138
|
+
special_instruction_prompt = get_special_instruction_prompt(
|
139
|
+
ctx, special_instruction_prompt_attr, render_special_instruction_prompt
|
140
|
+
)
|
141
|
+
workflow_prompt = get_workflow_prompt(ctx, modes_attr, render_modes)
|
77
142
|
if conversation_history is None:
|
78
143
|
conversation_history = ConversationHistory()
|
79
144
|
conversation_context, new_user_message = extract_conversation_context(user_message)
|
@@ -81,7 +146,8 @@ def get_system_and_user_prompt(
|
|
81
146
|
[
|
82
147
|
make_prompt_section("Persona", persona),
|
83
148
|
make_prompt_section("System Prompt", base_system_prompt),
|
84
|
-
make_prompt_section("Special Instruction",
|
149
|
+
make_prompt_section("Special Instruction", special_instruction_prompt),
|
150
|
+
make_prompt_section("Special Workflows", workflow_prompt),
|
85
151
|
make_prompt_section(
|
86
152
|
"Past Conversation",
|
87
153
|
"\n".join(
|
@@ -194,30 +260,15 @@ def get_user_message(
|
|
194
260
|
def get_summarization_system_prompt(
|
195
261
|
ctx: AnyContext,
|
196
262
|
summarization_prompt_attr: StrAttr | None,
|
263
|
+
render_summarization_prompt: bool,
|
197
264
|
) -> str:
|
198
265
|
"""Gets the summarization prompt, rendering if configured and handling defaults."""
|
199
266
|
summarization_prompt = get_attr(
|
200
267
|
ctx,
|
201
268
|
summarization_prompt_attr,
|
202
269
|
None,
|
203
|
-
auto_render=
|
270
|
+
auto_render=render_summarization_prompt,
|
204
271
|
)
|
205
272
|
if summarization_prompt is not None:
|
206
273
|
return summarization_prompt
|
207
274
|
return llm_config.default_summarization_prompt
|
208
|
-
|
209
|
-
|
210
|
-
def get_context_enrichment_prompt(
|
211
|
-
ctx: AnyContext,
|
212
|
-
context_enrichment_prompt_attr: StrAttr | None,
|
213
|
-
) -> str:
|
214
|
-
"""Gets the context enrichment prompt, rendering if configured and handling defaults."""
|
215
|
-
context_enrichment_prompt = get_attr(
|
216
|
-
ctx,
|
217
|
-
context_enrichment_prompt_attr,
|
218
|
-
None,
|
219
|
-
auto_render=False,
|
220
|
-
)
|
221
|
-
if context_enrichment_prompt is not None:
|
222
|
-
return context_enrichment_prompt
|
223
|
-
return llm_config.default_context_enrichment_prompt
|
zrb/task/llm/tool_wrapper.py
CHANGED
@@ -5,9 +5,12 @@ import typing
|
|
5
5
|
from collections.abc import Callable
|
6
6
|
from typing import TYPE_CHECKING
|
7
7
|
|
8
|
+
from zrb.config.config import CFG
|
8
9
|
from zrb.context.any_context import AnyContext
|
9
10
|
from zrb.task.llm.error import ToolExecutionError
|
11
|
+
from zrb.util.callable import get_callable_name
|
10
12
|
from zrb.util.run import run_async
|
13
|
+
from zrb.util.string.conversion import to_boolean
|
11
14
|
|
12
15
|
if TYPE_CHECKING:
|
13
16
|
from pydantic_ai import Tool
|
@@ -18,16 +21,19 @@ def wrap_tool(func: Callable, ctx: AnyContext) -> "Tool":
|
|
18
21
|
from pydantic_ai import RunContext, Tool
|
19
22
|
|
20
23
|
original_sig = inspect.signature(func)
|
21
|
-
# Use helper function for clarity
|
22
24
|
needs_run_context_for_pydantic = _has_context_parameter(original_sig, RunContext)
|
25
|
+
wrapper = wrap_func(func, ctx)
|
26
|
+
return Tool(wrapper, takes_ctx=needs_run_context_for_pydantic)
|
27
|
+
|
28
|
+
|
29
|
+
def wrap_func(func: Callable, ctx: AnyContext) -> Callable:
|
30
|
+
original_sig = inspect.signature(func)
|
23
31
|
needs_any_context_for_injection = _has_context_parameter(original_sig, AnyContext)
|
24
32
|
takes_no_args = len(original_sig.parameters) == 0
|
25
33
|
# Pass individual flags to the wrapper creator
|
26
34
|
wrapper = _create_wrapper(func, original_sig, ctx, needs_any_context_for_injection)
|
27
|
-
# Adjust signature - _adjust_signature determines exclusions based on type
|
28
35
|
_adjust_signature(wrapper, original_sig, takes_no_args)
|
29
|
-
|
30
|
-
return Tool(wrapper, takes_ctx=needs_run_context_for_pydantic)
|
36
|
+
return wrapper
|
31
37
|
|
32
38
|
|
33
39
|
def _has_context_parameter(original_sig: inspect.Signature, context_type: type) -> bool:
|
@@ -71,13 +77,11 @@ def _create_wrapper(
|
|
71
77
|
async def wrapper(*args, **kwargs):
|
72
78
|
# Identify AnyContext parameter name from the original signature if needed
|
73
79
|
any_context_param_name = None
|
74
|
-
|
75
80
|
if needs_any_context_for_injection:
|
76
81
|
for param in original_sig.parameters.values():
|
77
82
|
if _is_annotated_with_context(param.annotation, AnyContext):
|
78
83
|
any_context_param_name = param.name
|
79
84
|
break # Found it, no need to continue
|
80
|
-
|
81
85
|
if any_context_param_name is None:
|
82
86
|
# This should not happen if needs_any_context_for_injection is True,
|
83
87
|
# but check for safety
|
@@ -87,24 +91,25 @@ def _create_wrapper(
|
|
87
91
|
# Inject the captured ctx into kwargs. This will overwrite if the LLM
|
88
92
|
# somehow provided it.
|
89
93
|
kwargs[any_context_param_name] = ctx
|
90
|
-
|
91
94
|
# If the dummy argument was added for schema generation and is present in kwargs,
|
92
95
|
# remove it before calling the original function, unless the original function
|
93
96
|
# actually expects a parameter named '_dummy'.
|
94
97
|
if "_dummy" in kwargs and "_dummy" not in original_sig.parameters:
|
95
98
|
del kwargs["_dummy"]
|
96
|
-
|
97
99
|
try:
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
100
|
+
if not CFG.LLM_YOLO_MODE and not ctx.is_web_mode and ctx.is_tty:
|
101
|
+
func_name = get_callable_name(func)
|
102
|
+
ctx.print(f"✅ >> Allow to run tool: {func_name} (Y/n)", plain=True)
|
103
|
+
user_confirmation_str = await _read_line()
|
104
|
+
try:
|
105
|
+
user_confirmation = to_boolean(user_confirmation_str)
|
106
|
+
except Exception:
|
107
|
+
user_confirmation = False
|
108
|
+
if not user_confirmation:
|
109
|
+
ctx.print(f"❌ >> Rejecting {func_name} call. Why?", plain=True)
|
110
|
+
reason = await _read_line()
|
111
|
+
ctx.print("", plain=True)
|
112
|
+
raise ValueError(f"User disapproval: {reason}")
|
108
113
|
return await run_async(func(*args, **kwargs))
|
109
114
|
except Exception as e:
|
110
115
|
error_model = ToolExecutionError(
|
@@ -118,6 +123,13 @@ def _create_wrapper(
|
|
118
123
|
return wrapper
|
119
124
|
|
120
125
|
|
126
|
+
async def _read_line():
|
127
|
+
from prompt_toolkit import PromptSession
|
128
|
+
|
129
|
+
reader = PromptSession()
|
130
|
+
return await reader.prompt_async()
|
131
|
+
|
132
|
+
|
121
133
|
def _adjust_signature(
|
122
134
|
wrapper: Callable, original_sig: inspect.Signature, takes_no_args: bool
|
123
135
|
):
|