zrb 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/llm_chat.py +26 -3
- zrb/builtin/llm/tool/api.py +4 -2
- zrb/builtin/llm/tool/file.py +39 -0
- zrb/builtin/llm/tool/rag.py +37 -22
- zrb/builtin/llm/tool/web.py +46 -20
- zrb/config.py +3 -1
- zrb/content_transformer/content_transformer.py +7 -1
- zrb/context/context.py +8 -2
- zrb/input/text_input.py +9 -5
- zrb/task/llm_task.py +103 -16
- {zrb-1.2.1.dist-info → zrb-1.2.2.dist-info}/METADATA +2 -2
- {zrb-1.2.1.dist-info → zrb-1.2.2.dist-info}/RECORD +14 -13
- {zrb-1.2.1.dist-info → zrb-1.2.2.dist-info}/WHEEL +0 -0
- {zrb-1.2.1.dist-info → zrb-1.2.2.dist-info}/entry_points.txt +0 -0
zrb/builtin/llm/llm_chat.py
CHANGED
@@ -5,13 +5,26 @@ from typing import Any
|
|
5
5
|
from zrb.builtin.group import llm_group
|
6
6
|
from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
|
7
7
|
from zrb.builtin.llm.tool.cli import run_shell_command
|
8
|
-
from zrb.builtin.llm.tool.
|
8
|
+
from zrb.builtin.llm.tool.file import (
|
9
|
+
list_file,
|
10
|
+
read_source_code,
|
11
|
+
read_text_file,
|
12
|
+
write_text_file,
|
13
|
+
)
|
14
|
+
from zrb.builtin.llm.tool.web import (
|
15
|
+
create_search_internet_tool,
|
16
|
+
open_web_page,
|
17
|
+
search_arxiv,
|
18
|
+
search_wikipedia,
|
19
|
+
)
|
9
20
|
from zrb.config import (
|
10
21
|
LLM_ALLOW_ACCESS_INTERNET,
|
22
|
+
LLM_ALLOW_ACCESS_LOCAL_FILE,
|
11
23
|
LLM_ALLOW_ACCESS_SHELL,
|
12
24
|
LLM_HISTORY_DIR,
|
13
25
|
LLM_MODEL,
|
14
26
|
LLM_SYSTEM_PROMPT,
|
27
|
+
SERP_API_KEY,
|
15
28
|
)
|
16
29
|
from zrb.context.any_shared_context import AnySharedContext
|
17
30
|
from zrb.input.bool_input import BoolInput
|
@@ -117,11 +130,21 @@ llm_chat: LLMTask = llm_group.add_task(
|
|
117
130
|
alias="chat",
|
118
131
|
)
|
119
132
|
|
133
|
+
|
134
|
+
if LLM_ALLOW_ACCESS_LOCAL_FILE:
|
135
|
+
llm_chat.add_tool(read_source_code)
|
136
|
+
llm_chat.add_tool(list_file)
|
137
|
+
llm_chat.add_tool(read_text_file)
|
138
|
+
llm_chat.add_tool(write_text_file)
|
139
|
+
|
120
140
|
if LLM_ALLOW_ACCESS_SHELL:
|
121
141
|
llm_chat.add_tool(run_shell_command)
|
122
142
|
|
123
143
|
if LLM_ALLOW_ACCESS_INTERNET:
|
124
|
-
llm_chat.add_tool(
|
125
|
-
llm_chat.add_tool(
|
144
|
+
llm_chat.add_tool(open_web_page)
|
145
|
+
llm_chat.add_tool(search_wikipedia)
|
146
|
+
llm_chat.add_tool(search_arxiv)
|
147
|
+
if SERP_API_KEY != "":
|
148
|
+
llm_chat.add_tool(create_search_internet_tool(SERP_API_KEY))
|
126
149
|
llm_chat.add_tool(get_current_location)
|
127
150
|
llm_chat.add_tool(get_current_weather)
|
zrb/builtin/llm/tool/api.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
from typing import Annotated, Literal
|
3
3
|
|
4
|
-
import requests
|
5
|
-
|
6
4
|
|
7
5
|
def get_current_location() -> (
|
8
6
|
Annotated[str, "JSON string representing latitude and longitude"]
|
9
7
|
): # noqa
|
10
8
|
"""Get the user's current location."""
|
9
|
+
import requests
|
10
|
+
|
11
11
|
return json.dumps(requests.get("http://ip-api.com/json?fields=lat,lon").json())
|
12
12
|
|
13
13
|
|
@@ -17,6 +17,8 @@ def get_current_weather(
|
|
17
17
|
temperature_unit: Literal["celsius", "fahrenheit"],
|
18
18
|
) -> str:
|
19
19
|
"""Get the current weather in a given location."""
|
20
|
+
import requests
|
21
|
+
|
20
22
|
resp = requests.get(
|
21
23
|
"https://api.open-meteo.com/v1/forecast",
|
22
24
|
params={
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from zrb.util.file import read_file, write_file
|
4
|
+
|
5
|
+
|
6
|
+
def list_file(
|
7
|
+
directory: str = ".",
|
8
|
+
extensions: list[str] = [".py", ".go", ".js", ".ts", ".java", ".c", ".cpp"],
|
9
|
+
) -> list[str]:
|
10
|
+
"""List all files in a directory"""
|
11
|
+
all_files: list[str] = []
|
12
|
+
for root, _, files in os.walk(directory):
|
13
|
+
for filename in files:
|
14
|
+
for extension in extensions:
|
15
|
+
if filename.lower().endswith(extension):
|
16
|
+
all_files.append(os.path.join(root, filename))
|
17
|
+
return all_files
|
18
|
+
|
19
|
+
|
20
|
+
def read_text_file(file: str) -> str:
|
21
|
+
"""Read a text file"""
|
22
|
+
return read_file(os.path.abspath(file))
|
23
|
+
|
24
|
+
|
25
|
+
def write_text_file(file: str, content: str):
|
26
|
+
"""Write a text file"""
|
27
|
+
return write_file(os.path.abspath(file), content)
|
28
|
+
|
29
|
+
|
30
|
+
def read_source_code(
|
31
|
+
directory: str = ".",
|
32
|
+
extensions: list[str] = [".py", ".go", ".js", ".ts", ".java", ".c", ".cpp"],
|
33
|
+
) -> list[str]:
|
34
|
+
"""Read source code in a directory"""
|
35
|
+
files = list_file(directory, extensions)
|
36
|
+
for index, file in enumerate(files):
|
37
|
+
content = read_text_file(file)
|
38
|
+
files[index] = f"# {file}\n```\n{content}\n```"
|
39
|
+
return files
|
zrb/builtin/llm/tool/rag.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
+
import fnmatch
|
1
2
|
import hashlib
|
2
3
|
import json
|
3
4
|
import os
|
4
5
|
import sys
|
6
|
+
from collections.abc import Callable
|
5
7
|
|
6
8
|
import ulid
|
7
9
|
|
@@ -15,6 +17,20 @@ from zrb.util.cli.style import stylize_error, stylize_faint
|
|
15
17
|
from zrb.util.file import read_file
|
16
18
|
|
17
19
|
|
20
|
+
class RAGFileReader:
|
21
|
+
def __init__(self, glob_pattern: str, read: Callable[[str], str]):
|
22
|
+
self.glob_pattern = glob_pattern
|
23
|
+
self.read = read
|
24
|
+
|
25
|
+
def is_match(self, file_name: str):
|
26
|
+
if os.sep not in self.glob_pattern and (
|
27
|
+
os.altsep is None or os.altsep not in self.glob_pattern
|
28
|
+
):
|
29
|
+
# Pattern like "*.txt" – match only the basename.
|
30
|
+
return fnmatch.fnmatch(os.path.basename(file_name), self.glob_pattern)
|
31
|
+
return fnmatch.fnmatch(file_name, self.glob_pattern)
|
32
|
+
|
33
|
+
|
18
34
|
def create_rag_from_directory(
|
19
35
|
tool_name: str,
|
20
36
|
tool_description: str,
|
@@ -25,6 +41,7 @@ def create_rag_from_directory(
|
|
25
41
|
chunk_size: int = RAG_CHUNK_SIZE,
|
26
42
|
overlap: int = RAG_OVERLAP,
|
27
43
|
max_result_count: int = RAG_MAX_RESULT_COUNT,
|
44
|
+
file_reader: list[RAGFileReader] = [],
|
28
45
|
):
|
29
46
|
async def retrieve(query: str) -> str:
|
30
47
|
from chromadb import PersistentClient
|
@@ -36,35 +53,31 @@ def create_rag_from_directory(
|
|
36
53
|
path=vector_db_path, settings=Settings(allow_reset=True)
|
37
54
|
)
|
38
55
|
collection = client.get_or_create_collection(vector_db_collection)
|
39
|
-
|
40
56
|
# Track file changes using a hash-based approach
|
41
57
|
hash_file_path = os.path.join(vector_db_path, "file_hashes.json")
|
42
58
|
previous_hashes = _load_hashes(hash_file_path)
|
43
59
|
current_hashes = {}
|
44
|
-
|
60
|
+
# Get updated_files
|
45
61
|
updated_files = []
|
46
|
-
|
47
62
|
for root, _, files in os.walk(document_dir_path):
|
48
63
|
for file in files:
|
49
64
|
file_path = os.path.join(root, file)
|
50
65
|
file_hash = _compute_file_hash(file_path)
|
51
66
|
relative_path = os.path.relpath(file_path, document_dir_path)
|
52
67
|
current_hashes[relative_path] = file_hash
|
53
|
-
|
54
68
|
if previous_hashes.get(relative_path) != file_hash:
|
55
69
|
updated_files.append(file_path)
|
56
|
-
|
70
|
+
# Upsert updated_files to vector db
|
57
71
|
if updated_files:
|
58
72
|
print(
|
59
73
|
stylize_faint(f"Updating {len(updated_files)} changed files"),
|
60
74
|
file=sys.stderr,
|
61
75
|
)
|
62
|
-
|
63
76
|
for file_path in updated_files:
|
64
77
|
try:
|
65
78
|
relative_path = os.path.relpath(file_path, document_dir_path)
|
66
79
|
collection.delete(where={"file_path": relative_path})
|
67
|
-
content =
|
80
|
+
content = _read_txt_content(file_path, file_reader)
|
68
81
|
file_id = ulid.new().str
|
69
82
|
for i in range(0, len(content), chunk_size - overlap):
|
70
83
|
chunk = content[i : i + chunk_size]
|
@@ -92,14 +105,13 @@ def create_rag_from_directory(
|
|
92
105
|
stylize_error(f"Error processing {file_path}: {e}"),
|
93
106
|
file=sys.stderr,
|
94
107
|
)
|
95
|
-
|
96
108
|
_save_hashes(hash_file_path, current_hashes)
|
97
109
|
else:
|
98
110
|
print(
|
99
111
|
stylize_faint("No changes detected. Skipping database update."),
|
100
112
|
file=sys.stderr,
|
101
113
|
)
|
102
|
-
|
114
|
+
# Vectorize query and get related document chunks
|
103
115
|
print(stylize_faint("Vectorizing query"), file=sys.stderr)
|
104
116
|
embedding_result = list(embedding_model.embed([query]))
|
105
117
|
query_vector = embedding_result[0]
|
@@ -123,7 +135,22 @@ def _compute_file_hash(file_path: str) -> str:
|
|
123
135
|
return hash_md5.hexdigest()
|
124
136
|
|
125
137
|
|
126
|
-
def
|
138
|
+
def _load_hashes(file_path: str) -> dict:
|
139
|
+
if os.path.exists(file_path):
|
140
|
+
with open(file_path, "r") as f:
|
141
|
+
return json.load(f)
|
142
|
+
return {}
|
143
|
+
|
144
|
+
|
145
|
+
def _save_hashes(file_path: str, hashes: dict):
|
146
|
+
with open(file_path, "w") as f:
|
147
|
+
json.dump(hashes, f)
|
148
|
+
|
149
|
+
|
150
|
+
def _read_txt_content(file_path: str, file_reader: list[RAGFileReader]):
|
151
|
+
for reader in file_reader:
|
152
|
+
if reader.is_match(file_path):
|
153
|
+
return reader.read(file_path)
|
127
154
|
if file_path.lower().endswith(".pdf"):
|
128
155
|
return _read_pdf(file_path)
|
129
156
|
return read_file(file_path)
|
@@ -136,15 +163,3 @@ def _read_pdf(file_path: str) -> str:
|
|
136
163
|
return "\n".join(
|
137
164
|
page.extract_text() for page in pdf.pages if page.extract_text()
|
138
165
|
)
|
139
|
-
|
140
|
-
|
141
|
-
def _load_hashes(file_path: str) -> dict:
|
142
|
-
if os.path.exists(file_path):
|
143
|
-
with open(file_path, "r") as f:
|
144
|
-
return json.load(f)
|
145
|
-
return {}
|
146
|
-
|
147
|
-
|
148
|
-
def _save_hashes(file_path: str, hashes: dict):
|
149
|
-
with open(file_path, "w") as f:
|
150
|
-
json.dump(hashes, f)
|
zrb/builtin/llm/tool/web.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import json
|
2
|
+
from collections.abc import Callable
|
2
3
|
from typing import Annotated
|
3
4
|
|
4
5
|
|
5
|
-
def
|
6
|
+
def open_web_page(url: str) -> str:
|
6
7
|
"""Get content from a web page."""
|
7
8
|
import requests
|
8
9
|
|
@@ -19,30 +20,55 @@ def open_web_route(url: str) -> str:
|
|
19
20
|
return json.dumps(parse_html_text(response.text))
|
20
21
|
|
21
22
|
|
22
|
-
def
|
23
|
+
def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
|
24
|
+
def search_internet(
|
25
|
+
query: Annotated[str, "Search query"],
|
26
|
+
num_results: Annotated[int, "Search result count, by default 10"] = 10,
|
27
|
+
) -> str:
|
28
|
+
"""Search factual information from the internet by using Google."""
|
29
|
+
import requests
|
30
|
+
|
31
|
+
response = requests.get(
|
32
|
+
"https://serpapi.com/search",
|
33
|
+
headers={
|
34
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
|
35
|
+
},
|
36
|
+
params={
|
37
|
+
"q": query,
|
38
|
+
"num": num_results,
|
39
|
+
"hl": "en",
|
40
|
+
"safe": "off",
|
41
|
+
"api_key": serp_api_key,
|
42
|
+
},
|
43
|
+
)
|
44
|
+
if response.status_code != 200:
|
45
|
+
raise Exception(
|
46
|
+
f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
|
47
|
+
)
|
48
|
+
return json.dumps(parse_html_text(response.text))
|
49
|
+
|
50
|
+
return search_internet
|
51
|
+
|
52
|
+
|
53
|
+
def search_wikipedia(query: Annotated[str, "Search query"]) -> str:
|
54
|
+
"""Search on wikipedia"""
|
55
|
+
import requests
|
56
|
+
|
57
|
+
params = {"action": "query", "list": "search", "srsearch": query, "format": "json"}
|
58
|
+
response = requests.get("https://en.wikipedia.org/w/api.php", params=params)
|
59
|
+
return response.json()
|
60
|
+
|
61
|
+
|
62
|
+
def search_arxiv(
|
23
63
|
query: Annotated[str, "Search query"],
|
24
64
|
num_results: Annotated[int, "Search result count, by default 10"] = 10,
|
25
65
|
) -> str:
|
26
|
-
"""Search
|
66
|
+
"""Search on Arxiv"""
|
27
67
|
import requests
|
28
68
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
|
33
|
-
},
|
34
|
-
params={
|
35
|
-
"q": query,
|
36
|
-
"num": num_results,
|
37
|
-
"hl": "en",
|
38
|
-
"safe": "off",
|
39
|
-
},
|
40
|
-
)
|
41
|
-
if response.status_code != 200:
|
42
|
-
raise Exception(
|
43
|
-
f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
|
44
|
-
)
|
45
|
-
return json.dumps(parse_html_text(response.text))
|
69
|
+
params = {"search_query": f"all:{query}", "start": 0, "max_results": num_results}
|
70
|
+
response = requests.get("http://export.arxiv.org/api/query", params=params)
|
71
|
+
return response.content
|
46
72
|
|
47
73
|
|
48
74
|
def parse_html_text(html_text: str) -> dict[str, str]:
|
zrb/config.py
CHANGED
@@ -92,7 +92,8 @@ LLM_HISTORY_DIR = os.getenv(
|
|
92
92
|
LLM_HISTORY_FILE = os.getenv(
|
93
93
|
"ZRB_LLM_HISTORY_FILE", os.path.join(LLM_HISTORY_DIR, "history.json")
|
94
94
|
)
|
95
|
-
|
95
|
+
LLM_ALLOW_ACCESS_LOCAL_FILE = to_boolean(os.getenv("ZRB_LLM_ACCESS_LOCAL_FILE", "1"))
|
96
|
+
LLM_ALLOW_ACCESS_SHELL = to_boolean(os.getenv("ZRB_LLM_ACCESS_SHELL", "1"))
|
96
97
|
LLM_ALLOW_ACCESS_INTERNET = to_boolean(os.getenv("ZRB_LLM_ACCESS_INTERNET", "1"))
|
97
98
|
# noqa See: https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models
|
98
99
|
RAG_EMBEDDING_MODEL = os.getenv(
|
@@ -101,6 +102,7 @@ RAG_EMBEDDING_MODEL = os.getenv(
|
|
101
102
|
RAG_CHUNK_SIZE = int(os.getenv("ZRB_RAG_CHUNK_SIZE", "1024"))
|
102
103
|
RAG_OVERLAP = int(os.getenv("ZRB_RAG_OVERLAP", "128"))
|
103
104
|
RAG_MAX_RESULT_COUNT = int(os.getenv("ZRB_RAG_MAX_RESULT_COUNT", "5"))
|
105
|
+
SERP_API_KEY = os.getenv("SERP_API_KEY", "")
|
104
106
|
|
105
107
|
|
106
108
|
BANNER = f"""
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import fnmatch
|
2
|
+
import os
|
2
3
|
import re
|
3
4
|
from collections.abc import Callable
|
4
5
|
|
@@ -40,7 +41,12 @@ class ContentTransformer(AnyContentTransformer):
|
|
40
41
|
return True
|
41
42
|
except re.error:
|
42
43
|
pass
|
43
|
-
|
44
|
+
if os.sep not in pattern and (
|
45
|
+
os.altsep is None or os.altsep not in pattern
|
46
|
+
):
|
47
|
+
# Pattern like "*.txt" – match only the basename.
|
48
|
+
return fnmatch.fnmatch(file_path, os.path.basename(file_path))
|
49
|
+
return fnmatch.fnmatch(file_path, file_path)
|
44
50
|
|
45
51
|
def transform_file(self, ctx: AnyContext, file_path: str):
|
46
52
|
if callable(self._transform_file):
|
zrb/context/context.py
CHANGED
@@ -88,7 +88,7 @@ class Context(AnyContext):
|
|
88
88
|
return template
|
89
89
|
return int(self.render(template))
|
90
90
|
|
91
|
-
def render_float(self, template: str) -> float:
|
91
|
+
def render_float(self, template: str | float) -> float:
|
92
92
|
if isinstance(template, float):
|
93
93
|
return template
|
94
94
|
return float(self.render(template))
|
@@ -102,9 +102,10 @@ class Context(AnyContext):
|
|
102
102
|
flush: bool = True,
|
103
103
|
plain: bool = False,
|
104
104
|
):
|
105
|
+
sep = " " if sep is None else sep
|
105
106
|
message = sep.join([f"{value}" for value in values])
|
106
107
|
if plain:
|
107
|
-
self.append_to_shared_log(remove_style(message))
|
108
|
+
# self.append_to_shared_log(remove_style(message))
|
108
109
|
print(message, sep=sep, end=end, file=file, flush=flush)
|
109
110
|
return
|
110
111
|
color = self._color
|
@@ -132,6 +133,7 @@ class Context(AnyContext):
|
|
132
133
|
flush: bool = True,
|
133
134
|
):
|
134
135
|
if self._shared_ctx.get_logging_level() <= logging.DEBUG:
|
136
|
+
sep = " " if sep is None else sep
|
135
137
|
message = sep.join([f"{value}" for value in values])
|
136
138
|
stylized_message = stylize_log(f"[DEBUG] {message}")
|
137
139
|
self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
|
@@ -145,6 +147,7 @@ class Context(AnyContext):
|
|
145
147
|
flush: bool = True,
|
146
148
|
):
|
147
149
|
if self._shared_ctx.get_logging_level() <= logging.INFO:
|
150
|
+
sep = " " if sep is None else sep
|
148
151
|
message = sep.join([f"{value}" for value in values])
|
149
152
|
stylized_message = stylize_log(f"[INFO] {message}")
|
150
153
|
self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
|
@@ -158,6 +161,7 @@ class Context(AnyContext):
|
|
158
161
|
flush: bool = True,
|
159
162
|
):
|
160
163
|
if self._shared_ctx.get_logging_level() <= logging.INFO:
|
164
|
+
sep = " " if sep is None else sep
|
161
165
|
message = sep.join([f"{value}" for value in values])
|
162
166
|
stylized_message = stylize_warning(f"[WARNING] {message}")
|
163
167
|
self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
|
@@ -171,6 +175,7 @@ class Context(AnyContext):
|
|
171
175
|
flush: bool = True,
|
172
176
|
):
|
173
177
|
if self._shared_ctx.get_logging_level() <= logging.ERROR:
|
178
|
+
sep = " " if sep is None else sep
|
174
179
|
message = sep.join([f"{value}" for value in values])
|
175
180
|
stylized_message = stylize_error(f"[ERROR] {message}")
|
176
181
|
self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
|
@@ -184,6 +189,7 @@ class Context(AnyContext):
|
|
184
189
|
flush: bool = True,
|
185
190
|
):
|
186
191
|
if self._shared_ctx.get_logging_level() <= logging.CRITICAL:
|
192
|
+
sep = " " if sep is None else sep
|
187
193
|
message = sep.join([f"{value}" for value in values])
|
188
194
|
stylized_message = stylize_error(f"[CRITICAL] {message}")
|
189
195
|
self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
|
zrb/input/text_input.py
CHANGED
@@ -69,16 +69,17 @@ class TextInput(BaseInput):
|
|
69
69
|
)
|
70
70
|
|
71
71
|
def _prompt_cli_str(self, shared_ctx: AnySharedContext) -> str:
|
72
|
-
prompt_message = (
|
73
|
-
|
72
|
+
prompt_message = super().prompt_message
|
73
|
+
comment_prompt_message = (
|
74
|
+
f"{self.comment_start}{prompt_message}{self.comment_end}"
|
74
75
|
)
|
75
|
-
|
76
|
+
comment_prompt_message_eol = f"{comment_prompt_message}\n"
|
76
77
|
default_value = self.get_default_str(shared_ctx)
|
77
78
|
with tempfile.NamedTemporaryFile(
|
78
79
|
delete=False, suffix=self._extension
|
79
80
|
) as temp_file:
|
80
81
|
temp_file_name = temp_file.name
|
81
|
-
temp_file.write(
|
82
|
+
temp_file.write(comment_prompt_message_eol.encode())
|
82
83
|
# Pre-fill with default content
|
83
84
|
if default_value:
|
84
85
|
temp_file.write(default_value.encode())
|
@@ -87,7 +88,10 @@ class TextInput(BaseInput):
|
|
87
88
|
subprocess.call([self._editor, temp_file_name])
|
88
89
|
# Read the edited content
|
89
90
|
edited_content = read_file(temp_file_name)
|
90
|
-
parts = [
|
91
|
+
parts = [
|
92
|
+
text.strip() for text in edited_content.split(comment_prompt_message, 1)
|
93
|
+
]
|
91
94
|
edited_content = "\n".join(parts).lstrip()
|
92
95
|
os.remove(temp_file_name)
|
96
|
+
print(f"{prompt_message}: {edited_content}")
|
93
97
|
return edited_content
|
zrb/task/llm_task.py
CHANGED
@@ -4,10 +4,20 @@ from collections.abc import Callable
|
|
4
4
|
from typing import Any
|
5
5
|
|
6
6
|
from pydantic_ai import Agent, Tool
|
7
|
-
from pydantic_ai.messages import
|
7
|
+
from pydantic_ai.messages import (
|
8
|
+
FinalResultEvent,
|
9
|
+
FunctionToolCallEvent,
|
10
|
+
FunctionToolResultEvent,
|
11
|
+
ModelMessagesTypeAdapter,
|
12
|
+
PartDeltaEvent,
|
13
|
+
PartStartEvent,
|
14
|
+
TextPartDelta,
|
15
|
+
ToolCallPartDelta,
|
16
|
+
)
|
17
|
+
from pydantic_ai.models import Model
|
8
18
|
from pydantic_ai.settings import ModelSettings
|
9
19
|
|
10
|
-
from zrb.attr.type import StrAttr
|
20
|
+
from zrb.attr.type import StrAttr, fstring
|
11
21
|
from zrb.config import LLM_MODEL, LLM_SYSTEM_PROMPT
|
12
22
|
from zrb.context.any_context import AnyContext
|
13
23
|
from zrb.context.any_shared_context import AnySharedContext
|
@@ -15,7 +25,7 @@ from zrb.env.any_env import AnyEnv
|
|
15
25
|
from zrb.input.any_input import AnyInput
|
16
26
|
from zrb.task.any_task import AnyTask
|
17
27
|
from zrb.task.base_task import BaseTask
|
18
|
-
from zrb.util.attr import get_str_attr
|
28
|
+
from zrb.util.attr import get_attr, get_str_attr
|
19
29
|
from zrb.util.cli.style import stylize_faint
|
20
30
|
from zrb.util.file import read_file, write_file
|
21
31
|
from zrb.util.run import run_async
|
@@ -34,7 +44,9 @@ class LLMTask(BaseTask):
|
|
34
44
|
cli_only: bool = False,
|
35
45
|
input: list[AnyInput | None] | AnyInput | None = None,
|
36
46
|
env: list[AnyEnv | None] | AnyEnv | None = None,
|
37
|
-
model:
|
47
|
+
model: (
|
48
|
+
Callable[[AnySharedContext], Model | str | fstring] | Model | None
|
49
|
+
) = LLM_MODEL,
|
38
50
|
model_settings: (
|
39
51
|
ModelSettings | Callable[[AnySharedContext], ModelSettings] | None
|
40
52
|
) = None,
|
@@ -93,7 +105,7 @@ class LLMTask(BaseTask):
|
|
93
105
|
successor=successor,
|
94
106
|
)
|
95
107
|
self._model = model
|
96
|
-
self._model_settings =
|
108
|
+
self._model_settings = model_settings
|
97
109
|
self._agent = agent
|
98
110
|
self._render_model = render_model
|
99
111
|
self._system_prompt = system_prompt
|
@@ -108,6 +120,9 @@ class LLMTask(BaseTask):
|
|
108
120
|
self._render_history_file = render_history_file
|
109
121
|
self._max_call_iteration = max_call_iteration
|
110
122
|
|
123
|
+
def set_model(self, model: Model | str):
|
124
|
+
self._model = model
|
125
|
+
|
111
126
|
def add_tool(self, tool: ToolOrCallable):
|
112
127
|
self._additional_tools.append(tool)
|
113
128
|
|
@@ -115,15 +130,85 @@ class LLMTask(BaseTask):
|
|
115
130
|
history = await self._read_conversation_history(ctx)
|
116
131
|
user_prompt = self._get_message(ctx)
|
117
132
|
agent = self._get_agent(ctx)
|
118
|
-
|
133
|
+
async with agent.iter(
|
119
134
|
user_prompt=user_prompt,
|
120
135
|
message_history=ModelMessagesTypeAdapter.validate_python(history),
|
121
|
-
)
|
122
|
-
|
123
|
-
|
124
|
-
|
136
|
+
) as agent_run:
|
137
|
+
async for node in agent_run:
|
138
|
+
# Each node represents a step in the agent's execution
|
139
|
+
await self._print_node(ctx, agent_run, node)
|
140
|
+
new_history = json.loads(agent_run.result.all_messages_json())
|
125
141
|
await self._write_conversation_history(ctx, new_history)
|
126
|
-
return result.data
|
142
|
+
return agent_run.result.data
|
143
|
+
|
144
|
+
async def _print_node(self, ctx: AnyContext, agent_run: Any, node: Any):
|
145
|
+
if Agent.is_user_prompt_node(node):
|
146
|
+
# A user prompt node => The user has provided input
|
147
|
+
ctx.print(stylize_faint(f">> UserPromptNode: {node.user_prompt}"))
|
148
|
+
elif Agent.is_model_request_node(node):
|
149
|
+
# A model request node => We can stream tokens from the model"s request
|
150
|
+
ctx.print(
|
151
|
+
stylize_faint(">> ModelRequestNode: streaming partial request tokens")
|
152
|
+
)
|
153
|
+
async with node.stream(agent_run.ctx) as request_stream:
|
154
|
+
is_streaming = False
|
155
|
+
async for event in request_stream:
|
156
|
+
if isinstance(event, PartStartEvent):
|
157
|
+
if is_streaming:
|
158
|
+
ctx.print("", plain=True)
|
159
|
+
ctx.print(
|
160
|
+
stylize_faint(
|
161
|
+
f"[Request] Starting part {event.index}: {event.part!r}"
|
162
|
+
),
|
163
|
+
)
|
164
|
+
is_streaming = False
|
165
|
+
elif isinstance(event, PartDeltaEvent):
|
166
|
+
if isinstance(event.delta, TextPartDelta):
|
167
|
+
ctx.print(
|
168
|
+
stylize_faint(f"{event.delta.content_delta}"),
|
169
|
+
end="",
|
170
|
+
plain=is_streaming,
|
171
|
+
)
|
172
|
+
elif isinstance(event.delta, ToolCallPartDelta):
|
173
|
+
ctx.print(
|
174
|
+
stylize_faint(f"{event.delta.args_delta}"),
|
175
|
+
end="",
|
176
|
+
plain=is_streaming,
|
177
|
+
)
|
178
|
+
is_streaming = True
|
179
|
+
elif isinstance(event, FinalResultEvent):
|
180
|
+
if is_streaming:
|
181
|
+
ctx.print("", plain=True)
|
182
|
+
ctx.print(
|
183
|
+
stylize_faint(f"[Result] tool_name={event.tool_name}"),
|
184
|
+
)
|
185
|
+
is_streaming = False
|
186
|
+
if is_streaming:
|
187
|
+
ctx.print("", plain=True)
|
188
|
+
elif Agent.is_call_tools_node(node):
|
189
|
+
# A handle-response node => The model returned some data, potentially calls a tool
|
190
|
+
ctx.print(
|
191
|
+
stylize_faint(
|
192
|
+
">> CallToolsNode: streaming partial response & tool usage"
|
193
|
+
)
|
194
|
+
)
|
195
|
+
async with node.stream(agent_run.ctx) as handle_stream:
|
196
|
+
async for event in handle_stream:
|
197
|
+
if isinstance(event, FunctionToolCallEvent):
|
198
|
+
ctx.print(
|
199
|
+
stylize_faint(
|
200
|
+
f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
|
201
|
+
)
|
202
|
+
)
|
203
|
+
elif isinstance(event, FunctionToolResultEvent):
|
204
|
+
ctx.print(
|
205
|
+
stylize_faint(
|
206
|
+
f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}" # noqa
|
207
|
+
)
|
208
|
+
)
|
209
|
+
elif Agent.is_end_node(node):
|
210
|
+
# Once an End node is reached, the agent run is complete
|
211
|
+
ctx.print(stylize_faint(f"{agent_run.result.data}"))
|
127
212
|
|
128
213
|
async def _write_conversation_history(
|
129
214
|
self, ctx: AnyContext, conversations: list[Any]
|
@@ -135,11 +220,9 @@ class LLMTask(BaseTask):
|
|
135
220
|
write_file(history_file, json.dumps(conversations, indent=2))
|
136
221
|
|
137
222
|
def _get_model_settings(self, ctx: AnyContext) -> ModelSettings | None:
|
138
|
-
if isinstance(self._model_settings, ModelSettings):
|
139
|
-
return self._model_settings
|
140
223
|
if callable(self._model_settings):
|
141
224
|
return self._model_settings(ctx)
|
142
|
-
return
|
225
|
+
return self._model_settings
|
143
226
|
|
144
227
|
def _get_agent(self, ctx: AnyContext) -> Agent:
|
145
228
|
if isinstance(self._agent, Agent):
|
@@ -158,12 +241,16 @@ class LLMTask(BaseTask):
|
|
158
241
|
self._get_model(ctx),
|
159
242
|
system_prompt=self._get_system_prompt(ctx),
|
160
243
|
tools=tools,
|
244
|
+
model_settings=self._get_model_settings(ctx),
|
161
245
|
)
|
162
246
|
|
163
|
-
def _get_model(self, ctx: AnyContext) -> str:
|
164
|
-
|
247
|
+
def _get_model(self, ctx: AnyContext) -> str | Model | None:
|
248
|
+
model = get_attr(
|
165
249
|
ctx, self._model, "ollama_chat/llama3.1", auto_render=self._render_model
|
166
250
|
)
|
251
|
+
if isinstance(model, (Model, str)) or model is None:
|
252
|
+
return model
|
253
|
+
raise ValueError("Invalid model")
|
167
254
|
|
168
255
|
def _get_system_prompt(self, ctx: AnyContext) -> str:
|
169
256
|
return get_str_attr(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: zrb
|
3
|
-
Version: 1.2.
|
3
|
+
Version: 1.2.2
|
4
4
|
Summary: Your Automation Powerhouse
|
5
5
|
Home-page: https://github.com/state-alchemists/zrb
|
6
6
|
License: AGPL-3.0-or-later
|
@@ -24,7 +24,7 @@ Requires-Dist: isort (>=5.13.2,<5.14.0)
|
|
24
24
|
Requires-Dist: libcst (>=1.5.0,<2.0.0)
|
25
25
|
Requires-Dist: pdfplumber (>=0.11.4,<0.12.0) ; extra == "rag"
|
26
26
|
Requires-Dist: psutil (>=6.1.1,<7.0.0)
|
27
|
-
Requires-Dist: pydantic-ai (>=0.0.
|
27
|
+
Requires-Dist: pydantic-ai (>=0.0.31,<0.0.32)
|
28
28
|
Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
|
29
29
|
Requires-Dist: python-jose[cryptography] (>=3.4.0,<4.0.0)
|
30
30
|
Requires-Dist: requests (>=2.32.3,<3.0.0)
|
@@ -7,12 +7,13 @@ zrb/builtin/base64.py,sha256=1YnSwASp7OEAvQcsnHZGpJEvYoI1Z2zTIJ1bCDHfcPQ,921
|
|
7
7
|
zrb/builtin/git.py,sha256=8_qVE_2lVQEVXQ9vhiw8Tn4Prj1VZB78ZjEJJS5Ab3M,5461
|
8
8
|
zrb/builtin/git_subtree.py,sha256=7BKwOkVTWDrR0DXXQ4iJyHqeR6sV5VYRt8y_rEB0EHg,3505
|
9
9
|
zrb/builtin/group.py,sha256=-phJfVpTX3_gUwS1u8-RbZUHe-X41kxDBSmrVh4rq8E,1682
|
10
|
-
zrb/builtin/llm/llm_chat.py,sha256=
|
10
|
+
zrb/builtin/llm/llm_chat.py,sha256=cNAS_AS-Q5NW6qe8dJZh12b6c0zFKdNEvGfwJxNAUmw,5047
|
11
11
|
zrb/builtin/llm/previous-session.js,sha256=xMKZvJoAbrwiyHS0OoPrWuaKxWYLoyR5sguePIoCjTY,816
|
12
|
-
zrb/builtin/llm/tool/api.py,sha256=
|
12
|
+
zrb/builtin/llm/tool/api.py,sha256=bXFE7jihdhUscxJH8lu5imwlYH735AalbCyUTl28BaQ,826
|
13
13
|
zrb/builtin/llm/tool/cli.py,sha256=to_IjkfrMGs6eLfG0cpVN9oyADWYsJQCtyluUhUdBww,253
|
14
|
-
zrb/builtin/llm/tool/
|
15
|
-
zrb/builtin/llm/tool/
|
14
|
+
zrb/builtin/llm/tool/file.py,sha256=ibvh0zrsnponwyZvw6bWMUbpwSv5S5WUWCDfQ6BjVwk,1160
|
15
|
+
zrb/builtin/llm/tool/rag.py,sha256=vEIThEy0JGwXEiNRLOEJAHAE0l1Qie2qvU3ryioeYMk,6066
|
16
|
+
zrb/builtin/llm/tool/web.py,sha256=SDnCtYHZ0Q4DtLbIhc11a0UyyKbTTeW60UfeIKzK35k,3204
|
16
17
|
zrb/builtin/md5.py,sha256=0pNlrfZA0wlZlHvFHLgyqN0JZJWGKQIF5oXxO44_OJk,949
|
17
18
|
zrb/builtin/project/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
19
|
zrb/builtin/project/add/fastapp/fastapp_input.py,sha256=MKlWR_LxWhM_DcULCtLfL_IjTxpDnDBkn9KIqNmajFs,310
|
@@ -207,14 +208,14 @@ zrb/callback/callback.py,sha256=hKefB_Jd1XGjPSLQdMKDsGLHPzEGO2dqrIArLl_EmD0,848
|
|
207
208
|
zrb/cmd/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
208
209
|
zrb/cmd/cmd_result.py,sha256=L8bQJzWCpcYexIxHBNsXj2pT3BtLmWex0iJSMkvimOA,597
|
209
210
|
zrb/cmd/cmd_val.py,sha256=7Doowyg6BK3ISSGBLt-PmlhzaEkBjWWm51cED6fAUOQ,1014
|
210
|
-
zrb/config.py,sha256=
|
211
|
+
zrb/config.py,sha256=MfHwcQ4OhCmCw6jXpFI8483Ase6YrqNGBvqYzwnwopw,4753
|
211
212
|
zrb/content_transformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
212
213
|
zrb/content_transformer/any_content_transformer.py,sha256=v8ZUbcix1GGeDQwB6OKX_1TjpY__ksxWVeqibwa_iZA,850
|
213
|
-
zrb/content_transformer/content_transformer.py,sha256=
|
214
|
+
zrb/content_transformer/content_transformer.py,sha256=STl77wW-I69QaGzCXjvkppngYFLufow8ybPLSyAvlHs,2404
|
214
215
|
zrb/context/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
215
216
|
zrb/context/any_context.py,sha256=2hgVKbbDwmwrEl1h1L1FaTUjuUYaDd_b7YRGkaorW6Q,6362
|
216
217
|
zrb/context/any_shared_context.py,sha256=p1i9af_CUDz5Mf1h1kBZMAa2AEhf17I3O5IgAcjRLoY,1768
|
217
|
-
zrb/context/context.py,sha256=
|
218
|
+
zrb/context/context.py,sha256=VGoUwoWyL9d4QqJEhg41S-X8T2jlssGpiC9YSc3Gjqk,6601
|
218
219
|
zrb/context/shared_context.py,sha256=47Tnnor1ttpwpe_N07rMNM1jgIYPY9abMe1Q5igkMtE,2735
|
219
220
|
zrb/dot_dict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
220
221
|
zrb/dot_dict/dot_dict.py,sha256=ubw_x8I7AOJ59xxtFVJ00VGmq_IYdZP3mUhNlO4nEK0,556
|
@@ -235,7 +236,7 @@ zrb/input/int_input.py,sha256=w5ewSxstUYv5LBAzvX_E0jIueXXdmnY2ehoQMTtg-EA,1380
|
|
235
236
|
zrb/input/option_input.py,sha256=IrpF0XvFbH5G-IEAnoQ4QOvq7gn2wyT4jKwAdMKwV0s,2058
|
236
237
|
zrb/input/password_input.py,sha256=Tu8TZx95717YsHICZ0zBzTUPKPf-K9vGlvRyaOTrFEM,1388
|
237
238
|
zrb/input/str_input.py,sha256=NevZHX9rf1g8eMatPyy-kUX3DglrVAQpzvVpKAzf7bA,81
|
238
|
-
zrb/input/text_input.py,sha256=
|
239
|
+
zrb/input/text_input.py,sha256=gK5LGa9uEBBLwGqhTVmCVgsAuBYEQt3ANDQaSNtSc78,3300
|
239
240
|
zrb/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
240
241
|
zrb/runner/cli.py,sha256=G_ILZCFzpV-kRE3dm1kq6BorB51TLJ34Qmhgy5SIMlU,6734
|
241
242
|
zrb/runner/common_util.py,sha256=mjEBSmfdY2Sun2U5-8y8gGwF82OiRM8sgiYDOdW9NA4,1338
|
@@ -298,7 +299,7 @@ zrb/task/base_task.py,sha256=SQRf37bylS586KwyW0eYDe9JZ5Hl18FP8kScHae6y3A,21251
|
|
298
299
|
zrb/task/base_trigger.py,sha256=jC722rDvodaBLeNaFghkTyv1u0QXrK6BLZUUqcmBJ7Q,4581
|
299
300
|
zrb/task/cmd_task.py,sha256=pUKRSR4DZKjbmluB6vi7cxqyhxOLfJ2czSpYeQbiDvo,10705
|
300
301
|
zrb/task/http_check.py,sha256=Gf5rOB2Se2EdizuN9rp65HpGmfZkGc-clIAlHmPVehs,2565
|
301
|
-
zrb/task/llm_task.py,sha256=
|
302
|
+
zrb/task/llm_task.py,sha256=B4qhza-4fk7odI7-rv2rLYvBLt1dmZMNgKu8OK7rajM,11849
|
302
303
|
zrb/task/make_task.py,sha256=PD3b_aYazthS8LHeJsLAhwKDEgdurQZpymJDKeN60u0,2265
|
303
304
|
zrb/task/rsync_task.py,sha256=GSL9144bmp6F0EckT6m-2a1xG25AzrrWYzH4k3SVUKM,6370
|
304
305
|
zrb/task/scaffolder.py,sha256=rME18w1HJUHXgi9eTYXx_T2G4JdqDYzBoNOkdOOo5-o,6806
|
@@ -339,7 +340,7 @@ zrb/util/string/name.py,sha256=8picJfUBXNpdh64GNaHv3om23QHhUZux7DguFLrXHp8,1163
|
|
339
340
|
zrb/util/todo.py,sha256=1nDdwPc22oFoK_1ZTXyf3638Bg6sqE2yp_U4_-frHoc,16015
|
340
341
|
zrb/xcom/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
341
342
|
zrb/xcom/xcom.py,sha256=o79rxR9wphnShrcIushA0Qt71d_p3ZTxjNf7x9hJB78,1571
|
342
|
-
zrb-1.2.
|
343
|
-
zrb-1.2.
|
344
|
-
zrb-1.2.
|
345
|
-
zrb-1.2.
|
343
|
+
zrb-1.2.2.dist-info/METADATA,sha256=GAm6vQds-lw7StRPEkrq3oTUSBO1jIYpXi9s1-KSfjw,4198
|
344
|
+
zrb-1.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
345
|
+
zrb-1.2.2.dist-info/entry_points.txt,sha256=-Pg3ElWPfnaSM-XvXqCxEAa-wfVI6BEgcs386s8C8v8,46
|
346
|
+
zrb-1.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|