zrb 1.21.31__py3-none-any.whl → 1.21.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/builtin/llm/chat_completion.py +48 -84
- zrb/builtin/llm/chat_session.py +1 -1
- zrb/builtin/llm/chat_session_cmd.py +28 -11
- zrb/builtin/llm/chat_trigger.py +1 -1
- zrb/builtin/llm/tool/cli.py +34 -15
- zrb/builtin/llm/tool/file.py +11 -0
- zrb/builtin/llm/tool/search/brave.py +6 -0
- zrb/builtin/llm/tool/search/searxng.py +6 -0
- zrb/builtin/llm/tool/search/serpapi.py +6 -0
- zrb/builtin/llm/tool/sub_agent.py +4 -1
- zrb/builtin/llm/tool/web.py +5 -0
- zrb/cmd/cmd_result.py +2 -1
- zrb/config/config.py +5 -1
- zrb/config/default_prompt/interactive_system_prompt.md +15 -12
- zrb/config/default_prompt/system_prompt.md +16 -18
- zrb/config/llm_rate_limitter.py +36 -13
- zrb/input/option_input.py +30 -2
- zrb/task/cmd_task.py +3 -0
- zrb/task/llm/agent_runner.py +6 -2
- zrb/task/llm/history_list.py +13 -0
- zrb/task/llm/history_processor.py +4 -13
- zrb/task/llm/print_node.py +64 -23
- zrb/task/llm/tool_wrapper.py +4 -1
- zrb/task/llm/workflow.py +41 -14
- zrb/task/llm_task.py +4 -5
- zrb/task/rsync_task.py +2 -0
- zrb/util/cmd/command.py +33 -10
- zrb/util/match.py +71 -0
- {zrb-1.21.31.dist-info → zrb-1.21.37.dist-info}/METADATA +1 -1
- {zrb-1.21.31.dist-info → zrb-1.21.37.dist-info}/RECORD +32 -30
- {zrb-1.21.31.dist-info → zrb-1.21.37.dist-info}/WHEEL +0 -0
- {zrb-1.21.31.dist-info → zrb-1.21.37.dist-info}/entry_points.txt +0 -0
zrb/config/llm_rate_limitter.py
CHANGED
|
@@ -129,56 +129,79 @@ class LLMRateLimitter:
|
|
|
129
129
|
async def throttle(
|
|
130
130
|
self,
|
|
131
131
|
prompt: Any,
|
|
132
|
-
throttle_notif_callback: Callable[
|
|
132
|
+
throttle_notif_callback: Callable[..., Any] | None = None,
|
|
133
133
|
):
|
|
134
134
|
now = time.time()
|
|
135
135
|
str_prompt = self._prompt_to_str(prompt)
|
|
136
|
-
|
|
136
|
+
new_requested_tokens = self.count_token(str_prompt)
|
|
137
137
|
# Clean up old entries
|
|
138
138
|
while self.request_times and now - self.request_times[0] > 60:
|
|
139
139
|
self.request_times.popleft()
|
|
140
140
|
while self.token_times and now - self.token_times[0][0] > 60:
|
|
141
141
|
self.token_times.popleft()
|
|
142
142
|
# Check per-request token limit
|
|
143
|
-
if
|
|
143
|
+
if new_requested_tokens > self.max_tokens_per_request:
|
|
144
144
|
raise ValueError(
|
|
145
145
|
(
|
|
146
|
-
"
|
|
147
|
-
f"({
|
|
146
|
+
"New request exceeds max_tokens_per_request "
|
|
147
|
+
f"({new_requested_tokens} > {self.max_tokens_per_request})."
|
|
148
148
|
)
|
|
149
149
|
)
|
|
150
|
-
if
|
|
150
|
+
if new_requested_tokens > self.max_tokens_per_minute:
|
|
151
151
|
raise ValueError(
|
|
152
152
|
(
|
|
153
|
-
"
|
|
154
|
-
f"({
|
|
153
|
+
"New request exceeds max_tokens_per_minute "
|
|
154
|
+
f"({new_requested_tokens} > {self.max_tokens_per_minute})."
|
|
155
155
|
)
|
|
156
156
|
)
|
|
157
157
|
# Wait if over per-minute request or token limit
|
|
158
|
+
callback_new_line = True
|
|
159
|
+
ever_throttled = False
|
|
158
160
|
while (
|
|
159
161
|
len(self.request_times) >= self.max_requests_per_minute
|
|
160
|
-
or sum(t for _, t in self.token_times) +
|
|
162
|
+
or sum(t for _, t in self.token_times) + new_requested_tokens
|
|
163
|
+
> self.max_tokens_per_minute
|
|
161
164
|
):
|
|
165
|
+
ever_throttled = True
|
|
162
166
|
if throttle_notif_callback is not None:
|
|
163
167
|
if len(self.request_times) >= self.max_requests_per_minute:
|
|
168
|
+
limit = self.max_requests_per_minute
|
|
164
169
|
rpm = len(self.request_times)
|
|
170
|
+
wait_time = max(0, 60 - (now - self.request_times[0]))
|
|
165
171
|
throttle_notif_callback(
|
|
166
|
-
f"Max request per minute exceeded: {rpm} of {
|
|
172
|
+
f"Max request per minute exceeded: {rpm} of {limit}. "
|
|
173
|
+
f"Waiting for {wait_time:.2f} seconds.",
|
|
174
|
+
new_line=callback_new_line,
|
|
167
175
|
)
|
|
168
176
|
else:
|
|
169
|
-
|
|
177
|
+
limit = self.max_tokens_per_minute
|
|
178
|
+
current_tokens = sum(t for _, t in self.token_times)
|
|
179
|
+
tpm = current_tokens + new_requested_tokens
|
|
180
|
+
needed = tpm - self.max_tokens_per_minute
|
|
181
|
+
freed = 0
|
|
182
|
+
wait_time = 0
|
|
183
|
+
for t_time, t_count in self.token_times:
|
|
184
|
+
freed += t_count
|
|
185
|
+
if freed >= needed:
|
|
186
|
+
wait_time = max(0, 60 - (now - t_time))
|
|
187
|
+
break
|
|
170
188
|
throttle_notif_callback(
|
|
171
|
-
f"Max token per minute exceeded: {tpm} of {
|
|
189
|
+
f"Max token per minute exceeded: {tpm} of {limit}. "
|
|
190
|
+
f"Waiting for {wait_time:.2f} seconds.",
|
|
191
|
+
new_line=callback_new_line,
|
|
172
192
|
)
|
|
193
|
+
callback_new_line = False
|
|
173
194
|
await asyncio.sleep(self.throttle_sleep)
|
|
174
195
|
now = time.time()
|
|
175
196
|
while self.request_times and now - self.request_times[0] > 60:
|
|
176
197
|
self.request_times.popleft()
|
|
177
198
|
while self.token_times and now - self.token_times[0][0] > 60:
|
|
178
199
|
self.token_times.popleft()
|
|
200
|
+
if ever_throttled and throttle_notif_callback is not None:
|
|
201
|
+
throttle_notif_callback("", new_line=True)
|
|
179
202
|
# Record this request
|
|
180
203
|
self.request_times.append(now)
|
|
181
|
-
self.token_times.append((now,
|
|
204
|
+
self.token_times.append((now, new_requested_tokens))
|
|
182
205
|
|
|
183
206
|
def _prompt_to_str(self, prompt: Any) -> str:
|
|
184
207
|
try:
|
zrb/input/option_input.py
CHANGED
|
@@ -1,7 +1,13 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
1
3
|
from zrb.attr.type import StrAttr, StrListAttr
|
|
2
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
3
5
|
from zrb.input.base_input import BaseInput
|
|
4
6
|
from zrb.util.attr import get_str_list_attr
|
|
7
|
+
from zrb.util.match import fuzzy_match
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from prompt_toolkit.completion import Completer
|
|
5
11
|
|
|
6
12
|
|
|
7
13
|
class OptionInput(BaseInput):
|
|
@@ -58,10 +64,32 @@ class OptionInput(BaseInput):
|
|
|
58
64
|
self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
|
|
59
65
|
) -> str:
|
|
60
66
|
from prompt_toolkit import PromptSession
|
|
61
|
-
from prompt_toolkit.completion import WordCompleter
|
|
62
67
|
|
|
63
68
|
if shared_ctx.is_tty:
|
|
64
69
|
reader = PromptSession()
|
|
65
|
-
option_completer =
|
|
70
|
+
option_completer = self._get_option_completer(options)
|
|
66
71
|
return reader.prompt(f"{prompt_message}: ", completer=option_completer)
|
|
67
72
|
return input(f"{prompt_message}: ")
|
|
73
|
+
|
|
74
|
+
def _get_option_completer(self, options: list[str]) -> "Completer":
|
|
75
|
+
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
|
76
|
+
from prompt_toolkit.document import Document
|
|
77
|
+
|
|
78
|
+
class OptionCompleter(Completer):
|
|
79
|
+
def __init__(self, options: list[str]):
|
|
80
|
+
self._options = options
|
|
81
|
+
|
|
82
|
+
def get_completions(
|
|
83
|
+
self, document: Document, complete_event: CompleteEvent
|
|
84
|
+
):
|
|
85
|
+
search_pattern = document.get_word_before_cursor(WORD=True)
|
|
86
|
+
candidates = []
|
|
87
|
+
for option in self._options:
|
|
88
|
+
matched, score = fuzzy_match(option, search_pattern)
|
|
89
|
+
if matched:
|
|
90
|
+
candidates.append((score, option))
|
|
91
|
+
candidates.sort(key=lambda x: (x[0], x[1]))
|
|
92
|
+
for _, option in candidates:
|
|
93
|
+
yield Completion(option, start_position=-len(search_pattern))
|
|
94
|
+
|
|
95
|
+
return OptionCompleter(options)
|
zrb/task/cmd_task.py
CHANGED
|
@@ -48,6 +48,7 @@ class CmdTask(BaseTask):
|
|
|
48
48
|
warn_unrecommended_command: bool | None = None,
|
|
49
49
|
max_output_line: int = 1000,
|
|
50
50
|
max_error_line: int = 1000,
|
|
51
|
+
execution_timeout: int = 3600,
|
|
51
52
|
is_interactive: bool = False,
|
|
52
53
|
execute_condition: BoolAttr = True,
|
|
53
54
|
retries: int = 2,
|
|
@@ -103,6 +104,7 @@ class CmdTask(BaseTask):
|
|
|
103
104
|
self._render_cwd = render_cwd
|
|
104
105
|
self._max_output_line = max_output_line
|
|
105
106
|
self._max_error_line = max_error_line
|
|
107
|
+
self._execution_timeout = execution_timeout
|
|
106
108
|
self._should_plain_print = plain_print
|
|
107
109
|
self._should_warn_unrecommended_command = warn_unrecommended_command
|
|
108
110
|
self._is_interactive = is_interactive
|
|
@@ -142,6 +144,7 @@ class CmdTask(BaseTask):
|
|
|
142
144
|
register_pid_method=lambda pid: ctx.xcom.get(xcom_pid_key).push(pid),
|
|
143
145
|
max_output_line=self._max_output_line,
|
|
144
146
|
max_error_line=self._max_error_line,
|
|
147
|
+
timeout=self._execution_timeout,
|
|
145
148
|
is_interactive=self._is_interactive,
|
|
146
149
|
)
|
|
147
150
|
# Check for errors
|
zrb/task/llm/agent_runner.py
CHANGED
|
@@ -112,8 +112,12 @@ async def _run_single_agent_iteration(
|
|
|
112
112
|
|
|
113
113
|
|
|
114
114
|
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[str], None]:
|
|
115
|
-
def _print_throttle_notif(
|
|
116
|
-
|
|
115
|
+
def _print_throttle_notif(text: str, *args: Any, **kwargs: Any):
|
|
116
|
+
new_line = kwargs.get("new_line", True)
|
|
117
|
+
prefix = "\r" if not new_line else "\n"
|
|
118
|
+
if text != "":
|
|
119
|
+
prefix = f"{prefix} 🐢 Request Throttled: "
|
|
120
|
+
ctx.print(stylize_faint(f"{prefix}{text}"), plain=True, end="")
|
|
117
121
|
|
|
118
122
|
return _print_throttle_notif
|
|
119
123
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def remove_system_prompt_and_instruction(
|
|
5
|
+
history_list: list[dict[str, Any]],
|
|
6
|
+
) -> list[dict[str, Any]]:
|
|
7
|
+
for msg in history_list:
|
|
8
|
+
if msg.get("instructions"):
|
|
9
|
+
msg["instructions"] = ""
|
|
10
|
+
msg["parts"] = [
|
|
11
|
+
p for p in msg.get("parts", []) if p.get("part_kind") != "system-prompt"
|
|
12
|
+
]
|
|
13
|
+
return history_list
|
|
@@ -88,30 +88,21 @@ def create_summarize_history_processor(
|
|
|
88
88
|
messages: list[ModelMessage],
|
|
89
89
|
) -> list[ModelMessage]:
|
|
90
90
|
history_list = json.loads(ModelMessagesTypeAdapter.dump_json(messages))
|
|
91
|
-
|
|
91
|
+
history_list_str = json.dumps(history_list)
|
|
92
92
|
# Estimate token usage
|
|
93
93
|
# Note: Pydantic ai has run context parameter
|
|
94
94
|
# (https://ai.pydantic.dev/message-history/#runcontext-parameter)
|
|
95
95
|
# But we cannot use run_ctx.usage.total_tokens because total token keep increasing
|
|
96
96
|
# even after summariztion.
|
|
97
|
-
estimated_token_usage = rate_limitter.count_token(
|
|
97
|
+
estimated_token_usage = rate_limitter.count_token(history_list_str)
|
|
98
98
|
_print_request_info(
|
|
99
99
|
ctx, estimated_token_usage, summarization_token_threshold, messages
|
|
100
100
|
)
|
|
101
101
|
if estimated_token_usage < summarization_token_threshold or len(messages) == 1:
|
|
102
102
|
return messages
|
|
103
|
-
|
|
104
|
-
{
|
|
105
|
-
key: obj[key]
|
|
106
|
-
for key in obj
|
|
107
|
-
if index == len(history_list) - 1 or key != "instructions"
|
|
108
|
-
}
|
|
109
|
-
for index, obj in enumerate(history_list)
|
|
110
|
-
]
|
|
111
|
-
history_json_str_without_instruction = json.dumps(
|
|
112
|
-
history_list_without_instruction
|
|
103
|
+
summarization_message = (
|
|
104
|
+
f"Summarize the following conversation: {history_list_str}"
|
|
113
105
|
)
|
|
114
|
-
summarization_message = f"Summarize the following conversation: {history_json_str_without_instruction}"
|
|
115
106
|
summarization_agent = Agent[None, ConversationSummary](
|
|
116
107
|
model=summarization_model,
|
|
117
108
|
output_type=save_conversation_summary,
|
zrb/task/llm/print_node.py
CHANGED
|
@@ -24,15 +24,19 @@ async def print_node(
|
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
meta = getattr(node, "id", None) or getattr(node, "request_id", None)
|
|
27
|
+
progress_char_list = ["|", "/", "-", "\\"]
|
|
28
|
+
progress_index = 0
|
|
27
29
|
if Agent.is_user_prompt_node(node):
|
|
28
30
|
print_func(_format_header("🔠 Receiving input...", log_indent_level))
|
|
29
|
-
|
|
31
|
+
return
|
|
32
|
+
if Agent.is_model_request_node(node):
|
|
30
33
|
# A model request node => We can stream tokens from the model's request
|
|
31
34
|
print_func(_format_header("🧠 Processing...", log_indent_level))
|
|
32
35
|
# Reference: https://ai.pydantic.dev/agents/#streaming-all-events-and-output
|
|
33
36
|
try:
|
|
34
37
|
async with node.stream(agent_run.ctx) as request_stream:
|
|
35
38
|
is_streaming = False
|
|
39
|
+
is_tool_processing = False
|
|
36
40
|
async for event in request_stream:
|
|
37
41
|
if isinstance(event, PartStartEvent) and event.part:
|
|
38
42
|
if is_streaming:
|
|
@@ -40,29 +44,59 @@ async def print_node(
|
|
|
40
44
|
content = _get_event_part_content(event)
|
|
41
45
|
print_func(_format_content(content, log_indent_level), end="")
|
|
42
46
|
is_streaming = True
|
|
43
|
-
|
|
47
|
+
is_tool_processing = False
|
|
48
|
+
continue
|
|
49
|
+
if isinstance(event, PartDeltaEvent):
|
|
44
50
|
if isinstance(event.delta, TextPartDelta):
|
|
45
51
|
content_delta = event.delta.content_delta
|
|
46
52
|
print_func(
|
|
47
53
|
_format_stream_content(content_delta, log_indent_level),
|
|
48
54
|
end="",
|
|
49
55
|
)
|
|
50
|
-
|
|
56
|
+
is_tool_processing = False
|
|
57
|
+
is_streaming = True
|
|
58
|
+
continue
|
|
59
|
+
if isinstance(event.delta, ThinkingPartDelta):
|
|
51
60
|
content_delta = event.delta.content_delta
|
|
52
61
|
print_func(
|
|
53
62
|
_format_stream_content(content_delta, log_indent_level),
|
|
54
63
|
end="",
|
|
55
64
|
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
65
|
+
is_tool_processing = False
|
|
66
|
+
is_streaming = True
|
|
67
|
+
continue
|
|
68
|
+
if isinstance(event.delta, ToolCallPartDelta):
|
|
69
|
+
if CFG.LLM_SHOW_TOOL_CALL_PREPARATION:
|
|
70
|
+
args_delta = event.delta.args_delta
|
|
71
|
+
if isinstance(args_delta, dict):
|
|
72
|
+
args_delta = json.dumps(args_delta)
|
|
73
|
+
print_func(
|
|
74
|
+
_format_stream_content(
|
|
75
|
+
args_delta, log_indent_level
|
|
76
|
+
),
|
|
77
|
+
end="",
|
|
78
|
+
)
|
|
79
|
+
is_streaming = True
|
|
80
|
+
is_tool_processing = True
|
|
81
|
+
continue
|
|
82
|
+
prefix = "\n" if not is_tool_processing else ""
|
|
83
|
+
progress_char = progress_char_list[progress_index]
|
|
60
84
|
print_func(
|
|
61
|
-
|
|
85
|
+
_format_content(
|
|
86
|
+
f"Preparing Tool Parameters... {progress_char}",
|
|
87
|
+
log_indent_level,
|
|
88
|
+
prefix=f"\r{prefix}",
|
|
89
|
+
),
|
|
62
90
|
end="",
|
|
63
91
|
)
|
|
92
|
+
progress_index += 1
|
|
93
|
+
if progress_index >= len(progress_char_list):
|
|
94
|
+
progress_index = 0
|
|
95
|
+
is_tool_processing = True
|
|
96
|
+
is_streaming = True
|
|
97
|
+
continue
|
|
64
98
|
is_streaming = True
|
|
65
|
-
|
|
99
|
+
if isinstance(event, FinalResultEvent) and event.tool_name:
|
|
66
100
|
if is_streaming:
|
|
67
101
|
print_func("")
|
|
68
102
|
tool_name = event.tool_name
|
|
@@ -72,6 +106,7 @@ async def print_node(
|
|
|
72
106
|
)
|
|
73
107
|
)
|
|
74
108
|
is_streaming = False
|
|
109
|
+
is_tool_processing = False
|
|
75
110
|
if is_streaming:
|
|
76
111
|
print_func("")
|
|
77
112
|
except UnexpectedModelBehavior as e:
|
|
@@ -85,7 +120,8 @@ async def print_node(
|
|
|
85
120
|
log_indent_level,
|
|
86
121
|
)
|
|
87
122
|
)
|
|
88
|
-
|
|
123
|
+
return
|
|
124
|
+
if Agent.is_call_tools_node(node):
|
|
89
125
|
# A handle-response node => The model returned some data, potentially calls a tool
|
|
90
126
|
print_func(_format_header("🧰 Calling Tool...", log_indent_level))
|
|
91
127
|
try:
|
|
@@ -100,7 +136,8 @@ async def print_node(
|
|
|
100
136
|
f"{call_id} | Call {tool_name} {args}", log_indent_level
|
|
101
137
|
)
|
|
102
138
|
)
|
|
103
|
-
|
|
139
|
+
continue
|
|
140
|
+
if (
|
|
104
141
|
isinstance(event, FunctionToolResultEvent)
|
|
105
142
|
and event.tool_call_id
|
|
106
143
|
):
|
|
@@ -113,12 +150,10 @@ async def print_node(
|
|
|
113
150
|
log_indent_level,
|
|
114
151
|
)
|
|
115
152
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
)
|
|
121
|
-
)
|
|
153
|
+
continue
|
|
154
|
+
print_func(
|
|
155
|
+
_format_content(f"{call_id} | Executed", log_indent_level)
|
|
156
|
+
)
|
|
122
157
|
except UnexpectedModelBehavior as e:
|
|
123
158
|
print_func("") # ensure newline consistency
|
|
124
159
|
print_func(
|
|
@@ -130,9 +165,11 @@ async def print_node(
|
|
|
130
165
|
log_indent_level,
|
|
131
166
|
)
|
|
132
167
|
)
|
|
133
|
-
|
|
168
|
+
return
|
|
169
|
+
if Agent.is_end_node(node):
|
|
134
170
|
# Once an End node is reached, the agent run is complete
|
|
135
171
|
print_func(_format_header("✅ Completed...", log_indent_level))
|
|
172
|
+
return
|
|
136
173
|
|
|
137
174
|
|
|
138
175
|
def _format_header(text: str | None, log_indent_level: int = 0) -> str:
|
|
@@ -145,8 +182,10 @@ def _format_header(text: str | None, log_indent_level: int = 0) -> str:
|
|
|
145
182
|
)
|
|
146
183
|
|
|
147
184
|
|
|
148
|
-
def _format_content(
|
|
149
|
-
|
|
185
|
+
def _format_content(
|
|
186
|
+
text: str | None, log_indent_level: int = 0, prefix: str = ""
|
|
187
|
+
) -> str:
|
|
188
|
+
return prefix + _format(
|
|
150
189
|
text,
|
|
151
190
|
base_indent=2,
|
|
152
191
|
first_indent=3,
|
|
@@ -155,8 +194,10 @@ def _format_content(text: str | None, log_indent_level: int = 0) -> str:
|
|
|
155
194
|
)
|
|
156
195
|
|
|
157
196
|
|
|
158
|
-
def _format_stream_content(
|
|
159
|
-
|
|
197
|
+
def _format_stream_content(
|
|
198
|
+
text: str | None, log_indent_level: int = 0, prefix: str = ""
|
|
199
|
+
) -> str:
|
|
200
|
+
return prefix + _format(
|
|
160
201
|
text,
|
|
161
202
|
base_indent=2,
|
|
162
203
|
indent=3,
|
|
@@ -207,7 +248,7 @@ def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
|
207
248
|
return {key: _truncate_arg(val) for key, val in kwargs.items()}
|
|
208
249
|
|
|
209
250
|
|
|
210
|
-
def _truncate_arg(arg: str, length: int =
|
|
251
|
+
def _truncate_arg(arg: str, length: int = 30) -> str:
|
|
211
252
|
if isinstance(arg, str) and len(arg) > length:
|
|
212
253
|
return f"{arg[:length-4]} ..."
|
|
213
254
|
return arg
|
zrb/task/llm/tool_wrapper.py
CHANGED
|
@@ -136,9 +136,12 @@ def _create_wrapper(
|
|
|
136
136
|
result = await run_async(func(*args, **kwargs))
|
|
137
137
|
_check_tool_call_result_limit(result)
|
|
138
138
|
if has_ever_edited:
|
|
139
|
+
serializable_kwargs = kwargs.copy()
|
|
140
|
+
if any_context_param_name is not None:
|
|
141
|
+
serializable_kwargs.pop(any_context_param_name, None)
|
|
139
142
|
return {
|
|
140
143
|
"tool_call_result": result,
|
|
141
|
-
"new_tool_parameters":
|
|
144
|
+
"new_tool_parameters": serializable_kwargs,
|
|
142
145
|
"message": "User correction: Tool was called with user's parameters",
|
|
143
146
|
}
|
|
144
147
|
return result
|
zrb/task/llm/workflow.py
CHANGED
|
@@ -44,26 +44,40 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
|
44
44
|
workflow_name.strip().lower(): workflow
|
|
45
45
|
for workflow_name, workflow in llm_context_config.get_workflows().items()
|
|
46
46
|
}
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
47
|
+
workflow_hidden_folder = f".{CFG.ROOT_GROUP_NAME}"
|
|
48
|
+
# Define workflow locations in order of precedence
|
|
49
|
+
default_workflow_locations = (
|
|
50
|
+
[
|
|
51
|
+
# Project specific + user specific workflows
|
|
52
|
+
os.path.join(
|
|
53
|
+
os.path.dirname(__file__), workflow_hidden_folder, "workflows"
|
|
54
|
+
),
|
|
55
|
+
os.path.join(os.path.dirname(__file__), workflow_hidden_folder, "skills"),
|
|
56
|
+
os.path.join(os.path.dirname(__file__), ".claude", "skills"),
|
|
57
|
+
os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "workflows"),
|
|
58
|
+
os.path.join(os.path.expanduser("~"), workflow_hidden_folder, "skills"),
|
|
59
|
+
os.path.join(os.path.expanduser("~"), ".claude", "skills"),
|
|
60
|
+
]
|
|
61
|
+
+ [
|
|
62
|
+
# User defined builtin workflows
|
|
63
|
+
os.path.expanduser(additional_builtin_workflow_path)
|
|
64
|
+
for additional_builtin_workflow_path in CFG.LLM_BUILTIN_WORKFLOW_PATHS
|
|
65
|
+
if os.path.isdir(os.path.expanduser(additional_builtin_workflow_path))
|
|
66
|
+
]
|
|
67
|
+
+ [
|
|
68
|
+
# Zrb builtin workflows
|
|
69
|
+
os.path.join(os.path.dirname(__file__), "default_workflow"),
|
|
70
|
+
]
|
|
55
71
|
)
|
|
56
72
|
# Load workflows from all locations
|
|
57
|
-
for workflow_location in
|
|
73
|
+
for workflow_location in default_workflow_locations:
|
|
58
74
|
if not os.path.isdir(workflow_location):
|
|
59
75
|
continue
|
|
60
76
|
for workflow_name in os.listdir(workflow_location):
|
|
61
77
|
workflow_dir = os.path.join(workflow_location, workflow_name)
|
|
62
|
-
workflow_file =
|
|
63
|
-
if not
|
|
64
|
-
|
|
65
|
-
if not os.path.isfile(path=workflow_file):
|
|
66
|
-
continue
|
|
78
|
+
workflow_file = _get_workflow_file_name(workflow_dir)
|
|
79
|
+
if not workflow_file:
|
|
80
|
+
continue
|
|
67
81
|
# Only add if not already defined (earlier locations have precedence)
|
|
68
82
|
if workflow_name not in available_workflows:
|
|
69
83
|
with open(workflow_file, "r") as f:
|
|
@@ -74,3 +88,16 @@ def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
|
74
88
|
content=workflow_content,
|
|
75
89
|
)
|
|
76
90
|
return available_workflows
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_workflow_file_name(workflow_dir: str) -> str | None:
|
|
94
|
+
workflow_file = os.path.join(workflow_dir, "workflow.md")
|
|
95
|
+
if os.path.isfile(workflow_file):
|
|
96
|
+
return workflow_file
|
|
97
|
+
workflow_file = os.path.join(workflow_dir, "WORKFLOW.md")
|
|
98
|
+
if os.path.isfile(workflow_file):
|
|
99
|
+
return workflow_file
|
|
100
|
+
workflow_file = os.path.join(workflow_dir, "SKILL.md")
|
|
101
|
+
if os.path.isfile(workflow_file):
|
|
102
|
+
return workflow_file
|
|
103
|
+
return None
|
zrb/task/llm_task.py
CHANGED
|
@@ -22,6 +22,7 @@ from zrb.task.llm.conversation_history import (
|
|
|
22
22
|
write_conversation_history,
|
|
23
23
|
)
|
|
24
24
|
from zrb.task.llm.conversation_history_model import ConversationHistory
|
|
25
|
+
from zrb.task.llm.history_list import remove_system_prompt_and_instruction
|
|
25
26
|
from zrb.task.llm.history_summarization import get_history_summarization_token_threshold
|
|
26
27
|
from zrb.task.llm.prompt import (
|
|
27
28
|
get_attachments,
|
|
@@ -134,9 +135,6 @@ class LLMTask(BaseTask):
|
|
|
134
135
|
upstream: list[AnyTask] | AnyTask | None = None,
|
|
135
136
|
fallback: list[AnyTask] | AnyTask | None = None,
|
|
136
137
|
successor: list[AnyTask] | AnyTask | None = None,
|
|
137
|
-
conversation_context: (
|
|
138
|
-
dict[str, Any] | Callable[[AnyContext], dict[str, Any]] | None
|
|
139
|
-
) = None,
|
|
140
138
|
):
|
|
141
139
|
super().__init__(
|
|
142
140
|
name=name,
|
|
@@ -204,7 +202,6 @@ class LLMTask(BaseTask):
|
|
|
204
202
|
render_history_summarization_token_threshold
|
|
205
203
|
)
|
|
206
204
|
self._max_call_iteration = max_call_iteration
|
|
207
|
-
self._conversation_context = conversation_context
|
|
208
205
|
self._yolo_mode = yolo_mode
|
|
209
206
|
self._render_yolo_mode = render_yolo_mode
|
|
210
207
|
self._attachment = attachment
|
|
@@ -362,7 +359,9 @@ class LLMTask(BaseTask):
|
|
|
362
359
|
rate_limitter=self._rate_limitter,
|
|
363
360
|
)
|
|
364
361
|
if agent_run and agent_run.result:
|
|
365
|
-
new_history_list =
|
|
362
|
+
new_history_list = remove_system_prompt_and_instruction(
|
|
363
|
+
json.loads(agent_run.result.all_messages_json())
|
|
364
|
+
)
|
|
366
365
|
conversation_history.history = new_history_list
|
|
367
366
|
xcom_usage_key = f"{self.name}-usage"
|
|
368
367
|
if xcom_usage_key not in ctx.xcom:
|
zrb/task/rsync_task.py
CHANGED
|
@@ -44,6 +44,7 @@ class RsyncTask(CmdTask):
|
|
|
44
44
|
plain_print: bool = False,
|
|
45
45
|
max_output_line: int = 1000,
|
|
46
46
|
max_error_line: int = 1000,
|
|
47
|
+
execution_timeout: int = 3600,
|
|
47
48
|
execute_condition: BoolAttr = True,
|
|
48
49
|
retries: int = 2,
|
|
49
50
|
retry_period: float = 0,
|
|
@@ -77,6 +78,7 @@ class RsyncTask(CmdTask):
|
|
|
77
78
|
plain_print=plain_print,
|
|
78
79
|
max_output_line=max_output_line,
|
|
79
80
|
max_error_line=max_error_line,
|
|
81
|
+
execution_timeout=execution_timeout,
|
|
80
82
|
execute_condition=execute_condition,
|
|
81
83
|
retries=retries,
|
|
82
84
|
retry_period=retry_period,
|
zrb/util/cmd/command.py
CHANGED
|
@@ -5,7 +5,7 @@ import signal
|
|
|
5
5
|
import sys
|
|
6
6
|
from collections import deque
|
|
7
7
|
from collections.abc import Callable
|
|
8
|
-
from typing import TextIO
|
|
8
|
+
from typing import Any, TextIO
|
|
9
9
|
|
|
10
10
|
import psutil
|
|
11
11
|
|
|
@@ -62,6 +62,8 @@ async def run_command(
|
|
|
62
62
|
register_pid_method: Callable[[int], None] | None = None,
|
|
63
63
|
max_output_line: int = 1000,
|
|
64
64
|
max_error_line: int = 1000,
|
|
65
|
+
max_display_line: int | None = None,
|
|
66
|
+
timeout: int = 3600,
|
|
65
67
|
is_interactive: bool = False,
|
|
66
68
|
) -> tuple[CmdResult, int]:
|
|
67
69
|
"""
|
|
@@ -77,6 +79,8 @@ async def run_command(
|
|
|
77
79
|
actual_print_method = print_method if print_method is not None else print
|
|
78
80
|
if cwd is None:
|
|
79
81
|
cwd = os.getcwd()
|
|
82
|
+
if max_display_line is None:
|
|
83
|
+
max_display_line = max(max_output_line, max_error_line)
|
|
80
84
|
# While environment variables alone weren't the fix, they are still
|
|
81
85
|
# good practice for encouraging simpler output from tools.
|
|
82
86
|
child_env = (env_map or os.environ).copy()
|
|
@@ -95,17 +99,33 @@ async def run_command(
|
|
|
95
99
|
if register_pid_method is not None:
|
|
96
100
|
register_pid_method(cmd_process.pid)
|
|
97
101
|
# Use the new, simple, and correct stream reader.
|
|
102
|
+
display_lines = deque(maxlen=max_display_line if max_display_line > 0 else 0)
|
|
98
103
|
stdout_task = asyncio.create_task(
|
|
99
|
-
__read_stream(
|
|
104
|
+
__read_stream(
|
|
105
|
+
cmd_process.stdout, actual_print_method, max_output_line, display_lines
|
|
106
|
+
)
|
|
100
107
|
)
|
|
101
108
|
stderr_task = asyncio.create_task(
|
|
102
|
-
__read_stream(
|
|
109
|
+
__read_stream(
|
|
110
|
+
cmd_process.stderr, actual_print_method, max_error_line, display_lines
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
timeout_task = (
|
|
114
|
+
asyncio.create_task(asyncio.sleep(timeout)) if timeout and timeout > 0 else None
|
|
103
115
|
)
|
|
104
116
|
try:
|
|
105
|
-
|
|
117
|
+
wait_task = asyncio.create_task(cmd_process.wait())
|
|
118
|
+
done, pending = await asyncio.wait(
|
|
119
|
+
{wait_task, timeout_task} if timeout_task else {wait_task},
|
|
120
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
121
|
+
)
|
|
122
|
+
if timeout_task and timeout_task in done:
|
|
123
|
+
raise asyncio.TimeoutError()
|
|
124
|
+
return_code = wait_task.result()
|
|
106
125
|
stdout, stderr = await asyncio.gather(stdout_task, stderr_task)
|
|
107
|
-
|
|
108
|
-
|
|
126
|
+
display = "\r\n".join(display_lines)
|
|
127
|
+
return CmdResult(stdout, stderr, display=display), return_code
|
|
128
|
+
except (KeyboardInterrupt, asyncio.CancelledError, asyncio.TimeoutError):
|
|
109
129
|
try:
|
|
110
130
|
os.killpg(cmd_process.pid, signal.SIGINT)
|
|
111
131
|
await asyncio.wait_for(cmd_process.wait(), timeout=2.0)
|
|
@@ -133,13 +153,14 @@ def __get_cmd_stdin(is_interactive: bool) -> int | TextIO:
|
|
|
133
153
|
async def __read_stream(
|
|
134
154
|
stream: asyncio.StreamReader,
|
|
135
155
|
print_method: Callable[..., None],
|
|
136
|
-
|
|
156
|
+
max_line: int,
|
|
157
|
+
display_queue: deque[Any],
|
|
137
158
|
) -> str:
|
|
138
159
|
"""
|
|
139
160
|
Reads from the stream using the robust `readline()` and correctly
|
|
140
161
|
interprets carriage returns (`\r`) as distinct print events.
|
|
141
162
|
"""
|
|
142
|
-
captured_lines = deque(maxlen=
|
|
163
|
+
captured_lines = deque(maxlen=max_line if max_line > 0 else 0)
|
|
143
164
|
while True:
|
|
144
165
|
try:
|
|
145
166
|
line_bytes = await stream.readline()
|
|
@@ -149,8 +170,9 @@ async def __read_stream(
|
|
|
149
170
|
# Safety valve for the memory limit.
|
|
150
171
|
error_msg = "[ERROR] A single line of output was too long to process."
|
|
151
172
|
print_method(error_msg)
|
|
152
|
-
if
|
|
173
|
+
if max_line > 0:
|
|
153
174
|
captured_lines.append(error_msg)
|
|
175
|
+
display_queue.append(error_msg)
|
|
154
176
|
break
|
|
155
177
|
except (KeyboardInterrupt, asyncio.CancelledError):
|
|
156
178
|
raise
|
|
@@ -165,8 +187,9 @@ async def __read_stream(
|
|
|
165
187
|
print_method(clean_part, end="\r\n")
|
|
166
188
|
except Exception:
|
|
167
189
|
print_method(clean_part)
|
|
168
|
-
if
|
|
190
|
+
if max_line > 0:
|
|
169
191
|
captured_lines.append(clean_part)
|
|
192
|
+
display_queue.append(clean_part)
|
|
170
193
|
return "\r\n".join(captured_lines)
|
|
171
194
|
|
|
172
195
|
|