zrb 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,26 @@ from typing import Any
5
5
  from zrb.builtin.group import llm_group
6
6
  from zrb.builtin.llm.tool.api import get_current_location, get_current_weather
7
7
  from zrb.builtin.llm.tool.cli import run_shell_command
8
- from zrb.builtin.llm.tool.web import open_web_route, query_internet
8
+ from zrb.builtin.llm.tool.file import (
9
+ list_file,
10
+ read_source_code,
11
+ read_text_file,
12
+ write_text_file,
13
+ )
14
+ from zrb.builtin.llm.tool.web import (
15
+ create_search_internet_tool,
16
+ open_web_page,
17
+ search_arxiv,
18
+ search_wikipedia,
19
+ )
9
20
  from zrb.config import (
10
21
  LLM_ALLOW_ACCESS_INTERNET,
22
+ LLM_ALLOW_ACCESS_LOCAL_FILE,
11
23
  LLM_ALLOW_ACCESS_SHELL,
12
24
  LLM_HISTORY_DIR,
13
25
  LLM_MODEL,
14
26
  LLM_SYSTEM_PROMPT,
27
+ SERP_API_KEY,
15
28
  )
16
29
  from zrb.context.any_shared_context import AnySharedContext
17
30
  from zrb.input.bool_input import BoolInput
@@ -117,11 +130,21 @@ llm_chat: LLMTask = llm_group.add_task(
117
130
  alias="chat",
118
131
  )
119
132
 
133
+
134
+ if LLM_ALLOW_ACCESS_LOCAL_FILE:
135
+ llm_chat.add_tool(read_source_code)
136
+ llm_chat.add_tool(list_file)
137
+ llm_chat.add_tool(read_text_file)
138
+ llm_chat.add_tool(write_text_file)
139
+
120
140
  if LLM_ALLOW_ACCESS_SHELL:
121
141
  llm_chat.add_tool(run_shell_command)
122
142
 
123
143
  if LLM_ALLOW_ACCESS_INTERNET:
124
- llm_chat.add_tool(open_web_route)
125
- llm_chat.add_tool(query_internet)
144
+ llm_chat.add_tool(open_web_page)
145
+ llm_chat.add_tool(search_wikipedia)
146
+ llm_chat.add_tool(search_arxiv)
147
+ if SERP_API_KEY != "":
148
+ llm_chat.add_tool(create_search_internet_tool(SERP_API_KEY))
126
149
  llm_chat.add_tool(get_current_location)
127
150
  llm_chat.add_tool(get_current_weather)
@@ -1,13 +1,13 @@
1
1
  import json
2
2
  from typing import Annotated, Literal
3
3
 
4
- import requests
5
-
6
4
 
7
5
  def get_current_location() -> (
8
6
  Annotated[str, "JSON string representing latitude and longitude"]
9
7
  ): # noqa
10
8
  """Get the user's current location."""
9
+ import requests
10
+
11
11
  return json.dumps(requests.get("http://ip-api.com/json?fields=lat,lon").json())
12
12
 
13
13
 
@@ -17,6 +17,8 @@ def get_current_weather(
17
17
  temperature_unit: Literal["celsius", "fahrenheit"],
18
18
  ) -> str:
19
19
  """Get the current weather in a given location."""
20
+ import requests
21
+
20
22
  resp = requests.get(
21
23
  "https://api.open-meteo.com/v1/forecast",
22
24
  params={
@@ -0,0 +1,39 @@
1
+ import os
2
+
3
+ from zrb.util.file import read_file, write_file
4
+
5
+
6
+ def list_file(
7
+ directory: str = ".",
8
+ extensions: list[str] = [".py", ".go", ".js", ".ts", ".java", ".c", ".cpp"],
9
+ ) -> list[str]:
10
+ """List all files in a directory"""
11
+ all_files: list[str] = []
12
+ for root, _, files in os.walk(directory):
13
+ for filename in files:
14
+ for extension in extensions:
15
+ if filename.lower().endswith(extension):
16
+ all_files.append(os.path.join(root, filename))
17
+ return all_files
18
+
19
+
20
+ def read_text_file(file: str) -> str:
21
+ """Read a text file"""
22
+ return read_file(os.path.abspath(file))
23
+
24
+
25
+ def write_text_file(file: str, content: str):
26
+ """Write a text file"""
27
+ return write_file(os.path.abspath(file), content)
28
+
29
+
30
+ def read_source_code(
31
+ directory: str = ".",
32
+ extensions: list[str] = [".py", ".go", ".js", ".ts", ".java", ".c", ".cpp"],
33
+ ) -> list[str]:
34
+ """Read source code in a directory"""
35
+ files = list_file(directory, extensions)
36
+ for index, file in enumerate(files):
37
+ content = read_text_file(file)
38
+ files[index] = f"# {file}\n```\n{content}\n```"
39
+ return files
@@ -1,7 +1,9 @@
1
+ import fnmatch
1
2
  import hashlib
2
3
  import json
3
4
  import os
4
5
  import sys
6
+ from collections.abc import Callable
5
7
 
6
8
  import ulid
7
9
 
@@ -15,6 +17,20 @@ from zrb.util.cli.style import stylize_error, stylize_faint
15
17
  from zrb.util.file import read_file
16
18
 
17
19
 
20
+ class RAGFileReader:
21
+ def __init__(self, glob_pattern: str, read: Callable[[str], str]):
22
+ self.glob_pattern = glob_pattern
23
+ self.read = read
24
+
25
+ def is_match(self, file_name: str):
26
+ if os.sep not in self.glob_pattern and (
27
+ os.altsep is None or os.altsep not in self.glob_pattern
28
+ ):
29
+ # Pattern like "*.txt" – match only the basename.
30
+ return fnmatch.fnmatch(os.path.basename(file_name), self.glob_pattern)
31
+ return fnmatch.fnmatch(file_name, self.glob_pattern)
32
+
33
+
18
34
  def create_rag_from_directory(
19
35
  tool_name: str,
20
36
  tool_description: str,
@@ -25,6 +41,7 @@ def create_rag_from_directory(
25
41
  chunk_size: int = RAG_CHUNK_SIZE,
26
42
  overlap: int = RAG_OVERLAP,
27
43
  max_result_count: int = RAG_MAX_RESULT_COUNT,
44
+ file_reader: list[RAGFileReader] = [],
28
45
  ):
29
46
  async def retrieve(query: str) -> str:
30
47
  from chromadb import PersistentClient
@@ -36,35 +53,31 @@ def create_rag_from_directory(
36
53
  path=vector_db_path, settings=Settings(allow_reset=True)
37
54
  )
38
55
  collection = client.get_or_create_collection(vector_db_collection)
39
-
40
56
  # Track file changes using a hash-based approach
41
57
  hash_file_path = os.path.join(vector_db_path, "file_hashes.json")
42
58
  previous_hashes = _load_hashes(hash_file_path)
43
59
  current_hashes = {}
44
-
60
+ # Get updated_files
45
61
  updated_files = []
46
-
47
62
  for root, _, files in os.walk(document_dir_path):
48
63
  for file in files:
49
64
  file_path = os.path.join(root, file)
50
65
  file_hash = _compute_file_hash(file_path)
51
66
  relative_path = os.path.relpath(file_path, document_dir_path)
52
67
  current_hashes[relative_path] = file_hash
53
-
54
68
  if previous_hashes.get(relative_path) != file_hash:
55
69
  updated_files.append(file_path)
56
-
70
+ # Upsert updated_files to vector db
57
71
  if updated_files:
58
72
  print(
59
73
  stylize_faint(f"Updating {len(updated_files)} changed files"),
60
74
  file=sys.stderr,
61
75
  )
62
-
63
76
  for file_path in updated_files:
64
77
  try:
65
78
  relative_path = os.path.relpath(file_path, document_dir_path)
66
79
  collection.delete(where={"file_path": relative_path})
67
- content = _read_file_content(file_path)
80
+ content = _read_txt_content(file_path, file_reader)
68
81
  file_id = ulid.new().str
69
82
  for i in range(0, len(content), chunk_size - overlap):
70
83
  chunk = content[i : i + chunk_size]
@@ -92,14 +105,13 @@ def create_rag_from_directory(
92
105
  stylize_error(f"Error processing {file_path}: {e}"),
93
106
  file=sys.stderr,
94
107
  )
95
-
96
108
  _save_hashes(hash_file_path, current_hashes)
97
109
  else:
98
110
  print(
99
111
  stylize_faint("No changes detected. Skipping database update."),
100
112
  file=sys.stderr,
101
113
  )
102
-
114
+ # Vectorize query and get related document chunks
103
115
  print(stylize_faint("Vectorizing query"), file=sys.stderr)
104
116
  embedding_result = list(embedding_model.embed([query]))
105
117
  query_vector = embedding_result[0]
@@ -123,7 +135,22 @@ def _compute_file_hash(file_path: str) -> str:
123
135
  return hash_md5.hexdigest()
124
136
 
125
137
 
126
- def _read_file_content(file_path: str) -> str:
138
+ def _load_hashes(file_path: str) -> dict:
139
+ if os.path.exists(file_path):
140
+ with open(file_path, "r") as f:
141
+ return json.load(f)
142
+ return {}
143
+
144
+
145
+ def _save_hashes(file_path: str, hashes: dict):
146
+ with open(file_path, "w") as f:
147
+ json.dump(hashes, f)
148
+
149
+
150
+ def _read_txt_content(file_path: str, file_reader: list[RAGFileReader]):
151
+ for reader in file_reader:
152
+ if reader.is_match(file_path):
153
+ return reader.read(file_path)
127
154
  if file_path.lower().endswith(".pdf"):
128
155
  return _read_pdf(file_path)
129
156
  return read_file(file_path)
@@ -136,15 +163,3 @@ def _read_pdf(file_path: str) -> str:
136
163
  return "\n".join(
137
164
  page.extract_text() for page in pdf.pages if page.extract_text()
138
165
  )
139
-
140
-
141
- def _load_hashes(file_path: str) -> dict:
142
- if os.path.exists(file_path):
143
- with open(file_path, "r") as f:
144
- return json.load(f)
145
- return {}
146
-
147
-
148
- def _save_hashes(file_path: str, hashes: dict):
149
- with open(file_path, "w") as f:
150
- json.dump(hashes, f)
@@ -1,8 +1,9 @@
1
1
  import json
2
+ from collections.abc import Callable
2
3
  from typing import Annotated
3
4
 
4
5
 
5
- def open_web_route(url: str) -> str:
6
+ def open_web_page(url: str) -> str:
6
7
  """Get content from a web page."""
7
8
  import requests
8
9
 
@@ -19,30 +20,55 @@ def open_web_route(url: str) -> str:
19
20
  return json.dumps(parse_html_text(response.text))
20
21
 
21
22
 
22
- def query_internet(
23
+ def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
24
+ def search_internet(
25
+ query: Annotated[str, "Search query"],
26
+ num_results: Annotated[int, "Search result count, by default 10"] = 10,
27
+ ) -> str:
28
+ """Search factual information from the internet by using Google."""
29
+ import requests
30
+
31
+ response = requests.get(
32
+ "https://serpapi.com/search",
33
+ headers={
34
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
35
+ },
36
+ params={
37
+ "q": query,
38
+ "num": num_results,
39
+ "hl": "en",
40
+ "safe": "off",
41
+ "api_key": serp_api_key,
42
+ },
43
+ )
44
+ if response.status_code != 200:
45
+ raise Exception(
46
+ f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
47
+ )
48
+ return json.dumps(parse_html_text(response.text))
49
+
50
+ return search_internet
51
+
52
+
53
+ def search_wikipedia(query: Annotated[str, "Search query"]) -> str:
54
+ """Search on wikipedia"""
55
+ import requests
56
+
57
+ params = {"action": "query", "list": "search", "srsearch": query, "format": "json"}
58
+ response = requests.get("https://en.wikipedia.org/w/api.php", params=params)
59
+ return response.json()
60
+
61
+
62
+ def search_arxiv(
23
63
  query: Annotated[str, "Search query"],
24
64
  num_results: Annotated[int, "Search result count, by default 10"] = 10,
25
65
  ) -> str:
26
- """Search factual information from the internet by using Google."""
66
+ """Search on Arxiv"""
27
67
  import requests
28
68
 
29
- response = requests.get(
30
- "https://google.com/search",
31
- headers={
32
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
33
- },
34
- params={
35
- "q": query,
36
- "num": num_results,
37
- "hl": "en",
38
- "safe": "off",
39
- },
40
- )
41
- if response.status_code != 200:
42
- raise Exception(
43
- f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
44
- )
45
- return json.dumps(parse_html_text(response.text))
69
+ params = {"search_query": f"all:{query}", "start": 0, "max_results": num_results}
70
+ response = requests.get("http://export.arxiv.org/api/query", params=params)
71
+ return response.content
46
72
 
47
73
 
48
74
  def parse_html_text(html_text: str) -> dict[str, str]:
zrb/config.py CHANGED
@@ -92,7 +92,8 @@ LLM_HISTORY_DIR = os.getenv(
92
92
  LLM_HISTORY_FILE = os.getenv(
93
93
  "ZRB_LLM_HISTORY_FILE", os.path.join(LLM_HISTORY_DIR, "history.json")
94
94
  )
95
- LLM_ALLOW_ACCESS_SHELL = to_boolean(os.getenv("ZRB_LLM_ACCESS_FILE", "1"))
95
+ LLM_ALLOW_ACCESS_LOCAL_FILE = to_boolean(os.getenv("ZRB_LLM_ACCESS_LOCAL_FILE", "1"))
96
+ LLM_ALLOW_ACCESS_SHELL = to_boolean(os.getenv("ZRB_LLM_ACCESS_SHELL", "1"))
96
97
  LLM_ALLOW_ACCESS_INTERNET = to_boolean(os.getenv("ZRB_LLM_ACCESS_INTERNET", "1"))
97
98
  # noqa See: https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models
98
99
  RAG_EMBEDDING_MODEL = os.getenv(
@@ -101,6 +102,7 @@ RAG_EMBEDDING_MODEL = os.getenv(
101
102
  RAG_CHUNK_SIZE = int(os.getenv("ZRB_RAG_CHUNK_SIZE", "1024"))
102
103
  RAG_OVERLAP = int(os.getenv("ZRB_RAG_OVERLAP", "128"))
103
104
  RAG_MAX_RESULT_COUNT = int(os.getenv("ZRB_RAG_MAX_RESULT_COUNT", "5"))
105
+ SERP_API_KEY = os.getenv("SERP_API_KEY", "")
104
106
 
105
107
 
106
108
  BANNER = f"""
@@ -1,4 +1,5 @@
1
1
  import fnmatch
2
+ import os
2
3
  import re
3
4
  from collections.abc import Callable
4
5
 
@@ -40,7 +41,12 @@ class ContentTransformer(AnyContentTransformer):
40
41
  return True
41
42
  except re.error:
42
43
  pass
43
- return fnmatch.fnmatch(file_path, pattern)
44
+ if os.sep not in pattern and (
45
+ os.altsep is None or os.altsep not in pattern
46
+ ):
47
+ # Pattern like "*.txt" – match only the basename.
48
+ return fnmatch.fnmatch(file_path, os.path.basename(file_path))
49
+ return fnmatch.fnmatch(file_path, file_path)
44
50
 
45
51
  def transform_file(self, ctx: AnyContext, file_path: str):
46
52
  if callable(self._transform_file):
zrb/context/context.py CHANGED
@@ -88,7 +88,7 @@ class Context(AnyContext):
88
88
  return template
89
89
  return int(self.render(template))
90
90
 
91
- def render_float(self, template: str) -> float:
91
+ def render_float(self, template: str | float) -> float:
92
92
  if isinstance(template, float):
93
93
  return template
94
94
  return float(self.render(template))
@@ -102,9 +102,10 @@ class Context(AnyContext):
102
102
  flush: bool = True,
103
103
  plain: bool = False,
104
104
  ):
105
+ sep = " " if sep is None else sep
105
106
  message = sep.join([f"{value}" for value in values])
106
107
  if plain:
107
- self.append_to_shared_log(remove_style(message))
108
+ # self.append_to_shared_log(remove_style(message))
108
109
  print(message, sep=sep, end=end, file=file, flush=flush)
109
110
  return
110
111
  color = self._color
@@ -132,6 +133,7 @@ class Context(AnyContext):
132
133
  flush: bool = True,
133
134
  ):
134
135
  if self._shared_ctx.get_logging_level() <= logging.DEBUG:
136
+ sep = " " if sep is None else sep
135
137
  message = sep.join([f"{value}" for value in values])
136
138
  stylized_message = stylize_log(f"[DEBUG] {message}")
137
139
  self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
@@ -145,6 +147,7 @@ class Context(AnyContext):
145
147
  flush: bool = True,
146
148
  ):
147
149
  if self._shared_ctx.get_logging_level() <= logging.INFO:
150
+ sep = " " if sep is None else sep
148
151
  message = sep.join([f"{value}" for value in values])
149
152
  stylized_message = stylize_log(f"[INFO] {message}")
150
153
  self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
@@ -158,6 +161,7 @@ class Context(AnyContext):
158
161
  flush: bool = True,
159
162
  ):
160
163
  if self._shared_ctx.get_logging_level() <= logging.INFO:
164
+ sep = " " if sep is None else sep
161
165
  message = sep.join([f"{value}" for value in values])
162
166
  stylized_message = stylize_warning(f"[WARNING] {message}")
163
167
  self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
@@ -171,6 +175,7 @@ class Context(AnyContext):
171
175
  flush: bool = True,
172
176
  ):
173
177
  if self._shared_ctx.get_logging_level() <= logging.ERROR:
178
+ sep = " " if sep is None else sep
174
179
  message = sep.join([f"{value}" for value in values])
175
180
  stylized_message = stylize_error(f"[ERROR] {message}")
176
181
  self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
@@ -184,6 +189,7 @@ class Context(AnyContext):
184
189
  flush: bool = True,
185
190
  ):
186
191
  if self._shared_ctx.get_logging_level() <= logging.CRITICAL:
192
+ sep = " " if sep is None else sep
187
193
  message = sep.join([f"{value}" for value in values])
188
194
  stylized_message = stylize_error(f"[CRITICAL] {message}")
189
195
  self.print(stylized_message, sep=sep, end=end, file=file, flush=flush)
zrb/input/text_input.py CHANGED
@@ -69,16 +69,17 @@ class TextInput(BaseInput):
69
69
  )
70
70
 
71
71
  def _prompt_cli_str(self, shared_ctx: AnySharedContext) -> str:
72
- prompt_message = (
73
- f"{self.comment_start}{super().prompt_message}{self.comment_end}"
72
+ prompt_message = super().prompt_message
73
+ comment_prompt_message = (
74
+ f"{self.comment_start}{prompt_message}{self.comment_end}"
74
75
  )
75
- prompt_message_eol = f"{prompt_message}\n"
76
+ comment_prompt_message_eol = f"{comment_prompt_message}\n"
76
77
  default_value = self.get_default_str(shared_ctx)
77
78
  with tempfile.NamedTemporaryFile(
78
79
  delete=False, suffix=self._extension
79
80
  ) as temp_file:
80
81
  temp_file_name = temp_file.name
81
- temp_file.write(prompt_message_eol.encode())
82
+ temp_file.write(comment_prompt_message_eol.encode())
82
83
  # Pre-fill with default content
83
84
  if default_value:
84
85
  temp_file.write(default_value.encode())
@@ -87,7 +88,10 @@ class TextInput(BaseInput):
87
88
  subprocess.call([self._editor, temp_file_name])
88
89
  # Read the edited content
89
90
  edited_content = read_file(temp_file_name)
90
- parts = [text.strip() for text in edited_content.split(prompt_message, 1)]
91
+ parts = [
92
+ text.strip() for text in edited_content.split(comment_prompt_message, 1)
93
+ ]
91
94
  edited_content = "\n".join(parts).lstrip()
92
95
  os.remove(temp_file_name)
96
+ print(f"{prompt_message}: {edited_content}")
93
97
  return edited_content
zrb/task/llm_task.py CHANGED
@@ -4,10 +4,20 @@ from collections.abc import Callable
4
4
  from typing import Any
5
5
 
6
6
  from pydantic_ai import Agent, Tool
7
- from pydantic_ai.messages import ModelMessagesTypeAdapter
7
+ from pydantic_ai.messages import (
8
+ FinalResultEvent,
9
+ FunctionToolCallEvent,
10
+ FunctionToolResultEvent,
11
+ ModelMessagesTypeAdapter,
12
+ PartDeltaEvent,
13
+ PartStartEvent,
14
+ TextPartDelta,
15
+ ToolCallPartDelta,
16
+ )
17
+ from pydantic_ai.models import Model
8
18
  from pydantic_ai.settings import ModelSettings
9
19
 
10
- from zrb.attr.type import StrAttr
20
+ from zrb.attr.type import StrAttr, fstring
11
21
  from zrb.config import LLM_MODEL, LLM_SYSTEM_PROMPT
12
22
  from zrb.context.any_context import AnyContext
13
23
  from zrb.context.any_shared_context import AnySharedContext
@@ -15,7 +25,7 @@ from zrb.env.any_env import AnyEnv
15
25
  from zrb.input.any_input import AnyInput
16
26
  from zrb.task.any_task import AnyTask
17
27
  from zrb.task.base_task import BaseTask
18
- from zrb.util.attr import get_str_attr
28
+ from zrb.util.attr import get_attr, get_str_attr
19
29
  from zrb.util.cli.style import stylize_faint
20
30
  from zrb.util.file import read_file, write_file
21
31
  from zrb.util.run import run_async
@@ -34,7 +44,9 @@ class LLMTask(BaseTask):
34
44
  cli_only: bool = False,
35
45
  input: list[AnyInput | None] | AnyInput | None = None,
36
46
  env: list[AnyEnv | None] | AnyEnv | None = None,
37
- model: StrAttr | None = LLM_MODEL,
47
+ model: (
48
+ Callable[[AnySharedContext], Model | str | fstring] | Model | None
49
+ ) = LLM_MODEL,
38
50
  model_settings: (
39
51
  ModelSettings | Callable[[AnySharedContext], ModelSettings] | None
40
52
  ) = None,
@@ -93,7 +105,7 @@ class LLMTask(BaseTask):
93
105
  successor=successor,
94
106
  )
95
107
  self._model = model
96
- self._model_settings = (model_settings,)
108
+ self._model_settings = model_settings
97
109
  self._agent = agent
98
110
  self._render_model = render_model
99
111
  self._system_prompt = system_prompt
@@ -108,6 +120,9 @@ class LLMTask(BaseTask):
108
120
  self._render_history_file = render_history_file
109
121
  self._max_call_iteration = max_call_iteration
110
122
 
123
+ def set_model(self, model: Model | str):
124
+ self._model = model
125
+
111
126
  def add_tool(self, tool: ToolOrCallable):
112
127
  self._additional_tools.append(tool)
113
128
 
@@ -115,15 +130,85 @@ class LLMTask(BaseTask):
115
130
  history = await self._read_conversation_history(ctx)
116
131
  user_prompt = self._get_message(ctx)
117
132
  agent = self._get_agent(ctx)
118
- result = await agent.run(
133
+ async with agent.iter(
119
134
  user_prompt=user_prompt,
120
135
  message_history=ModelMessagesTypeAdapter.validate_python(history),
121
- )
122
- new_history = json.loads(result.all_messages_json())
123
- for history in new_history:
124
- ctx.print(stylize_faint(json.dumps(history)))
136
+ ) as agent_run:
137
+ async for node in agent_run:
138
+ # Each node represents a step in the agent's execution
139
+ await self._print_node(ctx, agent_run, node)
140
+ new_history = json.loads(agent_run.result.all_messages_json())
125
141
  await self._write_conversation_history(ctx, new_history)
126
- return result.data
142
+ return agent_run.result.data
143
+
144
+ async def _print_node(self, ctx: AnyContext, agent_run: Any, node: Any):
145
+ if Agent.is_user_prompt_node(node):
146
+ # A user prompt node => The user has provided input
147
+ ctx.print(stylize_faint(f">> UserPromptNode: {node.user_prompt}"))
148
+ elif Agent.is_model_request_node(node):
149
+ # A model request node => We can stream tokens from the model"s request
150
+ ctx.print(
151
+ stylize_faint(">> ModelRequestNode: streaming partial request tokens")
152
+ )
153
+ async with node.stream(agent_run.ctx) as request_stream:
154
+ is_streaming = False
155
+ async for event in request_stream:
156
+ if isinstance(event, PartStartEvent):
157
+ if is_streaming:
158
+ ctx.print("", plain=True)
159
+ ctx.print(
160
+ stylize_faint(
161
+ f"[Request] Starting part {event.index}: {event.part!r}"
162
+ ),
163
+ )
164
+ is_streaming = False
165
+ elif isinstance(event, PartDeltaEvent):
166
+ if isinstance(event.delta, TextPartDelta):
167
+ ctx.print(
168
+ stylize_faint(f"{event.delta.content_delta}"),
169
+ end="",
170
+ plain=is_streaming,
171
+ )
172
+ elif isinstance(event.delta, ToolCallPartDelta):
173
+ ctx.print(
174
+ stylize_faint(f"{event.delta.args_delta}"),
175
+ end="",
176
+ plain=is_streaming,
177
+ )
178
+ is_streaming = True
179
+ elif isinstance(event, FinalResultEvent):
180
+ if is_streaming:
181
+ ctx.print("", plain=True)
182
+ ctx.print(
183
+ stylize_faint(f"[Result] tool_name={event.tool_name}"),
184
+ )
185
+ is_streaming = False
186
+ if is_streaming:
187
+ ctx.print("", plain=True)
188
+ elif Agent.is_call_tools_node(node):
189
+ # A handle-response node => The model returned some data, potentially calls a tool
190
+ ctx.print(
191
+ stylize_faint(
192
+ ">> CallToolsNode: streaming partial response & tool usage"
193
+ )
194
+ )
195
+ async with node.stream(agent_run.ctx) as handle_stream:
196
+ async for event in handle_stream:
197
+ if isinstance(event, FunctionToolCallEvent):
198
+ ctx.print(
199
+ stylize_faint(
200
+ f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
201
+ )
202
+ )
203
+ elif isinstance(event, FunctionToolResultEvent):
204
+ ctx.print(
205
+ stylize_faint(
206
+ f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}" # noqa
207
+ )
208
+ )
209
+ elif Agent.is_end_node(node):
210
+ # Once an End node is reached, the agent run is complete
211
+ ctx.print(stylize_faint(f"{agent_run.result.data}"))
127
212
 
128
213
  async def _write_conversation_history(
129
214
  self, ctx: AnyContext, conversations: list[Any]
@@ -135,11 +220,9 @@ class LLMTask(BaseTask):
135
220
  write_file(history_file, json.dumps(conversations, indent=2))
136
221
 
137
222
  def _get_model_settings(self, ctx: AnyContext) -> ModelSettings | None:
138
- if isinstance(self._model_settings, ModelSettings):
139
- return self._model_settings
140
223
  if callable(self._model_settings):
141
224
  return self._model_settings(ctx)
142
- return None
225
+ return self._model_settings
143
226
 
144
227
  def _get_agent(self, ctx: AnyContext) -> Agent:
145
228
  if isinstance(self._agent, Agent):
@@ -158,12 +241,16 @@ class LLMTask(BaseTask):
158
241
  self._get_model(ctx),
159
242
  system_prompt=self._get_system_prompt(ctx),
160
243
  tools=tools,
244
+ model_settings=self._get_model_settings(ctx),
161
245
  )
162
246
 
163
- def _get_model(self, ctx: AnyContext) -> str:
164
- return get_str_attr(
247
+ def _get_model(self, ctx: AnyContext) -> str | Model | None:
248
+ model = get_attr(
165
249
  ctx, self._model, "ollama_chat/llama3.1", auto_render=self._render_model
166
250
  )
251
+ if isinstance(model, (Model, str)) or model is None:
252
+ return model
253
+ raise ValueError("Invalid model")
167
254
 
168
255
  def _get_system_prompt(self, ctx: AnyContext) -> str:
169
256
  return get_str_attr(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zrb
3
- Version: 1.2.1
3
+ Version: 1.2.2
4
4
  Summary: Your Automation Powerhouse
5
5
  Home-page: https://github.com/state-alchemists/zrb
6
6
  License: AGPL-3.0-or-later
@@ -24,7 +24,7 @@ Requires-Dist: isort (>=5.13.2,<5.14.0)
24
24
  Requires-Dist: libcst (>=1.5.0,<2.0.0)
25
25
  Requires-Dist: pdfplumber (>=0.11.4,<0.12.0) ; extra == "rag"
26
26
  Requires-Dist: psutil (>=6.1.1,<7.0.0)
27
- Requires-Dist: pydantic-ai (>=0.0.19,<0.0.20)
27
+ Requires-Dist: pydantic-ai (>=0.0.31,<0.0.32)
28
28
  Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
29
29
  Requires-Dist: python-jose[cryptography] (>=3.4.0,<4.0.0)
30
30
  Requires-Dist: requests (>=2.32.3,<3.0.0)
@@ -7,12 +7,13 @@ zrb/builtin/base64.py,sha256=1YnSwASp7OEAvQcsnHZGpJEvYoI1Z2zTIJ1bCDHfcPQ,921
7
7
  zrb/builtin/git.py,sha256=8_qVE_2lVQEVXQ9vhiw8Tn4Prj1VZB78ZjEJJS5Ab3M,5461
8
8
  zrb/builtin/git_subtree.py,sha256=7BKwOkVTWDrR0DXXQ4iJyHqeR6sV5VYRt8y_rEB0EHg,3505
9
9
  zrb/builtin/group.py,sha256=-phJfVpTX3_gUwS1u8-RbZUHe-X41kxDBSmrVh4rq8E,1682
10
- zrb/builtin/llm/llm_chat.py,sha256=UVhfJR0APXRC3tEv5i5vYbEANWqi04QD_WsAiARJ7j4,4494
10
+ zrb/builtin/llm/llm_chat.py,sha256=cNAS_AS-Q5NW6qe8dJZh12b6c0zFKdNEvGfwJxNAUmw,5047
11
11
  zrb/builtin/llm/previous-session.js,sha256=xMKZvJoAbrwiyHS0OoPrWuaKxWYLoyR5sguePIoCjTY,816
12
- zrb/builtin/llm/tool/api.py,sha256=yQ3XV8O7Fx7hHssLSOcmiHDnevPhz9ktWi44HK7zTls,801
12
+ zrb/builtin/llm/tool/api.py,sha256=bXFE7jihdhUscxJH8lu5imwlYH735AalbCyUTl28BaQ,826
13
13
  zrb/builtin/llm/tool/cli.py,sha256=to_IjkfrMGs6eLfG0cpVN9oyADWYsJQCtyluUhUdBww,253
14
- zrb/builtin/llm/tool/rag.py,sha256=PawaLZL-ThctxtBtsQuP3XsgTxQKyCGFqrudCANPJKk,5162
15
- zrb/builtin/llm/tool/web.py,sha256=N2HYuXbKPUpjVAq_UnQMbUrTIE8u0Ut3TeQadZ7_NJc,2217
14
+ zrb/builtin/llm/tool/file.py,sha256=ibvh0zrsnponwyZvw6bWMUbpwSv5S5WUWCDfQ6BjVwk,1160
15
+ zrb/builtin/llm/tool/rag.py,sha256=vEIThEy0JGwXEiNRLOEJAHAE0l1Qie2qvU3ryioeYMk,6066
16
+ zrb/builtin/llm/tool/web.py,sha256=SDnCtYHZ0Q4DtLbIhc11a0UyyKbTTeW60UfeIKzK35k,3204
16
17
  zrb/builtin/md5.py,sha256=0pNlrfZA0wlZlHvFHLgyqN0JZJWGKQIF5oXxO44_OJk,949
17
18
  zrb/builtin/project/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
19
  zrb/builtin/project/add/fastapp/fastapp_input.py,sha256=MKlWR_LxWhM_DcULCtLfL_IjTxpDnDBkn9KIqNmajFs,310
@@ -207,14 +208,14 @@ zrb/callback/callback.py,sha256=hKefB_Jd1XGjPSLQdMKDsGLHPzEGO2dqrIArLl_EmD0,848
207
208
  zrb/cmd/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
208
209
  zrb/cmd/cmd_result.py,sha256=L8bQJzWCpcYexIxHBNsXj2pT3BtLmWex0iJSMkvimOA,597
209
210
  zrb/cmd/cmd_val.py,sha256=7Doowyg6BK3ISSGBLt-PmlhzaEkBjWWm51cED6fAUOQ,1014
210
- zrb/config.py,sha256=Kb-GOsUS4poSCds4Wqg9LkscpS7BHXSy3dQmqvsFm2Q,4621
211
+ zrb/config.py,sha256=MfHwcQ4OhCmCw6jXpFI8483Ase6YrqNGBvqYzwnwopw,4753
211
212
  zrb/content_transformer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
212
213
  zrb/content_transformer/any_content_transformer.py,sha256=v8ZUbcix1GGeDQwB6OKX_1TjpY__ksxWVeqibwa_iZA,850
213
- zrb/content_transformer/content_transformer.py,sha256=YU6Xr3G_IaCWKQGsf9z9YlCclbiwcJ7ytQv3wKpPIiI,2125
214
+ zrb/content_transformer/content_transformer.py,sha256=STl77wW-I69QaGzCXjvkppngYFLufow8ybPLSyAvlHs,2404
214
215
  zrb/context/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
215
216
  zrb/context/any_context.py,sha256=2hgVKbbDwmwrEl1h1L1FaTUjuUYaDd_b7YRGkaorW6Q,6362
216
217
  zrb/context/any_shared_context.py,sha256=p1i9af_CUDz5Mf1h1kBZMAa2AEhf17I3O5IgAcjRLoY,1768
217
- zrb/context/context.py,sha256=qVMqt2tkLEFSI81mLYb_OSD615KH5jP685aUmHEm3XQ,6319
218
+ zrb/context/context.py,sha256=VGoUwoWyL9d4QqJEhg41S-X8T2jlssGpiC9YSc3Gjqk,6601
218
219
  zrb/context/shared_context.py,sha256=47Tnnor1ttpwpe_N07rMNM1jgIYPY9abMe1Q5igkMtE,2735
219
220
  zrb/dot_dict/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
220
221
  zrb/dot_dict/dot_dict.py,sha256=ubw_x8I7AOJ59xxtFVJ00VGmq_IYdZP3mUhNlO4nEK0,556
@@ -235,7 +236,7 @@ zrb/input/int_input.py,sha256=w5ewSxstUYv5LBAzvX_E0jIueXXdmnY2ehoQMTtg-EA,1380
235
236
  zrb/input/option_input.py,sha256=IrpF0XvFbH5G-IEAnoQ4QOvq7gn2wyT4jKwAdMKwV0s,2058
236
237
  zrb/input/password_input.py,sha256=Tu8TZx95717YsHICZ0zBzTUPKPf-K9vGlvRyaOTrFEM,1388
237
238
  zrb/input/str_input.py,sha256=NevZHX9rf1g8eMatPyy-kUX3DglrVAQpzvVpKAzf7bA,81
238
- zrb/input/text_input.py,sha256=wSNiYAx2xYPtl09Dfh_uHws9WG2dRqkS0Jnlm9HvD3s,3145
239
+ zrb/input/text_input.py,sha256=gK5LGa9uEBBLwGqhTVmCVgsAuBYEQt3ANDQaSNtSc78,3300
239
240
  zrb/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
240
241
  zrb/runner/cli.py,sha256=G_ILZCFzpV-kRE3dm1kq6BorB51TLJ34Qmhgy5SIMlU,6734
241
242
  zrb/runner/common_util.py,sha256=mjEBSmfdY2Sun2U5-8y8gGwF82OiRM8sgiYDOdW9NA4,1338
@@ -298,7 +299,7 @@ zrb/task/base_task.py,sha256=SQRf37bylS586KwyW0eYDe9JZ5Hl18FP8kScHae6y3A,21251
298
299
  zrb/task/base_trigger.py,sha256=jC722rDvodaBLeNaFghkTyv1u0QXrK6BLZUUqcmBJ7Q,4581
299
300
  zrb/task/cmd_task.py,sha256=pUKRSR4DZKjbmluB6vi7cxqyhxOLfJ2czSpYeQbiDvo,10705
300
301
  zrb/task/http_check.py,sha256=Gf5rOB2Se2EdizuN9rp65HpGmfZkGc-clIAlHmPVehs,2565
301
- zrb/task/llm_task.py,sha256=ptXC3x9Dwn7-4JrGQyEtzOXZ4dNQATDgCeowkvwAu9U,7723
302
+ zrb/task/llm_task.py,sha256=B4qhza-4fk7odI7-rv2rLYvBLt1dmZMNgKu8OK7rajM,11849
302
303
  zrb/task/make_task.py,sha256=PD3b_aYazthS8LHeJsLAhwKDEgdurQZpymJDKeN60u0,2265
303
304
  zrb/task/rsync_task.py,sha256=GSL9144bmp6F0EckT6m-2a1xG25AzrrWYzH4k3SVUKM,6370
304
305
  zrb/task/scaffolder.py,sha256=rME18w1HJUHXgi9eTYXx_T2G4JdqDYzBoNOkdOOo5-o,6806
@@ -339,7 +340,7 @@ zrb/util/string/name.py,sha256=8picJfUBXNpdh64GNaHv3om23QHhUZux7DguFLrXHp8,1163
339
340
  zrb/util/todo.py,sha256=1nDdwPc22oFoK_1ZTXyf3638Bg6sqE2yp_U4_-frHoc,16015
340
341
  zrb/xcom/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
341
342
  zrb/xcom/xcom.py,sha256=o79rxR9wphnShrcIushA0Qt71d_p3ZTxjNf7x9hJB78,1571
342
- zrb-1.2.1.dist-info/METADATA,sha256=Oj5aFm5hZeFXkWttWm6MNwk5uBCwVIGPHj2Y5IdoQyo,4198
343
- zrb-1.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
344
- zrb-1.2.1.dist-info/entry_points.txt,sha256=-Pg3ElWPfnaSM-XvXqCxEAa-wfVI6BEgcs386s8C8v8,46
345
- zrb-1.2.1.dist-info/RECORD,,
343
+ zrb-1.2.2.dist-info/METADATA,sha256=GAm6vQds-lw7StRPEkrq3oTUSBO1jIYpXi9s1-KSfjw,4198
344
+ zrb-1.2.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
345
+ zrb-1.2.2.dist-info/entry_points.txt,sha256=-Pg3ElWPfnaSM-XvXqCxEAa-wfVI6BEgcs386s8C8v8,46
346
+ zrb-1.2.2.dist-info/RECORD,,
File without changes