zrb 1.15.20__py3-none-any.whl → 1.15.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/tool/api.py +7 -7
- zrb/builtin/llm/tool/cli.py +6 -9
- zrb/builtin/llm/tool/file.py +50 -47
- zrb/builtin/llm/tool/rag.py +2 -2
- zrb/builtin/llm/tool/web.py +13 -16
- zrb/input/text_input.py +7 -20
- zrb/task/llm/print_node.py +1 -1
- zrb/task/llm/tool_wrapper.py +85 -62
- zrb/util/cli/text.py +28 -0
- {zrb-1.15.20.dist-info → zrb-1.15.22.dist-info}/METADATA +1 -1
- {zrb-1.15.20.dist-info → zrb-1.15.22.dist-info}/RECORD +13 -12
- {zrb-1.15.20.dist-info → zrb-1.15.22.dist-info}/WHEEL +0 -0
- {zrb-1.15.20.dist-info → zrb-1.15.22.dist-info}/entry_points.txt +0 -0
zrb/builtin/llm/tool/api.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
import json
|
2
1
|
from typing import Literal
|
3
2
|
|
4
3
|
|
5
|
-
def get_current_location() -> str:
|
4
|
+
def get_current_location() -> dict[str, float]:
|
6
5
|
"""
|
7
6
|
Fetches the user's current geographical location based on their IP address.
|
8
7
|
|
@@ -11,8 +10,9 @@ def get_current_location() -> str:
|
|
11
10
|
answered.
|
12
11
|
|
13
12
|
Returns:
|
14
|
-
str: A
|
15
|
-
location.
|
13
|
+
dict[str, float]: A dictionary containing the 'lat' and 'lon' of the current
|
14
|
+
location.
|
15
|
+
Example: {"lat": 48.8584, "lon": 2.2945}
|
16
16
|
Raises:
|
17
17
|
requests.RequestException: If the API request to the location service
|
18
18
|
fails.
|
@@ -22,7 +22,7 @@ def get_current_location() -> str:
|
|
22
22
|
try:
|
23
23
|
response = requests.get("http://ip-api.com/json?fields=lat,lon", timeout=5)
|
24
24
|
response.raise_for_status()
|
25
|
-
return
|
25
|
+
return dict(response.json())
|
26
26
|
except requests.RequestException as e:
|
27
27
|
raise requests.RequestException(f"Failed to get location: {e}") from None
|
28
28
|
|
@@ -46,7 +46,7 @@ def get_current_weather(
|
|
46
46
|
for the temperature reading.
|
47
47
|
|
48
48
|
Returns:
|
49
|
-
str: A
|
49
|
+
dict[str, Any]: A dictionary containing detailed weather data, including
|
50
50
|
temperature, wind speed, and weather code.
|
51
51
|
Raises:
|
52
52
|
requests.RequestException: If the API request to the weather service
|
@@ -66,6 +66,6 @@ def get_current_weather(
|
|
66
66
|
timeout=5,
|
67
67
|
)
|
68
68
|
response.raise_for_status()
|
69
|
-
return
|
69
|
+
return dict(response.json())
|
70
70
|
except requests.RequestException as e:
|
71
71
|
raise requests.RequestException(f"Failed to get weather data: {e}") from None
|
zrb/builtin/llm/tool/cli.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import json
|
2
1
|
import subprocess
|
3
2
|
|
4
3
|
|
@@ -19,7 +18,7 @@ def run_shell_command(command: str) -> str:
|
|
19
18
|
command (str): The exact shell command to execute.
|
20
19
|
|
21
20
|
Returns:
|
22
|
-
str: A
|
21
|
+
dict[str, Any]: A dictionary containing return code, standard output (stdout),
|
23
22
|
and standard error (stderr) from the command.
|
24
23
|
Example: {"return_code": 0, "stdout": "ok", "stderr": ""}
|
25
24
|
"""
|
@@ -29,10 +28,8 @@ def run_shell_command(command: str) -> str:
|
|
29
28
|
capture_output=True,
|
30
29
|
text=True,
|
31
30
|
)
|
32
|
-
return
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
}
|
38
|
-
)
|
31
|
+
return {
|
32
|
+
"return_code": result.returncode,
|
33
|
+
"stdout": result.stdout,
|
34
|
+
"stderr": result.stderr,
|
35
|
+
}
|
zrb/builtin/llm/tool/file.py
CHANGED
@@ -223,35 +223,42 @@ def read_from_file(
|
|
223
223
|
end_line: Optional[int] = None,
|
224
224
|
) -> str:
|
225
225
|
"""
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
226
|
+
Reads the content of a file, optionally from a specific start line to an
|
227
|
+
end line.
|
228
|
+
|
229
|
+
This tool is essential for inspecting file contents. It can read both text
|
230
|
+
and PDF files. The returned content is prefixed with line numbers, which is
|
231
|
+
crucial for providing context when you need to modify the file later with
|
232
|
+
the `apply_diff` tool.
|
233
|
+
|
234
|
+
Use this tool to:
|
235
|
+
- Examine the source code of a file.
|
236
|
+
- Read configuration files.
|
237
|
+
- Check the contents of a document.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
path (str): The path to the file to read.
|
241
|
+
start_line (int, optional): The 1-based line number to start reading
|
242
|
+
from. If omitted, reading starts from the beginning of the file.
|
243
|
+
end_line (int, optional): The 1-based line number to stop reading at
|
244
|
+
(inclusive). If omitted, reads to the end of the file.
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
str: A JSON string containing the file path, the requested content
|
248
|
+
with line numbers, the start and end lines, and the total number
|
249
|
+
of lines in the file.
|
250
|
+
Example:
|
251
|
+
```
|
252
|
+
{
|
253
|
+
"path": "src/main.py",
|
254
|
+
"content": "1| import os\n2|3| print(\"Hello, World!\")",
|
255
|
+
"start_line": 1,
|
256
|
+
"end_line": 3,
|
257
|
+
"total_lines": 3
|
258
|
+
}
|
259
|
+
```
|
260
|
+
Raises:
|
261
|
+
FileNotFoundError: If the specified file does not exist.
|
255
262
|
"""
|
256
263
|
|
257
264
|
abs_path = os.path.abspath(os.path.expanduser(path))
|
@@ -309,7 +316,7 @@ def write_to_file(
|
|
309
316
|
Do not use partial content or omit any lines.
|
310
317
|
|
311
318
|
Returns:
|
312
|
-
str: A
|
319
|
+
dict[str, Any]: A dictionary indicating success or failure.
|
313
320
|
Example: '{"success": true, "path": "new_file.txt"}'
|
314
321
|
"""
|
315
322
|
try:
|
@@ -319,8 +326,7 @@ def write_to_file(
|
|
319
326
|
if directory and not os.path.exists(directory):
|
320
327
|
os.makedirs(directory, exist_ok=True)
|
321
328
|
write_file(abs_path, content)
|
322
|
-
|
323
|
-
return json.dumps(result_data)
|
329
|
+
return {"success": True, "path": path}
|
324
330
|
except (OSError, IOError) as e:
|
325
331
|
raise OSError(f"Error writing file {path}: {e}")
|
326
332
|
except Exception as e:
|
@@ -352,7 +358,7 @@ def search_files(
|
|
352
358
|
hidden files and directories. Defaults to True.
|
353
359
|
|
354
360
|
Returns:
|
355
|
-
str: A
|
361
|
+
dict[str, Any]: A dictionary containing a summary of the search and a list of
|
356
362
|
results. Each result includes the file path and a list of matches,
|
357
363
|
with each match showing the line number, line content, and a few
|
358
364
|
lines of context from before and after the match.
|
@@ -404,9 +410,7 @@ def search_files(
|
|
404
410
|
f"Found {match_count} matches in {file_match_count} files "
|
405
411
|
f"(searched {searched_file_count} files)."
|
406
412
|
)
|
407
|
-
return
|
408
|
-
search_results
|
409
|
-
) # No need for pretty printing for LLM consumption
|
413
|
+
return search_results
|
410
414
|
except (OSError, IOError) as e:
|
411
415
|
raise OSError(f"Error searching files in {path}: {e}")
|
412
416
|
except Exception as e:
|
@@ -467,7 +471,7 @@ def replace_in_file(
|
|
467
471
|
new_string (str): The new string that will replace the `old_string`.
|
468
472
|
|
469
473
|
Returns:
|
470
|
-
str: A
|
474
|
+
dict[str, Any]: A dictionary indicating the success or failure of the operation.
|
471
475
|
Raises:
|
472
476
|
FileNotFoundError: If the specified file does not exist.
|
473
477
|
ValueError: If the `old_string` is not found in the file.
|
@@ -481,7 +485,7 @@ def replace_in_file(
|
|
481
485
|
raise ValueError(f"old_string not found in file: {path}")
|
482
486
|
new_content = content.replace(old_string, new_string, 1)
|
483
487
|
write_file(abs_path, new_content)
|
484
|
-
return
|
488
|
+
return {"success": True, "path": path}
|
485
489
|
except ValueError as e:
|
486
490
|
raise e
|
487
491
|
except (OSError, IOError) as e:
|
@@ -564,11 +568,10 @@ def read_many_files(paths: list[str]) -> str:
|
|
564
568
|
if you are unsure about the exact file locations.
|
565
569
|
|
566
570
|
Returns:
|
567
|
-
str:
|
568
|
-
corresponding contents, prefixed with line numbers.
|
569
|
-
cannot be read, its value will be an error message.
|
570
|
-
Example:
|
571
|
-
"config.yaml": "1| key: value"}}'
|
571
|
+
dict[str, str]: a dictionary where keys are the file paths and values
|
572
|
+
are their corresponding contents, prefixed with line numbers.
|
573
|
+
If a file cannot be read, its value will be an error message.
|
574
|
+
Example: {"src/api.py": "1| import ...", "config.yaml": "1| key: value"}
|
572
575
|
"""
|
573
576
|
results = {}
|
574
577
|
for path in paths:
|
@@ -580,7 +583,7 @@ def read_many_files(paths: list[str]) -> str:
|
|
580
583
|
results[path] = content
|
581
584
|
except Exception as e:
|
582
585
|
results[path] = f"Error reading file: {e}"
|
583
|
-
return
|
586
|
+
return results
|
584
587
|
|
585
588
|
|
586
589
|
def write_many_files(files: list[FileToWrite]) -> str:
|
@@ -601,10 +604,10 @@ def write_many_files(files: list[FileToWrite]) -> str:
|
|
601
604
|
containing a 'path' and the complete 'content'.
|
602
605
|
|
603
606
|
Returns:
|
604
|
-
str: A
|
607
|
+
str: A dictionary summarizing the operation, listing successfully
|
605
608
|
written files and any files that failed, along with corresponding
|
606
609
|
error messages.
|
607
|
-
Example:
|
610
|
+
Example: {"success": ["file1.py", "file2.txt"], "errors": {}}
|
608
611
|
"""
|
609
612
|
success = []
|
610
613
|
errors = {}
|
@@ -623,4 +626,4 @@ def write_many_files(files: list[FileToWrite]) -> str:
|
|
623
626
|
success.append(path)
|
624
627
|
except Exception as e:
|
625
628
|
errors[path] = f"Error writing file: {e}"
|
626
|
-
return
|
629
|
+
return {"success": success, "errors": errors}
|
zrb/builtin/llm/tool/rag.py
CHANGED
@@ -201,7 +201,7 @@ def create_rag_from_directory(
|
|
201
201
|
query_embeddings=query_vector,
|
202
202
|
n_results=max_result_count_val,
|
203
203
|
)
|
204
|
-
return
|
204
|
+
return dict(results)
|
205
205
|
|
206
206
|
retrieve.__name__ = tool_name
|
207
207
|
retrieve.__doc__ = dedent(
|
@@ -210,7 +210,7 @@ def create_rag_from_directory(
|
|
210
210
|
Args:
|
211
211
|
query (str): The user query to search for in documents.
|
212
212
|
Returns:
|
213
|
-
str:
|
213
|
+
str: dictionary with search results: {{"ids": [...], "documents": [...], ...}}
|
214
214
|
"""
|
215
215
|
).strip()
|
216
216
|
return retrieve
|
zrb/builtin/llm/tool/web.py
CHANGED
@@ -2,6 +2,8 @@ import json
|
|
2
2
|
from collections.abc import Callable
|
3
3
|
from urllib.parse import urljoin
|
4
4
|
|
5
|
+
_DEFAULT_USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
|
6
|
+
|
5
7
|
|
6
8
|
async def open_web_page(url: str) -> str:
|
7
9
|
"""
|
@@ -18,7 +20,7 @@ async def open_web_page(url: str) -> str:
|
|
18
20
|
"https://example.com/article").
|
19
21
|
|
20
22
|
Returns:
|
21
|
-
str: A JSON
|
23
|
+
str: A JSON string containing the page's content in Markdown format
|
22
24
|
and a list of all absolute links found on the page.
|
23
25
|
"""
|
24
26
|
html_content, links = await _fetch_page_content(url)
|
@@ -47,22 +49,22 @@ def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
|
|
47
49
|
"""
|
48
50
|
Performs an internet search using Google and returns a summary of the results.
|
49
51
|
|
50
|
-
Use this tool to find information on the web, answer general knowledge questions,
|
52
|
+
Use this tool to find information on the web, answer general knowledge questions,
|
53
|
+
or research topics.
|
51
54
|
|
52
55
|
Args:
|
53
56
|
query (str): The search query.
|
54
57
|
num_results (int, optional): The desired number of search results. Defaults to 10.
|
55
58
|
|
56
59
|
Returns:
|
57
|
-
str: A formatted string summarizing the search results,
|
60
|
+
str: A formatted string summarizing the search results,
|
61
|
+
including titles, links, and snippets.
|
58
62
|
"""
|
59
63
|
import requests
|
60
64
|
|
61
65
|
response = requests.get(
|
62
66
|
"https://serpapi.com/search",
|
63
|
-
headers={
|
64
|
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
65
|
-
},
|
67
|
+
headers={"User-Agent": _DEFAULT_USER_AGENT},
|
66
68
|
params={
|
67
69
|
"q": query,
|
68
70
|
"num": num_results,
|
@@ -73,7 +75,7 @@ def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
|
|
73
75
|
)
|
74
76
|
if response.status_code != 200:
|
75
77
|
raise Exception(
|
76
|
-
f"Error: Unable to retrieve search results (status code: {response.status_code})"
|
78
|
+
f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
|
77
79
|
)
|
78
80
|
return response.json()
|
79
81
|
|
@@ -100,9 +102,7 @@ def search_wikipedia(query: str) -> str:
|
|
100
102
|
params = {"action": "query", "list": "search", "srsearch": query, "format": "json"}
|
101
103
|
response = requests.get(
|
102
104
|
"https://en.wikipedia.org/w/api.php",
|
103
|
-
headers={
|
104
|
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
105
|
-
},
|
105
|
+
headers={"User-Agent": _DEFAULT_USER_AGENT},
|
106
106
|
params=params,
|
107
107
|
)
|
108
108
|
return response.json()
|
@@ -131,9 +131,7 @@ def search_arxiv(query: str, num_results: int = 10) -> str:
|
|
131
131
|
params = {"search_query": f"all:{query}", "start": 0, "max_results": num_results}
|
132
132
|
response = requests.get(
|
133
133
|
"http://export.arxiv.org/api/query",
|
134
|
-
headers={
|
135
|
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
136
|
-
},
|
134
|
+
headers={"User-Agent": _DEFAULT_USER_AGENT},
|
137
135
|
params=params,
|
138
136
|
)
|
139
137
|
return response.content
|
@@ -141,14 +139,13 @@ def search_arxiv(query: str, num_results: int = 10) -> str:
|
|
141
139
|
|
142
140
|
async def _fetch_page_content(url: str) -> tuple[str, list[str]]:
|
143
141
|
"""Fetches the HTML content and all absolute links from a URL."""
|
144
|
-
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
145
142
|
try:
|
146
143
|
from playwright.async_api import async_playwright
|
147
144
|
|
148
145
|
async with async_playwright() as p:
|
149
146
|
browser = await p.chromium.launch(headless=True)
|
150
147
|
page = await browser.new_page()
|
151
|
-
await page.set_extra_http_headers({"User-Agent":
|
148
|
+
await page.set_extra_http_headers({"User-Agent": _DEFAULT_USER_AGENT})
|
152
149
|
try:
|
153
150
|
await page.goto(url, wait_until="networkidle", timeout=30000)
|
154
151
|
await page.wait_for_load_state("domcontentloaded")
|
@@ -176,7 +173,7 @@ async def _fetch_page_content(url: str) -> tuple[str, list[str]]:
|
|
176
173
|
import requests
|
177
174
|
from bs4 import BeautifulSoup
|
178
175
|
|
179
|
-
response = requests.get(url, headers={"User-Agent":
|
176
|
+
response = requests.get(url, headers={"User-Agent": _DEFAULT_USER_AGENT})
|
180
177
|
if response.status_code != 200:
|
181
178
|
raise Exception(
|
182
179
|
f"Unable to retrieve page content. Status code: {response.status_code}"
|
zrb/input/text_input.py
CHANGED
@@ -6,6 +6,7 @@ from collections.abc import Callable
|
|
6
6
|
from zrb.config.config import CFG
|
7
7
|
from zrb.context.any_shared_context import AnySharedContext
|
8
8
|
from zrb.input.base_input import BaseInput
|
9
|
+
from zrb.util.cli.text import edit_text
|
9
10
|
from zrb.util.file import read_file
|
10
11
|
|
11
12
|
|
@@ -85,24 +86,10 @@ class TextInput(BaseInput):
|
|
85
86
|
comment_prompt_message = (
|
86
87
|
f"{self.comment_start}{prompt_message}{self.comment_end}"
|
87
88
|
)
|
88
|
-
comment_prompt_message_eol = f"{comment_prompt_message}\n"
|
89
89
|
default_value = self.get_default_str(shared_ctx)
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
if default_value:
|
97
|
-
temp_file.write(default_value.encode())
|
98
|
-
temp_file.flush()
|
99
|
-
subprocess.call([self.editor_cmd, temp_file_name])
|
100
|
-
# Read the edited content
|
101
|
-
edited_content = read_file(temp_file_name)
|
102
|
-
parts = [
|
103
|
-
text.strip() for text in edited_content.split(comment_prompt_message, 1)
|
104
|
-
]
|
105
|
-
edited_content = "\n".join(parts).lstrip()
|
106
|
-
os.remove(temp_file_name)
|
107
|
-
print(f"{prompt_message}: {edited_content}")
|
108
|
-
return edited_content
|
90
|
+
return edit_text(
|
91
|
+
prompt_message=comment_prompt_message,
|
92
|
+
value=default_value,
|
93
|
+
editor=self.editor_cmd,
|
94
|
+
extension=self._extension,
|
95
|
+
)
|
zrb/task/llm/print_node.py
CHANGED
@@ -191,7 +191,7 @@ def _get_event_part_args(event: Any) -> Any:
|
|
191
191
|
|
192
192
|
|
193
193
|
def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
194
|
-
return {key: _truncate_arg(val) for key, val in kwargs.items()
|
194
|
+
return {key: _truncate_arg(val) for key, val in kwargs.items()}
|
195
195
|
|
196
196
|
|
197
197
|
def _truncate_arg(arg: str, length: int = 19) -> str:
|
zrb/task/llm/tool_wrapper.py
CHANGED
@@ -5,6 +5,7 @@ import typing
|
|
5
5
|
from collections.abc import Callable
|
6
6
|
from typing import TYPE_CHECKING, Any
|
7
7
|
|
8
|
+
from zrb.config.config import CFG
|
8
9
|
from zrb.context.any_context import AnyContext
|
9
10
|
from zrb.task.llm.error import ToolExecutionError
|
10
11
|
from zrb.util.callable import get_callable_name
|
@@ -15,6 +16,7 @@ from zrb.util.cli.style import (
|
|
15
16
|
stylize_green,
|
16
17
|
stylize_yellow,
|
17
18
|
)
|
19
|
+
from zrb.util.cli.text import edit_text
|
18
20
|
from zrb.util.run import run_async
|
19
21
|
from zrb.util.string.conversion import to_boolean
|
20
22
|
|
@@ -39,7 +41,6 @@ def wrap_tool(func: Callable, ctx: AnyContext, yolo_mode: bool | list[str]) -> "
|
|
39
41
|
def wrap_func(func: Callable, ctx: AnyContext, yolo_mode: bool | list[str]) -> Callable:
|
40
42
|
original_sig = inspect.signature(func)
|
41
43
|
needs_any_context_for_injection = _has_context_parameter(original_sig, AnyContext)
|
42
|
-
takes_no_args = len(original_sig.parameters) == 0
|
43
44
|
# Pass individual flags to the wrapper creator
|
44
45
|
wrapper = _create_wrapper(
|
45
46
|
func=func,
|
@@ -48,7 +49,7 @@ def wrap_func(func: Callable, ctx: AnyContext, yolo_mode: bool | list[str]) -> C
|
|
48
49
|
needs_any_context_for_injection=needs_any_context_for_injection,
|
49
50
|
yolo_mode=yolo_mode,
|
50
51
|
)
|
51
|
-
_adjust_signature(wrapper, original_sig
|
52
|
+
_adjust_signature(wrapper, original_sig)
|
52
53
|
return wrapper
|
53
54
|
|
54
55
|
|
@@ -108,17 +109,14 @@ def _create_wrapper(
|
|
108
109
|
# Inject the captured ctx into kwargs. This will overwrite if the LLM
|
109
110
|
# somehow provided it.
|
110
111
|
kwargs[any_context_param_name] = ctx
|
111
|
-
# If the dummy argument was added for schema generation and is present in kwargs,
|
112
|
-
# remove it before calling the original function, unless the original function
|
113
|
-
# actually expects a parameter named '_dummy'.
|
114
|
-
if "_dummy" in kwargs and "_dummy" not in original_sig.parameters:
|
115
|
-
del kwargs["_dummy"]
|
116
112
|
try:
|
117
113
|
if not ctx.is_web_mode and ctx.is_tty:
|
118
114
|
if (
|
119
115
|
isinstance(yolo_mode, list) and func.__name__ not in yolo_mode
|
120
116
|
) or not yolo_mode:
|
121
|
-
approval, reason = await
|
117
|
+
approval, reason = await _handle_user_response(
|
118
|
+
ctx, func, args, kwargs
|
119
|
+
)
|
122
120
|
if not approval:
|
123
121
|
raise ToolExecutionCancelled(f"User disapproving: {reason}")
|
124
122
|
return await run_async(func(*args, **kwargs))
|
@@ -136,54 +134,97 @@ def _create_wrapper(
|
|
136
134
|
return wrapper
|
137
135
|
|
138
136
|
|
139
|
-
async def
|
140
|
-
ctx: AnyContext,
|
137
|
+
async def _handle_user_response(
|
138
|
+
ctx: AnyContext,
|
139
|
+
func: Callable,
|
140
|
+
args: list[Any] | tuple[Any],
|
141
|
+
kwargs: dict[str, Any],
|
141
142
|
) -> tuple[bool, str]:
|
142
|
-
func_call_str = _get_func_call_str(func, args, kwargs)
|
143
|
-
complete_confirmation_message = "\n".join(
|
144
|
-
[
|
145
|
-
f"\n🎰 >> {func_call_str}",
|
146
|
-
_get_detail_func_param(args, kwargs),
|
147
|
-
f"🎰 >> {_get_run_func_confirmation(func)}",
|
148
|
-
]
|
149
|
-
)
|
150
143
|
while True:
|
144
|
+
func_call_str = _get_func_call_str(func, args, kwargs)
|
145
|
+
complete_confirmation_message = "\n".join(
|
146
|
+
[
|
147
|
+
f"\n🎰 >> {func_call_str}",
|
148
|
+
_get_detail_func_param(args, kwargs),
|
149
|
+
f"🎰 >> {_get_run_func_confirmation(func)}",
|
150
|
+
]
|
151
|
+
)
|
151
152
|
ctx.print(complete_confirmation_message, plain=True)
|
152
|
-
|
153
|
+
user_response = await _read_line()
|
153
154
|
ctx.print("", plain=True)
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
155
|
+
new_kwargs, is_edited = _get_edited_kwargs(ctx, user_response, kwargs)
|
156
|
+
if is_edited:
|
157
|
+
kwargs = new_kwargs
|
158
|
+
continue
|
159
|
+
approval_and_reason = _get_user_approval_and_reason(
|
160
|
+
ctx, user_response, func_call_str
|
161
|
+
)
|
162
|
+
if approval_and_reason is None:
|
163
|
+
continue
|
164
|
+
return approval_and_reason
|
165
|
+
|
166
|
+
|
167
|
+
def _get_edited_kwargs(
|
168
|
+
cx: AnyContext, user_response: str, kwargs: dict[str, Any]
|
169
|
+
) -> tuple[dict[str, Any], bool]:
|
170
|
+
user_edit_responses = [val for val in user_response.split(" ", maxsplit=2)]
|
171
|
+
if len(user_edit_responses) >= 1 and user_edit_responses[0].lower() != "edit":
|
172
|
+
return kwargs, False
|
173
|
+
while len(user_edit_responses) < 3:
|
174
|
+
user_edit_responses.append("")
|
175
|
+
key, val = user_edit_responses[1:]
|
176
|
+
if key not in kwargs:
|
177
|
+
return kwargs, True
|
178
|
+
if val != "":
|
179
|
+
kwargs[key] = val
|
180
|
+
return kwargs, True
|
181
|
+
val = edit_text(
|
182
|
+
prompt_message=f"// {key}",
|
183
|
+
value=kwargs.get(key, ""),
|
184
|
+
editor=CFG.DEFAULT_EDITOR,
|
185
|
+
)
|
186
|
+
kwargs[key] = val
|
187
|
+
return kwargs, True
|
188
|
+
|
189
|
+
|
190
|
+
def _get_user_approval_and_reason(
|
191
|
+
ctx: AnyContext, user_response: str, func_call_str: str
|
192
|
+
) -> tuple[bool, str] | None:
|
193
|
+
user_approval_responses = [
|
194
|
+
val.strip() for val in user_response.split(",", maxsplit=1)
|
195
|
+
]
|
196
|
+
while len(user_approval_responses) < 2:
|
197
|
+
user_approval_responses.append("")
|
198
|
+
approval_str, reason = user_approval_responses
|
199
|
+
try:
|
200
|
+
approved = True if approval_str.strip() == "" else to_boolean(approval_str)
|
201
|
+
if not approved and reason == "":
|
170
202
|
ctx.print(
|
171
203
|
stylize_error(
|
172
|
-
f"
|
204
|
+
f"You must specify rejection reason (i.e., No, <why>) for {func_call_str}" # noqa
|
173
205
|
),
|
174
206
|
plain=True,
|
175
207
|
)
|
176
|
-
|
208
|
+
return None
|
209
|
+
return approved, reason
|
210
|
+
except Exception:
|
211
|
+
ctx.print(
|
212
|
+
stylize_error(
|
213
|
+
f"Invalid approval value for {func_call_str}: {approval_str}"
|
214
|
+
),
|
215
|
+
plain=True,
|
216
|
+
)
|
217
|
+
return None
|
177
218
|
|
178
219
|
|
179
220
|
def _get_run_func_confirmation(func: Callable) -> str:
|
180
221
|
func_name = get_callable_name(func)
|
181
222
|
return render_markdown(
|
182
|
-
f"Allow to run `{func_name}`? (✅ `Yes` | ⛔ `No, <reason>`)"
|
223
|
+
f"Allow to run `{func_name}`? (✅ `Yes` | ⛔ `No, <reason>` | ✏️ `Edit <param> <value>`)"
|
183
224
|
).strip()
|
184
225
|
|
185
226
|
|
186
|
-
def _get_detail_func_param(args: list[Any], kwargs: dict[str, Any]) -> str:
|
227
|
+
def _get_detail_func_param(args: list[Any] | tuple[Any], kwargs: dict[str, Any]) -> str:
|
187
228
|
markdown = "\n".join(
|
188
229
|
[_get_func_param_item(key, val) for key, val in kwargs.items()]
|
189
230
|
)
|
@@ -203,7 +244,9 @@ def _get_func_param_item(key: str, val: Any) -> str:
|
|
203
244
|
return "\n".join(lines)
|
204
245
|
|
205
246
|
|
206
|
-
def _get_func_call_str(
|
247
|
+
def _get_func_call_str(
|
248
|
+
func: Callable, args: list[Any] | tuple[Any], kwargs: dict[str, Any]
|
249
|
+
) -> str:
|
207
250
|
func_name = get_callable_name(func)
|
208
251
|
normalized_args = [stylize_green(_truncate_arg(arg)) for arg in args]
|
209
252
|
normalized_kwargs = []
|
@@ -230,9 +273,7 @@ async def _read_line():
|
|
230
273
|
return await reader.prompt_async()
|
231
274
|
|
232
275
|
|
233
|
-
def _adjust_signature(
|
234
|
-
wrapper: Callable, original_sig: inspect.Signature, takes_no_args: bool
|
235
|
-
):
|
276
|
+
def _adjust_signature(wrapper: Callable, original_sig: inspect.Signature):
|
236
277
|
"""Adjusts the wrapper function's signature for schema generation."""
|
237
278
|
# The wrapper's signature should represent the arguments the *LLM* needs to provide.
|
238
279
|
# The LLM does not provide RunContext (pydantic-ai injects it) or AnyContext
|
@@ -247,22 +288,4 @@ def _adjust_signature(
|
|
247
288
|
if not _is_annotated_with_context(param.annotation, RunContext)
|
248
289
|
and not _is_annotated_with_context(param.annotation, AnyContext)
|
249
290
|
]
|
250
|
-
|
251
|
-
# If after removing context parameters, there are no parameters left,
|
252
|
-
# and the original function took no args, keep the dummy.
|
253
|
-
# If after removing context parameters, there are no parameters left,
|
254
|
-
# but the original function *did* take args (only context), then the schema
|
255
|
-
# should have no parameters.
|
256
|
-
if not params_for_schema and takes_no_args:
|
257
|
-
# Keep the dummy if the original function truly had no parameters
|
258
|
-
new_sig = inspect.Signature(
|
259
|
-
parameters=[
|
260
|
-
inspect.Parameter(
|
261
|
-
"_dummy", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
|
262
|
-
)
|
263
|
-
]
|
264
|
-
)
|
265
|
-
else:
|
266
|
-
new_sig = inspect.Signature(parameters=params_for_schema)
|
267
|
-
|
268
|
-
wrapper.__signature__ = new_sig
|
291
|
+
wrapper.__signature__ = inspect.Signature(parameters=params_for_schema)
|
zrb/util/cli/text.py
ADDED
@@ -0,0 +1,28 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
import tempfile
|
4
|
+
|
5
|
+
from zrb.util.file import read_file
|
6
|
+
|
7
|
+
|
8
|
+
def edit_text(
|
9
|
+
prompt_message: str,
|
10
|
+
value: str,
|
11
|
+
editor: str = "vi",
|
12
|
+
extension: str = ".txt",
|
13
|
+
) -> str:
|
14
|
+
prompt_message_eol = f"{prompt_message}\n"
|
15
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as temp_file:
|
16
|
+
temp_file_name = temp_file.name
|
17
|
+
temp_file.write(prompt_message_eol.encode())
|
18
|
+
# Pre-fill with default content
|
19
|
+
if value:
|
20
|
+
temp_file.write(value.encode())
|
21
|
+
temp_file.flush()
|
22
|
+
subprocess.call([editor, temp_file_name])
|
23
|
+
# Read the edited content
|
24
|
+
edited_content = read_file(temp_file_name)
|
25
|
+
parts = [text.strip() for text in edited_content.split(prompt_message, 1)]
|
26
|
+
edited_content = "\n".join(parts).lstrip()
|
27
|
+
os.remove(temp_file_name)
|
28
|
+
return edited_content
|
@@ -15,13 +15,13 @@ zrb/builtin/llm/input.py,sha256=Nw-26uTWp2QhUgKJcP_IMHmtk-b542CCSQ_vCOjhvhM,877
|
|
15
15
|
zrb/builtin/llm/llm_ask.py,sha256=XtnSZoBvwHqnBUi8R0rt8VDfnBmWgwFlDuuo1WA1W_w,6209
|
16
16
|
zrb/builtin/llm/previous-session.js,sha256=xMKZvJoAbrwiyHS0OoPrWuaKxWYLoyR5sguePIoCjTY,816
|
17
17
|
zrb/builtin/llm/tool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
-
zrb/builtin/llm/tool/api.py,sha256=
|
19
|
-
zrb/builtin/llm/tool/cli.py,sha256=
|
18
|
+
zrb/builtin/llm/tool/api.py,sha256=T8NGhBe59sQiu8LfdPOIBmsTNMXWFEKaPPSY9bolsQ8,2401
|
19
|
+
zrb/builtin/llm/tool/cli.py,sha256=GCGB8GMFjvVcH0Ac-bD44VG6Bj3mQSuIcNHAwJbx4Ts,1210
|
20
20
|
zrb/builtin/llm/tool/code.py,sha256=fr9FbmtfwizQTyTztvuvwAb9MD_auRZhPZfoJVBlKT4,8777
|
21
|
-
zrb/builtin/llm/tool/file.py,sha256=
|
22
|
-
zrb/builtin/llm/tool/rag.py,sha256=
|
21
|
+
zrb/builtin/llm/tool/file.py,sha256=FPPvKUZY-w1XEa7EN6D6X4VQoQNQ9ggGWLj-xJt-Ysc,23524
|
22
|
+
zrb/builtin/llm/tool/rag.py,sha256=n4ATdr-2gCzPb7LnaBSD_TuAG4TUXKhE9ElKrSDHvFc,9763
|
23
23
|
zrb/builtin/llm/tool/sub_agent.py,sha256=qJTJ2GSH-2Cma2QyHEJm8l_VuDHMHwhAWGls217YA6A,5078
|
24
|
-
zrb/builtin/llm/tool/web.py,sha256=
|
24
|
+
zrb/builtin/llm/tool/web.py,sha256=2FgmiM2LIQfvMMoswidj9hVMav_t8QPG1LiyedL66dw,7349
|
25
25
|
zrb/builtin/md5.py,sha256=690RV2LbW7wQeTFxY-lmmqTSVEEZv3XZbjEUW1Q3XpE,1480
|
26
26
|
zrb/builtin/project/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
zrb/builtin/project/add/fastapp/fastapp_input.py,sha256=MKlWR_LxWhM_DcULCtLfL_IjTxpDnDBkn9KIqNmajFs,310
|
@@ -257,7 +257,7 @@ zrb/input/int_input.py,sha256=UhxCFYlZdJcgUSGGEkz301zOgRVpK0KDG_IxxWpQfMU,1457
|
|
257
257
|
zrb/input/option_input.py,sha256=TQB82ko5odgzkULEizBZi0e9TIHEbIgvdP0AR3RhA74,2135
|
258
258
|
zrb/input/password_input.py,sha256=szBojWxSP9QJecgsgA87OIYwQrY2AQ3USIKdDZY6snU,1465
|
259
259
|
zrb/input/str_input.py,sha256=NevZHX9rf1g8eMatPyy-kUX3DglrVAQpzvVpKAzf7bA,81
|
260
|
-
zrb/input/text_input.py,sha256=
|
260
|
+
zrb/input/text_input.py,sha256=NRM9FSS2pUFs7_R0KsBlu_CD8WLxbfbwxRkpaRoeCSY,3049
|
261
261
|
zrb/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
262
262
|
zrb/runner/cli.py,sha256=E5GGNJjCOBpFbhgnMM_iE1TVYhxMNDJKA6WxCRETTmA,6951
|
263
263
|
zrb/runner/common_util.py,sha256=yIJm9ivM7hvJ4Kb4Nt5RRE7oqAlt9EN89w6JDGyLkFE,1570
|
@@ -356,9 +356,9 @@ zrb/task/llm/default_workflow/researching.md,sha256=KD-aYHFHir6Ti-4FsBBtGwiI0seS
|
|
356
356
|
zrb/task/llm/error.py,sha256=QR-nIohS6pBpC_16cWR-fw7Mevo1sNYAiXMBsh_CJDE,4157
|
357
357
|
zrb/task/llm/history_summarization.py,sha256=UIT8bpdT3hy1xn559waDLFWZlNtIqdIpIvRGcZEpHm0,8057
|
358
358
|
zrb/task/llm/history_summarization_tool.py,sha256=Wazi4WMr3k1WJ1v7QgjAPbuY1JdBpHUsTWGt3DSTsLc,1706
|
359
|
-
zrb/task/llm/print_node.py,sha256=
|
359
|
+
zrb/task/llm/print_node.py,sha256=TG8i3MrAkIj3cLkU9_fSX-u49jlTdU8t9FpHGI_VtoM,8077
|
360
360
|
zrb/task/llm/prompt.py,sha256=FGXWYHecWtrNNkPnjg-uhnkqp7fYt8V91-AjFM_5fpA,11550
|
361
|
-
zrb/task/llm/tool_wrapper.py,sha256=
|
361
|
+
zrb/task/llm/tool_wrapper.py,sha256=v3y4FO14xStpq9K0lA3GIVv6-3dbq85I7xZqdtG-j9U,10243
|
362
362
|
zrb/task/llm/typing.py,sha256=c8VAuPBw_4A3DxfYdydkgedaP-LU61W9_wj3m3CAX1E,58
|
363
363
|
zrb/task/llm_task.py,sha256=OxJ9QpqjEyeOI1_zqzNZHtQlRHi0ANOvL9FYaWLzO3Y,14913
|
364
364
|
zrb/task/make_task.py,sha256=PD3b_aYazthS8LHeJsLAhwKDEgdurQZpymJDKeN60u0,2265
|
@@ -376,6 +376,7 @@ zrb/util/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
376
376
|
zrb/util/cli/markdown.py,sha256=Uhuw8XR-jAG9AG3oNK8VHJpYOdU40Q_8yVN74uu0RJ8,384
|
377
377
|
zrb/util/cli/style.py,sha256=D_548KG1gXEirQGdkAVTc81vBdCeInXtnG1gV1yabBA,6655
|
378
378
|
zrb/util/cli/subcommand.py,sha256=umTZIlrL-9g-qc_eRRgdaQgK-whvXK1roFfvnbuY7NQ,1753
|
379
|
+
zrb/util/cli/text.py,sha256=6r1NqvtjKXt-XVVURyBqYE9tZA2Bnr6u8h9Lopr-Gag,870
|
379
380
|
zrb/util/cmd/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
380
381
|
zrb/util/cmd/command.py,sha256=WpEMWVL9hBsxptvDHmRR93_cJ2zP05BJ2h9-tP93M1Y,7473
|
381
382
|
zrb/util/cmd/remote.py,sha256=NGQq2_IrUMDoZz3qmcgtnNYVGjMHaBKQpZxImf0yfXA,1296
|
@@ -409,7 +410,7 @@ zrb/util/todo_model.py,sha256=hhzAX-uFl5rsg7iVX1ULlJOfBtblwQ_ieNUxBWfc-Os,1670
|
|
409
410
|
zrb/util/truncate.py,sha256=eSzmjBpc1Qod3lM3M73snNbDOcARHukW_tq36dWdPvc,921
|
410
411
|
zrb/xcom/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
411
412
|
zrb/xcom/xcom.py,sha256=o79rxR9wphnShrcIushA0Qt71d_p3ZTxjNf7x9hJB78,1571
|
412
|
-
zrb-1.15.
|
413
|
-
zrb-1.15.
|
414
|
-
zrb-1.15.
|
415
|
-
zrb-1.15.
|
413
|
+
zrb-1.15.22.dist-info/METADATA,sha256=OHahYYqF0_2Z_Ht40qggq5Z538ITZNVtaC-UagYam6o,9892
|
414
|
+
zrb-1.15.22.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
415
|
+
zrb-1.15.22.dist-info/entry_points.txt,sha256=-Pg3ElWPfnaSM-XvXqCxEAa-wfVI6BEgcs386s8C8v8,46
|
416
|
+
zrb-1.15.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|