zrb 1.21.28__py3-none-any.whl → 1.21.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/builtin/llm/chat_completion.py +199 -222
- zrb/builtin/llm/chat_session.py +1 -1
- zrb/builtin/llm/chat_session_cmd.py +28 -11
- zrb/builtin/llm/chat_trigger.py +3 -4
- zrb/builtin/llm/tool/cli.py +45 -14
- zrb/builtin/llm/tool/code.py +5 -1
- zrb/builtin/llm/tool/file.py +17 -0
- zrb/builtin/llm/tool/note.py +7 -7
- zrb/builtin/llm/tool/search/__init__.py +1 -0
- zrb/builtin/llm/tool/search/brave.py +66 -0
- zrb/builtin/llm/tool/search/searxng.py +61 -0
- zrb/builtin/llm/tool/search/serpapi.py +61 -0
- zrb/builtin/llm/tool/sub_agent.py +4 -1
- zrb/builtin/llm/tool/web.py +17 -72
- zrb/cmd/cmd_result.py +2 -1
- zrb/config/config.py +5 -1
- zrb/config/default_prompt/interactive_system_prompt.md +15 -12
- zrb/config/default_prompt/system_prompt.md +16 -18
- zrb/config/llm_rate_limitter.py +36 -13
- zrb/input/option_input.py +30 -2
- zrb/task/cmd_task.py +3 -0
- zrb/task/llm/agent_runner.py +6 -2
- zrb/task/llm/history_list.py +13 -0
- zrb/task/llm/history_processor.py +4 -13
- zrb/task/llm/print_node.py +64 -23
- zrb/task/llm/tool_confirmation_completer.py +41 -0
- zrb/task/llm/tool_wrapper.py +6 -4
- zrb/task/llm/workflow.py +41 -14
- zrb/task/llm_task.py +4 -5
- zrb/task/rsync_task.py +2 -0
- zrb/util/cmd/command.py +33 -10
- zrb/util/match.py +71 -0
- {zrb-1.21.28.dist-info → zrb-1.21.37.dist-info}/METADATA +1 -1
- {zrb-1.21.28.dist-info → zrb-1.21.37.dist-info}/RECORD +36 -29
- {zrb-1.21.28.dist-info → zrb-1.21.37.dist-info}/WHEEL +0 -0
- {zrb-1.21.28.dist-info → zrb-1.21.37.dist-info}/entry_points.txt +0 -0
zrb/config/llm_rate_limitter.py
CHANGED
|
@@ -129,56 +129,79 @@ class LLMRateLimitter:
|
|
|
129
129
|
async def throttle(
|
|
130
130
|
self,
|
|
131
131
|
prompt: Any,
|
|
132
|
-
throttle_notif_callback: Callable[
|
|
132
|
+
throttle_notif_callback: Callable[..., Any] | None = None,
|
|
133
133
|
):
|
|
134
134
|
now = time.time()
|
|
135
135
|
str_prompt = self._prompt_to_str(prompt)
|
|
136
|
-
|
|
136
|
+
new_requested_tokens = self.count_token(str_prompt)
|
|
137
137
|
# Clean up old entries
|
|
138
138
|
while self.request_times and now - self.request_times[0] > 60:
|
|
139
139
|
self.request_times.popleft()
|
|
140
140
|
while self.token_times and now - self.token_times[0][0] > 60:
|
|
141
141
|
self.token_times.popleft()
|
|
142
142
|
# Check per-request token limit
|
|
143
|
-
if
|
|
143
|
+
if new_requested_tokens > self.max_tokens_per_request:
|
|
144
144
|
raise ValueError(
|
|
145
145
|
(
|
|
146
|
-
"
|
|
147
|
-
f"({
|
|
146
|
+
"New request exceeds max_tokens_per_request "
|
|
147
|
+
f"({new_requested_tokens} > {self.max_tokens_per_request})."
|
|
148
148
|
)
|
|
149
149
|
)
|
|
150
|
-
if
|
|
150
|
+
if new_requested_tokens > self.max_tokens_per_minute:
|
|
151
151
|
raise ValueError(
|
|
152
152
|
(
|
|
153
|
-
"
|
|
154
|
-
f"({
|
|
153
|
+
"New request exceeds max_tokens_per_minute "
|
|
154
|
+
f"({new_requested_tokens} > {self.max_tokens_per_minute})."
|
|
155
155
|
)
|
|
156
156
|
)
|
|
157
157
|
# Wait if over per-minute request or token limit
|
|
158
|
+
callback_new_line = True
|
|
159
|
+
ever_throttled = False
|
|
158
160
|
while (
|
|
159
161
|
len(self.request_times) >= self.max_requests_per_minute
|
|
160
|
-
or sum(t for _, t in self.token_times) +
|
|
162
|
+
or sum(t for _, t in self.token_times) + new_requested_tokens
|
|
163
|
+
> self.max_tokens_per_minute
|
|
161
164
|
):
|
|
165
|
+
ever_throttled = True
|
|
162
166
|
if throttle_notif_callback is not None:
|
|
163
167
|
if len(self.request_times) >= self.max_requests_per_minute:
|
|
168
|
+
limit = self.max_requests_per_minute
|
|
164
169
|
rpm = len(self.request_times)
|
|
170
|
+
wait_time = max(0, 60 - (now - self.request_times[0]))
|
|
165
171
|
throttle_notif_callback(
|
|
166
|
-
f"Max request per minute exceeded: {rpm} of {
|
|
172
|
+
f"Max request per minute exceeded: {rpm} of {limit}. "
|
|
173
|
+
f"Waiting for {wait_time:.2f} seconds.",
|
|
174
|
+
new_line=callback_new_line,
|
|
167
175
|
)
|
|
168
176
|
else:
|
|
169
|
-
|
|
177
|
+
limit = self.max_tokens_per_minute
|
|
178
|
+
current_tokens = sum(t for _, t in self.token_times)
|
|
179
|
+
tpm = current_tokens + new_requested_tokens
|
|
180
|
+
needed = tpm - self.max_tokens_per_minute
|
|
181
|
+
freed = 0
|
|
182
|
+
wait_time = 0
|
|
183
|
+
for t_time, t_count in self.token_times:
|
|
184
|
+
freed += t_count
|
|
185
|
+
if freed >= needed:
|
|
186
|
+
wait_time = max(0, 60 - (now - t_time))
|
|
187
|
+
break
|
|
170
188
|
throttle_notif_callback(
|
|
171
|
-
f"Max token per minute exceeded: {tpm} of {
|
|
189
|
+
f"Max token per minute exceeded: {tpm} of {limit}. "
|
|
190
|
+
f"Waiting for {wait_time:.2f} seconds.",
|
|
191
|
+
new_line=callback_new_line,
|
|
172
192
|
)
|
|
193
|
+
callback_new_line = False
|
|
173
194
|
await asyncio.sleep(self.throttle_sleep)
|
|
174
195
|
now = time.time()
|
|
175
196
|
while self.request_times and now - self.request_times[0] > 60:
|
|
176
197
|
self.request_times.popleft()
|
|
177
198
|
while self.token_times and now - self.token_times[0][0] > 60:
|
|
178
199
|
self.token_times.popleft()
|
|
200
|
+
if ever_throttled and throttle_notif_callback is not None:
|
|
201
|
+
throttle_notif_callback("", new_line=True)
|
|
179
202
|
# Record this request
|
|
180
203
|
self.request_times.append(now)
|
|
181
|
-
self.token_times.append((now,
|
|
204
|
+
self.token_times.append((now, new_requested_tokens))
|
|
182
205
|
|
|
183
206
|
def _prompt_to_str(self, prompt: Any) -> str:
|
|
184
207
|
try:
|
zrb/input/option_input.py
CHANGED
|
@@ -1,7 +1,13 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
1
3
|
from zrb.attr.type import StrAttr, StrListAttr
|
|
2
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
3
5
|
from zrb.input.base_input import BaseInput
|
|
4
6
|
from zrb.util.attr import get_str_list_attr
|
|
7
|
+
from zrb.util.match import fuzzy_match
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from prompt_toolkit.completion import Completer
|
|
5
11
|
|
|
6
12
|
|
|
7
13
|
class OptionInput(BaseInput):
|
|
@@ -58,10 +64,32 @@ class OptionInput(BaseInput):
|
|
|
58
64
|
self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
|
|
59
65
|
) -> str:
|
|
60
66
|
from prompt_toolkit import PromptSession
|
|
61
|
-
from prompt_toolkit.completion import WordCompleter
|
|
62
67
|
|
|
63
68
|
if shared_ctx.is_tty:
|
|
64
69
|
reader = PromptSession()
|
|
65
|
-
option_completer =
|
|
70
|
+
option_completer = self._get_option_completer(options)
|
|
66
71
|
return reader.prompt(f"{prompt_message}: ", completer=option_completer)
|
|
67
72
|
return input(f"{prompt_message}: ")
|
|
73
|
+
|
|
74
|
+
def _get_option_completer(self, options: list[str]) -> "Completer":
|
|
75
|
+
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
|
76
|
+
from prompt_toolkit.document import Document
|
|
77
|
+
|
|
78
|
+
class OptionCompleter(Completer):
|
|
79
|
+
def __init__(self, options: list[str]):
|
|
80
|
+
self._options = options
|
|
81
|
+
|
|
82
|
+
def get_completions(
|
|
83
|
+
self, document: Document, complete_event: CompleteEvent
|
|
84
|
+
):
|
|
85
|
+
search_pattern = document.get_word_before_cursor(WORD=True)
|
|
86
|
+
candidates = []
|
|
87
|
+
for option in self._options:
|
|
88
|
+
matched, score = fuzzy_match(option, search_pattern)
|
|
89
|
+
if matched:
|
|
90
|
+
candidates.append((score, option))
|
|
91
|
+
candidates.sort(key=lambda x: (x[0], x[1]))
|
|
92
|
+
for _, option in candidates:
|
|
93
|
+
yield Completion(option, start_position=-len(search_pattern))
|
|
94
|
+
|
|
95
|
+
return OptionCompleter(options)
|
zrb/task/cmd_task.py
CHANGED
|
@@ -48,6 +48,7 @@ class CmdTask(BaseTask):
|
|
|
48
48
|
warn_unrecommended_command: bool | None = None,
|
|
49
49
|
max_output_line: int = 1000,
|
|
50
50
|
max_error_line: int = 1000,
|
|
51
|
+
execution_timeout: int = 3600,
|
|
51
52
|
is_interactive: bool = False,
|
|
52
53
|
execute_condition: BoolAttr = True,
|
|
53
54
|
retries: int = 2,
|
|
@@ -103,6 +104,7 @@ class CmdTask(BaseTask):
|
|
|
103
104
|
self._render_cwd = render_cwd
|
|
104
105
|
self._max_output_line = max_output_line
|
|
105
106
|
self._max_error_line = max_error_line
|
|
107
|
+
self._execution_timeout = execution_timeout
|
|
106
108
|
self._should_plain_print = plain_print
|
|
107
109
|
self._should_warn_unrecommended_command = warn_unrecommended_command
|
|
108
110
|
self._is_interactive = is_interactive
|
|
@@ -142,6 +144,7 @@ class CmdTask(BaseTask):
|
|
|
142
144
|
register_pid_method=lambda pid: ctx.xcom.get(xcom_pid_key).push(pid),
|
|
143
145
|
max_output_line=self._max_output_line,
|
|
144
146
|
max_error_line=self._max_error_line,
|
|
147
|
+
timeout=self._execution_timeout,
|
|
145
148
|
is_interactive=self._is_interactive,
|
|
146
149
|
)
|
|
147
150
|
# Check for errors
|
zrb/task/llm/agent_runner.py
CHANGED
|
@@ -112,8 +112,12 @@ async def _run_single_agent_iteration(
|
|
|
112
112
|
|
|
113
113
|
|
|
114
114
|
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[str], None]:
|
|
115
|
-
def _print_throttle_notif(
|
|
116
|
-
|
|
115
|
+
def _print_throttle_notif(text: str, *args: Any, **kwargs: Any):
|
|
116
|
+
new_line = kwargs.get("new_line", True)
|
|
117
|
+
prefix = "\r" if not new_line else "\n"
|
|
118
|
+
if text != "":
|
|
119
|
+
prefix = f"{prefix} 🐢 Request Throttled: "
|
|
120
|
+
ctx.print(stylize_faint(f"{prefix}{text}"), plain=True, end="")
|
|
117
121
|
|
|
118
122
|
return _print_throttle_notif
|
|
119
123
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def remove_system_prompt_and_instruction(
|
|
5
|
+
history_list: list[dict[str, Any]],
|
|
6
|
+
) -> list[dict[str, Any]]:
|
|
7
|
+
for msg in history_list:
|
|
8
|
+
if msg.get("instructions"):
|
|
9
|
+
msg["instructions"] = ""
|
|
10
|
+
msg["parts"] = [
|
|
11
|
+
p for p in msg.get("parts", []) if p.get("part_kind") != "system-prompt"
|
|
12
|
+
]
|
|
13
|
+
return history_list
|
|
@@ -88,30 +88,21 @@ def create_summarize_history_processor(
|
|
|
88
88
|
messages: list[ModelMessage],
|
|
89
89
|
) -> list[ModelMessage]:
|
|
90
90
|
history_list = json.loads(ModelMessagesTypeAdapter.dump_json(messages))
|
|
91
|
-
|
|
91
|
+
history_list_str = json.dumps(history_list)
|
|
92
92
|
# Estimate token usage
|
|
93
93
|
# Note: Pydantic ai has run context parameter
|
|
94
94
|
# (https://ai.pydantic.dev/message-history/#runcontext-parameter)
|
|
95
95
|
# But we cannot use run_ctx.usage.total_tokens because total token keep increasing
|
|
96
96
|
# even after summariztion.
|
|
97
|
-
estimated_token_usage = rate_limitter.count_token(
|
|
97
|
+
estimated_token_usage = rate_limitter.count_token(history_list_str)
|
|
98
98
|
_print_request_info(
|
|
99
99
|
ctx, estimated_token_usage, summarization_token_threshold, messages
|
|
100
100
|
)
|
|
101
101
|
if estimated_token_usage < summarization_token_threshold or len(messages) == 1:
|
|
102
102
|
return messages
|
|
103
|
-
|
|
104
|
-
{
|
|
105
|
-
key: obj[key]
|
|
106
|
-
for key in obj
|
|
107
|
-
if index == len(history_list) - 1 or key != "instructions"
|
|
108
|
-
}
|
|
109
|
-
for index, obj in enumerate(history_list)
|
|
110
|
-
]
|
|
111
|
-
history_json_str_without_instruction = json.dumps(
|
|
112
|
-
history_list_without_instruction
|
|
103
|
+
summarization_message = (
|
|
104
|
+
f"Summarize the following conversation: {history_list_str}"
|
|
113
105
|
)
|
|
114
|
-
summarization_message = f"Summarize the following conversation: {history_json_str_without_instruction}"
|
|
115
106
|
summarization_agent = Agent[None, ConversationSummary](
|
|
116
107
|
model=summarization_model,
|
|
117
108
|
output_type=save_conversation_summary,
|
zrb/task/llm/print_node.py
CHANGED
|
@@ -24,15 +24,19 @@ async def print_node(
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
meta = getattr(node, "id", None) or getattr(node, "request_id", None)
|
|
27
|
+
progress_char_list = ["|", "/", "-", "\\"]
|
|
28
|
+
progress_index = 0
|
|
27
29
|
if Agent.is_user_prompt_node(node):
|
|
28
30
|
print_func(_format_header("🔠 Receiving input...", log_indent_level))
|
|
29
|
-
|
|
31
|
+
return
|
|
32
|
+
if Agent.is_model_request_node(node):
|
|
30
33
|
# A model request node => We can stream tokens from the model's request
|
|
31
34
|
print_func(_format_header("🧠 Processing...", log_indent_level))
|
|
32
35
|
# Reference: https://ai.pydantic.dev/agents/#streaming-all-events-and-output
|
|
33
36
|
try:
|
|
34
37
|
async with node.stream(agent_run.ctx) as request_stream:
|
|
35
38
|
is_streaming = False
|
|
39
|
+
is_tool_processing = False
|
|
36
40
|
async for event in request_stream:
|
|
37
41
|
if isinstance(event, PartStartEvent) and event.part:
|
|
38
42
|
if is_streaming:
|
|
@@ -40,29 +44,59 @@ async def print_node(
|
|
|
40
44
|
content = _get_event_part_content(event)
|
|
41
45
|
print_func(_format_content(content, log_indent_level), end="")
|
|
42
46
|
is_streaming = True
|
|
43
|
-
|
|
47
|
+
is_tool_processing = False
|
|
48
|
+
continue
|
|
49
|
+
if isinstance(event, PartDeltaEvent):
|
|
44
50
|
if isinstance(event.delta, TextPartDelta):
|
|
45
51
|
content_delta = event.delta.content_delta
|
|
46
52
|
print_func(
|
|
47
53
|
_format_stream_content(content_delta, log_indent_level),
|
|
48
54
|
end="",
|
|
49
55
|
)
|
|
50
|
-
|
|
56
|
+
is_tool_processing = False
|
|
57
|
+
is_streaming = True
|
|
58
|
+
continue
|
|
59
|
+
if isinstance(event.delta, ThinkingPartDelta):
|
|
51
60
|
content_delta = event.delta.content_delta
|
|
52
61
|
print_func(
|
|
53
62
|
_format_stream_content(content_delta, log_indent_level),
|
|
54
63
|
end="",
|
|
55
64
|
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
65
|
+
is_tool_processing = False
|
|
66
|
+
is_streaming = True
|
|
67
|
+
continue
|
|
68
|
+
if isinstance(event.delta, ToolCallPartDelta):
|
|
69
|
+
if CFG.LLM_SHOW_TOOL_CALL_PREPARATION:
|
|
70
|
+
args_delta = event.delta.args_delta
|
|
71
|
+
if isinstance(args_delta, dict):
|
|
72
|
+
args_delta = json.dumps(args_delta)
|
|
73
|
+
print_func(
|
|
74
|
+
_format_stream_content(
|
|
75
|
+
args_delta, log_indent_level
|
|
76
|
+
),
|
|
77
|
+
end="",
|
|
78
|
+
)
|
|
79
|
+
is_streaming = True
|
|
80
|
+
is_tool_processing = True
|
|
81
|
+
continue
|
|
82
|
+
prefix = "\n" if not is_tool_processing else ""
|
|
83
|
+
progress_char = progress_char_list[progress_index]
|
|
60
84
|
print_func(
|
|
61
|
-
|
|
85
|
+
_format_content(
|
|
86
|
+
f"Preparing Tool Parameters... {progress_char}",
|
|
87
|
+
log_indent_level,
|
|
88
|
+
prefix=f"\r{prefix}",
|
|
89
|
+
),
|
|
62
90
|
end="",
|
|
63
91
|
)
|
|
92
|
+
progress_index += 1
|
|
93
|
+
if progress_index >= len(progress_char_list):
|
|
94
|
+
progress_index = 0
|
|
95
|
+
is_tool_processing = True
|
|
96
|
+
is_streaming = True
|
|
97
|
+
continue
|
|
64
98
|
is_streaming = True
|
|
65
|
-
|
|
99
|
+
if isinstance(event, FinalResultEvent) and event.tool_name:
|
|
66
100
|
if is_streaming:
|
|
67
101
|
print_func("")
|
|
68
102
|
tool_name = event.tool_name
|
|
@@ -72,6 +106,7 @@ async def print_node(
|
|
|
72
106
|
)
|
|
73
107
|
)
|
|
74
108
|
is_streaming = False
|
|
109
|
+
is_tool_processing = False
|
|
75
110
|
if is_streaming:
|
|
76
111
|
print_func("")
|
|
77
112
|
except UnexpectedModelBehavior as e:
|
|
@@ -85,7 +120,8 @@ async def print_node(
|
|
|
85
120
|
log_indent_level,
|
|
86
121
|
)
|
|
87
122
|
)
|
|
88
|
-
|
|
123
|
+
return
|
|
124
|
+
if Agent.is_call_tools_node(node):
|
|
89
125
|
# A handle-response node => The model returned some data, potentially calls a tool
|
|
90
126
|
print_func(_format_header("🧰 Calling Tool...", log_indent_level))
|
|
91
127
|
try:
|
|
@@ -100,7 +136,8 @@ async def print_node(
|
|
|
100
136
|
f"{call_id} | Call {tool_name} {args}", log_indent_level
|
|
101
137
|
)
|
|
102
138
|
)
|
|
103
|
-
|
|
139
|
+
continue
|
|
140
|
+
if (
|
|
104
141
|
isinstance(event, FunctionToolResultEvent)
|
|
105
142
|
and event.tool_call_id
|
|
106
143
|
):
|
|
@@ -113,12 +150,10 @@ async def print_node(
|
|
|
113
150
|
log_indent_level,
|
|
114
151
|
)
|
|
115
152
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
)
|
|
121
|
-
)
|
|
153
|
+
continue
|
|
154
|
+
print_func(
|
|
155
|
+
_format_content(f"{call_id} | Executed", log_indent_level)
|
|
156
|
+
)
|
|
122
157
|
except UnexpectedModelBehavior as e:
|
|
123
158
|
print_func("") # ensure newline consistency
|
|
124
159
|
print_func(
|
|
@@ -130,9 +165,11 @@ async def print_node(
|
|
|
130
165
|
log_indent_level,
|
|
131
166
|
)
|
|
132
167
|
)
|
|
133
|
-
|
|
168
|
+
return
|
|
169
|
+
if Agent.is_end_node(node):
|
|
134
170
|
# Once an End node is reached, the agent run is complete
|
|
135
171
|
print_func(_format_header("✅ Completed...", log_indent_level))
|
|
172
|
+
return
|
|
136
173
|
|
|
137
174
|
|
|
138
175
|
def _format_header(text: str | None, log_indent_level: int = 0) -> str:
|
|
@@ -145,8 +182,10 @@ def _format_header(text: str | None, log_indent_level: int = 0) -> str:
|
|
|
145
182
|
)
|
|
146
183
|
|
|
147
184
|
|
|
148
|
-
def _format_content(
|
|
149
|
-
|
|
185
|
+
def _format_content(
|
|
186
|
+
text: str | None, log_indent_level: int = 0, prefix: str = ""
|
|
187
|
+
) -> str:
|
|
188
|
+
return prefix + _format(
|
|
150
189
|
text,
|
|
151
190
|
base_indent=2,
|
|
152
191
|
first_indent=3,
|
|
@@ -155,8 +194,10 @@ def _format_content(text: str | None, log_indent_level: int = 0) -> str:
|
|
|
155
194
|
)
|
|
156
195
|
|
|
157
196
|
|
|
158
|
-
def _format_stream_content(
|
|
159
|
-
|
|
197
|
+
def _format_stream_content(
|
|
198
|
+
text: str | None, log_indent_level: int = 0, prefix: str = ""
|
|
199
|
+
) -> str:
|
|
200
|
+
return prefix + _format(
|
|
160
201
|
text,
|
|
161
202
|
base_indent=2,
|
|
162
203
|
indent=3,
|
|
@@ -207,7 +248,7 @@ def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
|
207
248
|
return {key: _truncate_arg(val) for key, val in kwargs.items()}
|
|
208
249
|
|
|
209
250
|
|
|
210
|
-
def _truncate_arg(arg: str, length: int =
|
|
251
|
+
def _truncate_arg(arg: str, length: int = 30) -> str:
|
|
211
252
|
if isinstance(arg, str) and len(arg) > length:
|
|
212
253
|
return f"{arg[:length-4]} ..."
|
|
213
254
|
return arg
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from prompt_toolkit.completion import Completer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_tool_confirmation_completer(
|
|
8
|
+
options: list[str], meta_dict: dict[str, str]
|
|
9
|
+
) -> "Completer":
|
|
10
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
11
|
+
|
|
12
|
+
class ToolConfirmationCompleter(Completer):
|
|
13
|
+
"""Custom completer for tool confirmation that doesn't auto-complete partial words."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, options, meta_dict):
|
|
16
|
+
self.options = options
|
|
17
|
+
self.meta_dict = meta_dict
|
|
18
|
+
|
|
19
|
+
def get_completions(self, document, complete_event):
|
|
20
|
+
text = document.text.strip()
|
|
21
|
+
# 1. Input is empty, OR
|
|
22
|
+
# 2. Input exactly matches the beginning of an option
|
|
23
|
+
if text == "":
|
|
24
|
+
# Show all options when nothing is typed
|
|
25
|
+
for option in self.options:
|
|
26
|
+
yield Completion(
|
|
27
|
+
option,
|
|
28
|
+
start_position=0,
|
|
29
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
30
|
+
)
|
|
31
|
+
return
|
|
32
|
+
# Only complete if text exactly matches the beginning of an option
|
|
33
|
+
for option in self.options:
|
|
34
|
+
if option.startswith(text):
|
|
35
|
+
yield Completion(
|
|
36
|
+
option,
|
|
37
|
+
start_position=-len(text),
|
|
38
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return ToolConfirmationCompleter(options, meta_dict)
|
zrb/task/llm/tool_wrapper.py
CHANGED
|
@@ -11,11 +11,11 @@ from zrb.config.llm_rate_limitter import llm_rate_limitter
|
|
|
11
11
|
from zrb.context.any_context import AnyContext
|
|
12
12
|
from zrb.task.llm.error import ToolExecutionError
|
|
13
13
|
from zrb.task.llm.file_replacement import edit_replacement, is_single_path_replacement
|
|
14
|
+
from zrb.task.llm.tool_confirmation_completer import get_tool_confirmation_completer
|
|
14
15
|
from zrb.util.callable import get_callable_name
|
|
15
16
|
from zrb.util.cli.markdown import render_markdown
|
|
16
17
|
from zrb.util.cli.style import (
|
|
17
18
|
stylize_blue,
|
|
18
|
-
stylize_error,
|
|
19
19
|
stylize_faint,
|
|
20
20
|
stylize_green,
|
|
21
21
|
stylize_yellow,
|
|
@@ -136,9 +136,12 @@ def _create_wrapper(
|
|
|
136
136
|
result = await run_async(func(*args, **kwargs))
|
|
137
137
|
_check_tool_call_result_limit(result)
|
|
138
138
|
if has_ever_edited:
|
|
139
|
+
serializable_kwargs = kwargs.copy()
|
|
140
|
+
if any_context_param_name is not None:
|
|
141
|
+
serializable_kwargs.pop(any_context_param_name, None)
|
|
139
142
|
return {
|
|
140
143
|
"tool_call_result": result,
|
|
141
|
-
"new_tool_parameters":
|
|
144
|
+
"new_tool_parameters": serializable_kwargs,
|
|
142
145
|
"message": "User correction: Tool was called with user's parameters",
|
|
143
146
|
}
|
|
144
147
|
return result
|
|
@@ -296,7 +299,6 @@ def _truncate_arg(arg: str, length: int = 19) -> str:
|
|
|
296
299
|
|
|
297
300
|
async def _read_line(args: list[Any] | tuple[Any], kwargs: dict[str, Any]):
|
|
298
301
|
from prompt_toolkit import PromptSession
|
|
299
|
-
from prompt_toolkit.completion import WordCompleter
|
|
300
302
|
|
|
301
303
|
options = ["yes", "no", "edit"]
|
|
302
304
|
meta_dict = {
|
|
@@ -307,7 +309,7 @@ async def _read_line(args: list[Any] | tuple[Any], kwargs: dict[str, Any]):
|
|
|
307
309
|
for key in kwargs:
|
|
308
310
|
options.append(f"edit {key}")
|
|
309
311
|
meta_dict[f"edit {key}"] = f"Edit tool execution parameter: {key}"
|
|
310
|
-
completer =
|
|
312
|
+
completer = get_tool_confirmation_completer(options, meta_dict)
|
|
311
313
|
reader = PromptSession()
|
|
312
314
|
return await reader.prompt_async(completer=completer)
|
|
313
315
|
|
zrb/task/llm/workflow.py
CHANGED
|
@@ -44,26 +44,40 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
|
44
44
|
workflow_name.strip().lower(): workflow
|
|
45
45
|
for workflow_name, workflow in llm_context_config.get_workflows().items()
|
|
46
46
|
}
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
47
|
+
workflow_hidden_folder = f".{CFG.ROOT_GROUP_NAME}"
|
|
48
|
+
# Define workflow locations in order of precedence
|
|
49
|
+
default_workflow_locations = (
|
|
50
|
+
[
|
|
51
|
+
# Project specific + user specific workflows
|
|
52
|
+
os.path.join(
|
|
53
|
+
os.path.dirname(__file__), workflow_hidden_folder, "workflows"
|
|
54
|
+
),
|
|
55
|
+
os.path.join(os.path.dirname(__file__), workflow_hidden_folder, "skills"),
|
|
56
|
+
os.path.join(os.path.dirname(__file__), ".claude", "skills"),
|
|
57
|
+
os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "workflows"),
|
|
58
|
+
os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "skills"),
|
|
59
|
+
os.path.join(os.path.expanduser("~"), ".claude", "skills"),
|
|
60
|
+
]
|
|
61
|
+
+ [
|
|
62
|
+
# User defined builtin workflows
|
|
63
|
+
os.path.expanduser(additional_builtin_workflow_path)
|
|
64
|
+
for additional_builtin_workflow_path in CFG.LLM_BUILTIN_WORKFLOW_PATHS
|
|
65
|
+
if os.path.isdir(os.path.expanduser(additional_builtin_workflow_path))
|
|
66
|
+
]
|
|
67
|
+
+ [
|
|
68
|
+
# Zrb builtin workflows
|
|
69
|
+
os.path.join(os.path.dirname(__file__), "default_workflow"),
|
|
70
|
+
]
|
|
55
71
|
)
|
|
56
72
|
# Load workflows from all locations
|
|
57
|
-
for workflow_location in
|
|
73
|
+
for workflow_location in default_workflow_locations:
|
|
58
74
|
if not os.path.isdir(workflow_location):
|
|
59
75
|
continue
|
|
60
76
|
for workflow_name in os.listdir(workflow_location):
|
|
61
77
|
workflow_dir = os.path.join(workflow_location, workflow_name)
|
|
62
|
-
workflow_file =
|
|
63
|
-
if not
|
|
64
|
-
|
|
65
|
-
if not os.path.isfile(path=workflow_file):
|
|
66
|
-
continue
|
|
78
|
+
workflow_file = _get_workflow_file_name(workflow_dir)
|
|
79
|
+
if not workflow_file:
|
|
80
|
+
continue
|
|
67
81
|
# Only add if not already defined (earlier locations have precedence)
|
|
68
82
|
if workflow_name not in available_workflows:
|
|
69
83
|
with open(workflow_file, "r") as f:
|
|
@@ -74,3 +88,16 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
|
74
88
|
content=workflow_content,
|
|
75
89
|
)
|
|
76
90
|
return available_workflows
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_workflow_file_name(workflow_dir: str) -> str | None:
|
|
94
|
+
workflow_file = os.path.join(workflow_dir, "workflow.md")
|
|
95
|
+
if os.path.isfile(workflow_file):
|
|
96
|
+
return workflow_file
|
|
97
|
+
workflow_file = os.path.join(workflow_dir, "WORKFLOW.md")
|
|
98
|
+
if os.path.isfile(workflow_file):
|
|
99
|
+
return workflow_file
|
|
100
|
+
workflow_file = os.path.join(workflow_dir, "SKILL.md")
|
|
101
|
+
if os.path.isfile(workflow_file):
|
|
102
|
+
return workflow_file
|
|
103
|
+
return None
|
zrb/task/llm_task.py
CHANGED
|
@@ -22,6 +22,7 @@ from zrb.task.llm.conversation_history import (
|
|
|
22
22
|
write_conversation_history,
|
|
23
23
|
)
|
|
24
24
|
from zrb.task.llm.conversation_history_model import ConversationHistory
|
|
25
|
+
from zrb.task.llm.history_list import remove_system_prompt_and_instruction
|
|
25
26
|
from zrb.task.llm.history_summarization import get_history_summarization_token_threshold
|
|
26
27
|
from zrb.task.llm.prompt import (
|
|
27
28
|
get_attachments,
|
|
@@ -134,9 +135,6 @@ class LLMTask(BaseTask):
|
|
|
134
135
|
upstream: list[AnyTask] | AnyTask | None = None,
|
|
135
136
|
fallback: list[AnyTask] | AnyTask | None = None,
|
|
136
137
|
successor: list[AnyTask] | AnyTask | None = None,
|
|
137
|
-
conversation_context: (
|
|
138
|
-
dict[str, Any] | Callable[[AnyContext], dict[str, Any]] | None
|
|
139
|
-
) = None,
|
|
140
138
|
):
|
|
141
139
|
super().__init__(
|
|
142
140
|
name=name,
|
|
@@ -204,7 +202,6 @@ class LLMTask(BaseTask):
|
|
|
204
202
|
render_history_summarization_token_threshold
|
|
205
203
|
)
|
|
206
204
|
self._max_call_iteration = max_call_iteration
|
|
207
|
-
self._conversation_context = conversation_context
|
|
208
205
|
self._yolo_mode = yolo_mode
|
|
209
206
|
self._render_yolo_mode = render_yolo_mode
|
|
210
207
|
self._attachment = attachment
|
|
@@ -362,7 +359,9 @@ class LLMTask(BaseTask):
|
|
|
362
359
|
rate_limitter=self._rate_limitter,
|
|
363
360
|
)
|
|
364
361
|
if agent_run and agent_run.result:
|
|
365
|
-
new_history_list =
|
|
362
|
+
new_history_list = remove_system_prompt_and_instruction(
|
|
363
|
+
json.loads(agent_run.result.all_messages_json())
|
|
364
|
+
)
|
|
366
365
|
conversation_history.history = new_history_list
|
|
367
366
|
xcom_usage_key = f"{self.name}-usage"
|
|
368
367
|
if xcom_usage_key not in ctx.xcom:
|
zrb/task/rsync_task.py
CHANGED
|
@@ -44,6 +44,7 @@ class RsyncTask(CmdTask):
|
|
|
44
44
|
plain_print: bool = False,
|
|
45
45
|
max_output_line: int = 1000,
|
|
46
46
|
max_error_line: int = 1000,
|
|
47
|
+
execution_timeout: int = 3600,
|
|
47
48
|
execute_condition: BoolAttr = True,
|
|
48
49
|
retries: int = 2,
|
|
49
50
|
retry_period: float = 0,
|
|
@@ -77,6 +78,7 @@ class RsyncTask(CmdTask):
|
|
|
77
78
|
plain_print=plain_print,
|
|
78
79
|
max_output_line=max_output_line,
|
|
79
80
|
max_error_line=max_error_line,
|
|
81
|
+
execution_timeout=execution_timeout,
|
|
80
82
|
execute_condition=execute_condition,
|
|
81
83
|
retries=retries,
|
|
82
84
|
retry_period=retry_period,
|