zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_completion.py +274 -0
- zrb/builtin/llm/chat_session.py +126 -167
- zrb/builtin/llm/chat_session_cmd.py +288 -0
- zrb/builtin/llm/chat_trigger.py +79 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +217 -135
- zrb/builtin/llm/tool/api.py +74 -70
- zrb/builtin/llm/tool/cli.py +35 -21
- zrb/builtin/llm/tool/code.py +55 -73
- zrb/builtin/llm/tool/file.py +278 -344
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +27 -34
- zrb/builtin/llm/tool/sub_agent.py +54 -41
- zrb/builtin/llm/tool/web.py +74 -98
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +202 -27
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +57 -16
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +119 -23
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +100 -47
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +3 -7
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +13 -1
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +128 -167
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/config.py +39 -20
- zrb/task/llm/conversation_history.py +110 -29
- zrb/task/llm/conversation_history_model.py +4 -179
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +2 -193
- zrb/task/llm/print_node.py +184 -64
- zrb/task/llm/prompt.py +175 -179
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_wrapper.py +226 -85
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +109 -71
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +12 -3
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
zrb/task/llm/tool_wrapper.py
CHANGED
|
@@ -1,45 +1,60 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import inspect
|
|
3
|
+
import signal
|
|
3
4
|
import traceback
|
|
4
5
|
import typing
|
|
5
6
|
from collections.abc import Callable
|
|
6
7
|
from typing import TYPE_CHECKING, Any
|
|
7
8
|
|
|
9
|
+
from zrb.config.config import CFG
|
|
10
|
+
from zrb.config.llm_rate_limitter import llm_rate_limitter
|
|
8
11
|
from zrb.context.any_context import AnyContext
|
|
9
12
|
from zrb.task.llm.error import ToolExecutionError
|
|
13
|
+
from zrb.task.llm.file_replacement import edit_replacement, is_single_path_replacement
|
|
10
14
|
from zrb.util.callable import get_callable_name
|
|
15
|
+
from zrb.util.cli.markdown import render_markdown
|
|
11
16
|
from zrb.util.cli.style import (
|
|
12
17
|
stylize_blue,
|
|
13
18
|
stylize_error,
|
|
19
|
+
stylize_faint,
|
|
14
20
|
stylize_green,
|
|
15
21
|
stylize_yellow,
|
|
16
22
|
)
|
|
23
|
+
from zrb.util.cli.text import edit_text
|
|
17
24
|
from zrb.util.run import run_async
|
|
18
25
|
from zrb.util.string.conversion import to_boolean
|
|
26
|
+
from zrb.util.yaml import edit_obj, yaml_dump
|
|
19
27
|
|
|
20
28
|
if TYPE_CHECKING:
|
|
21
29
|
from pydantic_ai import Tool
|
|
22
30
|
|
|
23
31
|
|
|
24
|
-
|
|
32
|
+
class ToolExecutionCancelled(ValueError):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def wrap_tool(func: Callable, ctx: AnyContext, yolo_mode: bool | list[str]) -> "Tool":
|
|
25
37
|
"""Wraps a tool function to handle exceptions and context propagation."""
|
|
26
38
|
from pydantic_ai import RunContext, Tool
|
|
27
39
|
|
|
28
40
|
original_sig = inspect.signature(func)
|
|
29
41
|
needs_run_context_for_pydantic = _has_context_parameter(original_sig, RunContext)
|
|
30
|
-
wrapper = wrap_func(func, ctx,
|
|
42
|
+
wrapper = wrap_func(func, ctx, yolo_mode)
|
|
31
43
|
return Tool(wrapper, takes_ctx=needs_run_context_for_pydantic)
|
|
32
44
|
|
|
33
45
|
|
|
34
|
-
def wrap_func(func: Callable, ctx: AnyContext,
|
|
46
|
+
def wrap_func(func: Callable, ctx: AnyContext, yolo_mode: bool | list[str]) -> Callable:
|
|
35
47
|
original_sig = inspect.signature(func)
|
|
36
48
|
needs_any_context_for_injection = _has_context_parameter(original_sig, AnyContext)
|
|
37
|
-
takes_no_args = len(original_sig.parameters) == 0
|
|
38
49
|
# Pass individual flags to the wrapper creator
|
|
39
50
|
wrapper = _create_wrapper(
|
|
40
|
-
func,
|
|
51
|
+
func=func,
|
|
52
|
+
original_signature=original_sig,
|
|
53
|
+
ctx=ctx,
|
|
54
|
+
needs_any_context_for_injection=needs_any_context_for_injection,
|
|
55
|
+
yolo_mode=yolo_mode,
|
|
41
56
|
)
|
|
42
|
-
_adjust_signature(wrapper, original_sig
|
|
57
|
+
_adjust_signature(wrapper, original_sig)
|
|
43
58
|
return wrapper
|
|
44
59
|
|
|
45
60
|
|
|
@@ -74,10 +89,10 @@ def _is_annotated_with_context(param_annotation, context_type):
|
|
|
74
89
|
|
|
75
90
|
def _create_wrapper(
|
|
76
91
|
func: Callable,
|
|
77
|
-
|
|
92
|
+
original_signature: inspect.Signature,
|
|
78
93
|
ctx: AnyContext,
|
|
79
94
|
needs_any_context_for_injection: bool,
|
|
80
|
-
|
|
95
|
+
yolo_mode: bool | list[str],
|
|
81
96
|
) -> Callable:
|
|
82
97
|
"""Creates the core wrapper function."""
|
|
83
98
|
|
|
@@ -86,7 +101,7 @@ def _create_wrapper(
|
|
|
86
101
|
# Identify AnyContext parameter name from the original signature if needed
|
|
87
102
|
any_context_param_name = None
|
|
88
103
|
if needs_any_context_for_injection:
|
|
89
|
-
for param in
|
|
104
|
+
for param in original_signature.parameters.values():
|
|
90
105
|
if _is_annotated_with_context(param.annotation, AnyContext):
|
|
91
106
|
any_context_param_name = param.name
|
|
92
107
|
break # Found it, no need to continue
|
|
@@ -99,32 +114,167 @@ def _create_wrapper(
|
|
|
99
114
|
# Inject the captured ctx into kwargs. This will overwrite if the LLM
|
|
100
115
|
# somehow provided it.
|
|
101
116
|
kwargs[any_context_param_name] = ctx
|
|
102
|
-
#
|
|
103
|
-
#
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
del kwargs["_dummy"]
|
|
117
|
+
# We will need to overwrite SIGINT handler, so that when user press ctrl + c,
|
|
118
|
+
# the program won't immediately exit
|
|
119
|
+
original_sigint_handler = signal.getsignal(signal.SIGINT)
|
|
120
|
+
tool_name = get_callable_name(func)
|
|
107
121
|
try:
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
if
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
122
|
+
has_ever_edited = False
|
|
123
|
+
if not ctx.is_web_mode and ctx.is_tty:
|
|
124
|
+
if (
|
|
125
|
+
isinstance(yolo_mode, list) and func.__name__ not in yolo_mode
|
|
126
|
+
) or not yolo_mode:
|
|
127
|
+
approval, reason, kwargs, has_ever_edited = (
|
|
128
|
+
await _handle_user_response(ctx, func, args, kwargs)
|
|
129
|
+
)
|
|
130
|
+
if not approval:
|
|
131
|
+
raise ToolExecutionCancelled(
|
|
132
|
+
f"Tool execution cancelled. User disapproving: {reason}"
|
|
133
|
+
)
|
|
134
|
+
signal.signal(signal.SIGINT, _tool_wrapper_sigint_handler)
|
|
135
|
+
ctx.print(stylize_faint(f"Run {tool_name}"), plain=True)
|
|
136
|
+
result = await run_async(func(*args, **kwargs))
|
|
137
|
+
_check_tool_call_result_limit(result)
|
|
138
|
+
if has_ever_edited:
|
|
139
|
+
return {
|
|
140
|
+
"tool_call_result": result,
|
|
141
|
+
"new_tool_parameters": kwargs,
|
|
142
|
+
"message": "User correction: Tool was called with user's parameters",
|
|
143
|
+
}
|
|
144
|
+
return result
|
|
145
|
+
except BaseException as e:
|
|
114
146
|
error_model = ToolExecutionError(
|
|
115
|
-
tool_name=
|
|
147
|
+
tool_name=tool_name,
|
|
116
148
|
error_type=type(e).__name__,
|
|
117
149
|
message=str(e),
|
|
118
150
|
details=traceback.format_exc(),
|
|
119
151
|
)
|
|
120
152
|
return error_model.model_dump_json()
|
|
153
|
+
finally:
|
|
154
|
+
signal.signal(signal.SIGINT, original_sigint_handler)
|
|
121
155
|
|
|
122
156
|
return wrapper
|
|
123
157
|
|
|
124
158
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
159
|
+
def _tool_wrapper_sigint_handler(signum, frame):
|
|
160
|
+
raise KeyboardInterrupt("SIGINT detected while running tool")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _check_tool_call_result_limit(result: Any):
|
|
164
|
+
if (
|
|
165
|
+
llm_rate_limitter.count_token(result)
|
|
166
|
+
> llm_rate_limitter.max_tokens_per_tool_call_result
|
|
167
|
+
):
|
|
168
|
+
raise ValueError("Result value is too large, please adjust the parameter")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def _handle_user_response(
|
|
172
|
+
ctx: AnyContext,
|
|
173
|
+
func: Callable,
|
|
174
|
+
args: list[Any] | tuple[Any],
|
|
175
|
+
kwargs: dict[str, Any],
|
|
176
|
+
) -> tuple[bool, str, dict[str, Any], bool]:
|
|
177
|
+
has_ever_edited = False
|
|
178
|
+
while True:
|
|
179
|
+
func_call_str = _get_func_call_str(func, args, kwargs)
|
|
180
|
+
complete_confirmation_message = "\n".join(
|
|
181
|
+
[
|
|
182
|
+
f"\n🎰 >> {func_call_str}",
|
|
183
|
+
_get_detail_func_param(args, kwargs),
|
|
184
|
+
f"🎰 >> {_get_run_func_confirmation(func)}",
|
|
185
|
+
]
|
|
186
|
+
)
|
|
187
|
+
ctx.print(complete_confirmation_message, plain=True)
|
|
188
|
+
user_response = await _read_line(args, kwargs)
|
|
189
|
+
ctx.print("", plain=True)
|
|
190
|
+
new_kwargs, is_edited = _get_edited_kwargs(ctx, user_response, kwargs)
|
|
191
|
+
if is_edited:
|
|
192
|
+
kwargs = new_kwargs
|
|
193
|
+
has_ever_edited = True
|
|
194
|
+
continue
|
|
195
|
+
approval_and_reason = _get_user_approval_and_reason(
|
|
196
|
+
ctx, user_response, func_call_str
|
|
197
|
+
)
|
|
198
|
+
if approval_and_reason is None:
|
|
199
|
+
continue
|
|
200
|
+
approval, reason = approval_and_reason
|
|
201
|
+
return approval, reason, kwargs, has_ever_edited
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _get_edited_kwargs(
|
|
205
|
+
ctx: AnyContext, user_response: str, kwargs: dict[str, Any]
|
|
206
|
+
) -> tuple[dict[str, Any], bool]:
|
|
207
|
+
user_edit_responses = [val for val in user_response.split(" ", maxsplit=2)]
|
|
208
|
+
if len(user_edit_responses) >= 1 and user_edit_responses[0].lower() != "edit":
|
|
209
|
+
return kwargs, False
|
|
210
|
+
while len(user_edit_responses) < 3:
|
|
211
|
+
user_edit_responses.append("")
|
|
212
|
+
key, val_str = user_edit_responses[1:]
|
|
213
|
+
# Make sure first segment of the key is in kwargs
|
|
214
|
+
if key != "":
|
|
215
|
+
key_parts = key.split(".")
|
|
216
|
+
if len(key_parts) > 0 and key_parts[0] not in kwargs:
|
|
217
|
+
return kwargs, True
|
|
218
|
+
# Handle replacement edit
|
|
219
|
+
if len(kwargs) == 1:
|
|
220
|
+
kwarg_key = list(kwargs.keys())[0]
|
|
221
|
+
if is_single_path_replacement(kwargs[kwarg_key]) and (
|
|
222
|
+
key == "" or key == kwarg_key
|
|
223
|
+
):
|
|
224
|
+
kwargs[kwarg_key], edited = edit_replacement(kwargs[kwarg_key])
|
|
225
|
+
return kwargs, True
|
|
226
|
+
# Handle other kind of edit
|
|
227
|
+
old_val_str = yaml_dump(kwargs, key)
|
|
228
|
+
if val_str == "":
|
|
229
|
+
val_str = edit_text(
|
|
230
|
+
prompt_message=f"# {key}" if key != "" else "",
|
|
231
|
+
value=old_val_str,
|
|
232
|
+
editor=CFG.DEFAULT_EDITOR,
|
|
233
|
+
extension=".yaml",
|
|
234
|
+
)
|
|
235
|
+
if old_val_str == val_str:
|
|
236
|
+
return kwargs, True
|
|
237
|
+
edited_kwargs = edit_obj(kwargs, key, val_str)
|
|
238
|
+
return edited_kwargs, True
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _get_user_approval_and_reason(
|
|
242
|
+
ctx: AnyContext, user_response: str, func_call_str: str
|
|
243
|
+
) -> tuple[bool, str] | None:
|
|
244
|
+
user_approval_responses = [
|
|
245
|
+
val.strip() for val in user_response.split(",", maxsplit=1)
|
|
246
|
+
]
|
|
247
|
+
while len(user_approval_responses) < 2:
|
|
248
|
+
user_approval_responses.append("")
|
|
249
|
+
approval_str, reason = user_approval_responses
|
|
250
|
+
try:
|
|
251
|
+
approved = True if approval_str.strip() == "" else to_boolean(approval_str)
|
|
252
|
+
if not approved and reason == "":
|
|
253
|
+
reason = "User disapproving the tool execution"
|
|
254
|
+
return approved, reason
|
|
255
|
+
except Exception:
|
|
256
|
+
return False, user_response
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _get_run_func_confirmation(func: Callable) -> str:
|
|
260
|
+
func_name = get_callable_name(func)
|
|
261
|
+
return render_markdown(
|
|
262
|
+
f"Allow to run `{func_name}`? (✅ `Yes` | ⛔ `No, <reason>` | 📝 `Edit <param> <value>`)"
|
|
263
|
+
).strip()
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _get_detail_func_param(args: list[Any] | tuple[Any], kwargs: dict[str, Any]) -> str:
|
|
267
|
+
if not kwargs:
|
|
268
|
+
return ""
|
|
269
|
+
yaml_str = yaml_dump(kwargs)
|
|
270
|
+
# Create the final markdown string
|
|
271
|
+
markdown = f"```yaml\n{yaml_str}\n```"
|
|
272
|
+
return render_markdown(markdown)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _get_func_call_str(
|
|
276
|
+
func: Callable, args: list[Any] | tuple[Any], kwargs: dict[str, Any]
|
|
277
|
+
) -> str:
|
|
128
278
|
func_name = get_callable_name(func)
|
|
129
279
|
normalized_args = [stylize_green(_truncate_arg(arg)) for arg in args]
|
|
130
280
|
normalized_kwargs = []
|
|
@@ -133,58 +283,67 @@ async def _ask_for_approval(
|
|
|
133
283
|
normalized_kwargs.append(
|
|
134
284
|
f"{stylize_yellow(key)}={stylize_green(truncated_val)}"
|
|
135
285
|
)
|
|
136
|
-
func_param_str = ",".join(normalized_args + normalized_kwargs)
|
|
137
|
-
|
|
138
|
-
f"{stylize_blue(func_name + '(')}{func_param_str}{stylize_blue(')')}"
|
|
139
|
-
)
|
|
140
|
-
while True:
|
|
141
|
-
ctx.print(
|
|
142
|
-
f"\n✅ >> Allow to run tool: {func_call_str} (Yes | No, <reason>) ",
|
|
143
|
-
plain=True,
|
|
144
|
-
)
|
|
145
|
-
user_input = await _read_line()
|
|
146
|
-
ctx.print("", plain=True)
|
|
147
|
-
user_responses = [val.strip() for val in user_input.split(",", maxsplit=1)]
|
|
148
|
-
while len(user_responses) < 2:
|
|
149
|
-
user_responses.append("")
|
|
150
|
-
approval_str, reason = user_responses
|
|
151
|
-
try:
|
|
152
|
-
approved = True if approval_str.strip() == "" else to_boolean(approval_str)
|
|
153
|
-
if not approved and reason == "":
|
|
154
|
-
ctx.print(
|
|
155
|
-
stylize_error(
|
|
156
|
-
f"You must specify rejection reason (i.e., No, <why>) for {func_call_str}" # noqa
|
|
157
|
-
),
|
|
158
|
-
plain=True,
|
|
159
|
-
)
|
|
160
|
-
continue
|
|
161
|
-
return approved, reason
|
|
162
|
-
except Exception:
|
|
163
|
-
ctx.print(
|
|
164
|
-
stylize_error(
|
|
165
|
-
f"Invalid approval value for {func_call_str}: {approval_str}"
|
|
166
|
-
),
|
|
167
|
-
plain=True,
|
|
168
|
-
)
|
|
169
|
-
continue
|
|
286
|
+
func_param_str = ", ".join(normalized_args + normalized_kwargs)
|
|
287
|
+
return f"{stylize_blue(func_name + '(')}{func_param_str}{stylize_blue(')')}"
|
|
170
288
|
|
|
171
289
|
|
|
172
290
|
def _truncate_arg(arg: str, length: int = 19) -> str:
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
291
|
+
normalized_arg = arg.replace("\n", "\\n")
|
|
292
|
+
if len(normalized_arg) > length:
|
|
293
|
+
return f"{normalized_arg[:length-4]} ..."
|
|
294
|
+
return normalized_arg
|
|
176
295
|
|
|
177
296
|
|
|
178
|
-
async def _read_line():
|
|
297
|
+
async def _read_line(args: list[Any] | tuple[Any], kwargs: dict[str, Any]):
|
|
179
298
|
from prompt_toolkit import PromptSession
|
|
299
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
300
|
+
|
|
301
|
+
class ToolConfirmationCompleter(Completer):
|
|
302
|
+
"""Custom completer for tool confirmation that doesn't auto-complete partial words."""
|
|
303
|
+
|
|
304
|
+
def __init__(self, options, meta_dict):
|
|
305
|
+
self.options = options
|
|
306
|
+
self.meta_dict = meta_dict
|
|
307
|
+
|
|
308
|
+
def get_completions(self, document, complete_event):
|
|
309
|
+
text = document.text.strip()
|
|
310
|
+
|
|
311
|
+
# Only provide completions if:
|
|
312
|
+
# 1. Input is empty, OR
|
|
313
|
+
# 2. Input exactly matches the beginning of an option
|
|
314
|
+
if text == "":
|
|
315
|
+
# Show all options when nothing is typed
|
|
316
|
+
for option in self.options:
|
|
317
|
+
yield Completion(
|
|
318
|
+
option,
|
|
319
|
+
start_position=0,
|
|
320
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
# Only complete if text exactly matches the beginning of an option
|
|
324
|
+
for option in self.options:
|
|
325
|
+
if option.startswith(text):
|
|
326
|
+
yield Completion(
|
|
327
|
+
option,
|
|
328
|
+
start_position=-len(text),
|
|
329
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
330
|
+
)
|
|
180
331
|
|
|
332
|
+
options = ["yes", "no", "edit"]
|
|
333
|
+
meta_dict = {
|
|
334
|
+
"yes": "Approve the execution",
|
|
335
|
+
"no": "Disapprove the execution",
|
|
336
|
+
"edit": "Edit tool execution parameters",
|
|
337
|
+
}
|
|
338
|
+
for key in kwargs:
|
|
339
|
+
options.append(f"edit {key}")
|
|
340
|
+
meta_dict[f"edit {key}"] = f"Edit tool execution parameter: {key}"
|
|
341
|
+
completer = ToolConfirmationCompleter(options, meta_dict)
|
|
181
342
|
reader = PromptSession()
|
|
182
|
-
return await reader.prompt_async()
|
|
343
|
+
return await reader.prompt_async(completer=completer)
|
|
183
344
|
|
|
184
345
|
|
|
185
|
-
def _adjust_signature(
|
|
186
|
-
wrapper: Callable, original_sig: inspect.Signature, takes_no_args: bool
|
|
187
|
-
):
|
|
346
|
+
def _adjust_signature(wrapper: Callable, original_sig: inspect.Signature):
|
|
188
347
|
"""Adjusts the wrapper function's signature for schema generation."""
|
|
189
348
|
# The wrapper's signature should represent the arguments the *LLM* needs to provide.
|
|
190
349
|
# The LLM does not provide RunContext (pydantic-ai injects it) or AnyContext
|
|
@@ -199,22 +358,4 @@ def _adjust_signature(
|
|
|
199
358
|
if not _is_annotated_with_context(param.annotation, RunContext)
|
|
200
359
|
and not _is_annotated_with_context(param.annotation, AnyContext)
|
|
201
360
|
]
|
|
202
|
-
|
|
203
|
-
# If after removing context parameters, there are no parameters left,
|
|
204
|
-
# and the original function took no args, keep the dummy.
|
|
205
|
-
# If after removing context parameters, there are no parameters left,
|
|
206
|
-
# but the original function *did* take args (only context), then the schema
|
|
207
|
-
# should have no parameters.
|
|
208
|
-
if not params_for_schema and takes_no_args:
|
|
209
|
-
# Keep the dummy if the original function truly had no parameters
|
|
210
|
-
new_sig = inspect.Signature(
|
|
211
|
-
parameters=[
|
|
212
|
-
inspect.Parameter(
|
|
213
|
-
"_dummy", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
|
|
214
|
-
)
|
|
215
|
-
]
|
|
216
|
-
)
|
|
217
|
-
else:
|
|
218
|
-
new_sig = inspect.Signature(parameters=params_for_schema)
|
|
219
|
-
|
|
220
|
-
wrapper.__signature__ = new_sig
|
|
361
|
+
wrapper.__signature__ = inspect.Signature(parameters=params_for_schema)
|
zrb/task/llm/workflow.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from zrb.config.config import CFG
|
|
4
|
+
from zrb.config.llm_context.config import llm_context_config
|
|
5
|
+
from zrb.config.llm_context.workflow import LLMWorkflow
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_workflow(workflow_name: str | list[str]) -> str:
|
|
9
|
+
"""
|
|
10
|
+
Loads and formats one or more workflow documents for LLM consumption.
|
|
11
|
+
|
|
12
|
+
Retrieves workflows by name, formats with descriptive headers for LLM context injection.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
workflow_name: Name or list of names of the workflow(s) to load
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Formatted workflow content as a string with headers
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If any specified workflow name is not found
|
|
22
|
+
"""
|
|
23
|
+
names = [workflow_name] if isinstance(workflow_name, str) else workflow_name
|
|
24
|
+
available_workflows = get_available_workflows()
|
|
25
|
+
contents = []
|
|
26
|
+
for name in names:
|
|
27
|
+
workflow = available_workflows.get(name.strip().lower())
|
|
28
|
+
if workflow is None:
|
|
29
|
+
raise ValueError(f"Workflow not found: {name}")
|
|
30
|
+
contents.append(
|
|
31
|
+
"\n".join(
|
|
32
|
+
[
|
|
33
|
+
f"# {workflow.name}",
|
|
34
|
+
f"> Workflow Location: `{workflow.path}`",
|
|
35
|
+
workflow.content,
|
|
36
|
+
]
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
return "\n".join(contents)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
43
|
+
available_workflows = {
|
|
44
|
+
workflow_name.strip().lower(): workflow
|
|
45
|
+
for workflow_name, workflow in llm_context_config.get_workflows().items()
|
|
46
|
+
}
|
|
47
|
+
# Define builtin workflow locations in order of precedence
|
|
48
|
+
builtin_workflow_locations = [
|
|
49
|
+
os.path.expanduser(additional_builtin_workflow_path)
|
|
50
|
+
for additional_builtin_workflow_path in CFG.LLM_BUILTIN_WORKFLOW_PATHS
|
|
51
|
+
if os.path.isdir(os.path.expanduser(additional_builtin_workflow_path))
|
|
52
|
+
]
|
|
53
|
+
builtin_workflow_locations.append(
|
|
54
|
+
os.path.join(os.path.dirname(__file__), "default_workflow")
|
|
55
|
+
)
|
|
56
|
+
# Load workflows from all locations
|
|
57
|
+
for workflow_location in builtin_workflow_locations:
|
|
58
|
+
if not os.path.isdir(workflow_location):
|
|
59
|
+
continue
|
|
60
|
+
for workflow_name in os.listdir(workflow_location):
|
|
61
|
+
workflow_dir = os.path.join(workflow_location, workflow_name)
|
|
62
|
+
workflow_file = os.path.join(workflow_dir, "workflow.md")
|
|
63
|
+
if not os.path.isfile(workflow_file):
|
|
64
|
+
workflow_file = os.path.join(workflow_dir, "SKILL.md")
|
|
65
|
+
if not os.path.isfile(path=workflow_file):
|
|
66
|
+
continue
|
|
67
|
+
# Only add if not already defined (earlier locations have precedence)
|
|
68
|
+
if workflow_name not in available_workflows:
|
|
69
|
+
with open(workflow_file, "r") as f:
|
|
70
|
+
workflow_content = f.read()
|
|
71
|
+
available_workflows[workflow_name] = LLMWorkflow(
|
|
72
|
+
name=workflow_name,
|
|
73
|
+
path=workflow_dir,
|
|
74
|
+
content=workflow_content,
|
|
75
|
+
)
|
|
76
|
+
return available_workflows
|