zrb 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,8 @@ import ulid
9
9
 
10
10
  from zrb.config import (
11
11
  RAG_CHUNK_SIZE,
12
+ RAG_EMBEDDING_API_KEY,
13
+ RAG_EMBEDDING_BASE_URL,
12
14
  RAG_EMBEDDING_MODEL,
13
15
  RAG_MAX_RESULT_COUNT,
14
16
  RAG_OVERLAP,
@@ -35,24 +37,34 @@ def create_rag_from_directory(
35
37
  tool_name: str,
36
38
  tool_description: str,
37
39
  document_dir_path: str = "./documents",
38
- model: str = RAG_EMBEDDING_MODEL,
39
40
  vector_db_path: str = "./chroma",
40
41
  vector_db_collection: str = "documents",
41
42
  chunk_size: int = RAG_CHUNK_SIZE,
42
43
  overlap: int = RAG_OVERLAP,
43
44
  max_result_count: int = RAG_MAX_RESULT_COUNT,
44
45
  file_reader: list[RAGFileReader] = [],
46
+ openai_api_key: str = RAG_EMBEDDING_API_KEY,
47
+ openai_base_url: str = RAG_EMBEDDING_BASE_URL,
48
+ openai_embedding_model: str = RAG_EMBEDDING_MODEL,
45
49
  ):
46
50
  async def retrieve(query: str) -> str:
47
51
  from chromadb import PersistentClient
48
52
  from chromadb.config import Settings
49
- from fastembed import TextEmbedding
50
-
51
- embedding_model = TextEmbedding(model_name=model)
52
- client = PersistentClient(
53
+ from openai import OpenAI
54
+
55
+ # Initialize OpenAI client with custom URL if provided
56
+ client_args = {}
57
+ if openai_api_key:
58
+ client_args["api_key"] = openai_api_key
59
+ if openai_base_url:
60
+ client_args["base_url"] = openai_base_url
61
+ # Initialize OpenAI client for embeddings
62
+ openai_client = OpenAI(**client_args)
63
+ # Initialize ChromaDB client
64
+ chroma_client = PersistentClient(
53
65
  path=vector_db_path, settings=Settings(allow_reset=True)
54
66
  )
55
- collection = client.get_or_create_collection(vector_db_collection)
67
+ collection = chroma_client.get_or_create_collection(vector_db_collection)
56
68
  # Track file changes using a hash-based approach
57
69
  hash_file_path = os.path.join(vector_db_path, "file_hashes.json")
58
70
  previous_hashes = _load_hashes(hash_file_path)
@@ -89,8 +101,11 @@ def create_rag_from_directory(
89
101
  ),
90
102
  file=sys.stderr,
91
103
  )
92
- embedding_result = list(embedding_model.embed([chunk]))
93
- vector = embedding_result[0]
104
+ # Get embeddings using OpenAI
105
+ embedding_response = openai_client.embeddings.create(
106
+ input=chunk, model=openai_embedding_model
107
+ )
108
+ vector = embedding_response.data[0].embedding
94
109
  collection.upsert(
95
110
  ids=[chunk_id],
96
111
  embeddings=[vector],
@@ -113,8 +128,11 @@ def create_rag_from_directory(
113
128
  )
114
129
  # Vectorize query and get related document chunks
115
130
  print(stylize_faint("Vectorizing query"), file=sys.stderr)
116
- embedding_result = list(embedding_model.embed([query]))
117
- query_vector = embedding_result[0]
131
+ # Get embeddings using OpenAI
132
+ embedding_response = openai_client.embeddings.create(
133
+ input=query, model=openai_embedding_model
134
+ )
135
+ query_vector = embedding_response.data[0].embedding
118
136
  print(stylize_faint("Searching documents"), file=sys.stderr)
119
137
  results = collection.query(
120
138
  query_embeddings=query_vector,
@@ -3,21 +3,53 @@ from collections.abc import Callable
3
3
  from typing import Annotated
4
4
 
5
5
 
6
- def open_web_page(url: str) -> str:
7
- """Get content from a web page."""
8
- import requests
9
-
10
- response = requests.get(
11
- url,
12
- headers={
13
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
14
- },
15
- )
16
- if response.status_code != 200:
17
- raise Exception(
18
- f"Error: Unable to retrieve search results (status code: {response.status_code})" # noqa
19
- )
20
- return json.dumps(parse_html_text(response.text))
6
+ async def open_web_page(url: str) -> str:
7
+ """Get content from a web page using a headless browser."""
8
+
9
+ async def get_page_content(page_url: str):
10
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
11
+ try:
12
+ from playwright.async_api import async_playwright
13
+
14
+ async with async_playwright() as p:
15
+ browser = await p.chromium.launch(headless=True)
16
+ page = await browser.new_page()
17
+ await page.set_extra_http_headers({"User-Agent": user_agent})
18
+ try:
19
+ # Navigate to the URL with a timeout of 30 seconds
20
+ await page.goto(page_url, wait_until="networkidle", timeout=30000)
21
+ # Wait for the content to load
22
+ await page.wait_for_load_state("domcontentloaded")
23
+ # Get the page content
24
+ content = await page.content()
25
+ # Extract all links from the page
26
+ links = await page.eval_on_selector_all(
27
+ "a[href]",
28
+ """
29
+ (elements) => elements.map(el => {
30
+ const href = el.getAttribute('href');
31
+ if (href && !href.startsWith('#') && !href.startsWith('/')) {
32
+ return href;
33
+ }
34
+ return null;
35
+ }).filter(href => href !== null)
36
+ """,
37
+ )
38
+ return {"content": content, "links_on_page": links}
39
+ finally:
40
+ await browser.close()
41
+ except ImportError:
42
+ import requests
43
+
44
+ response = requests.get(url, headers={"User-Agent": user_agent})
45
+ if response.status_code != 200:
46
+ msg = f"Unable to retrieve search results. Status code: {response.status_code}"
47
+ raise Exception(msg)
48
+ return {"content": response.text, "links_on_page": []}
49
+
50
+ result = await get_page_content(url)
51
+ # Parse the HTML content
52
+ return json.dumps(parse_html_text(result["content"]))
21
53
 
22
54
 
23
55
  def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
zrb/builtin/todo.py CHANGED
@@ -25,6 +25,18 @@ from zrb.util.todo import (
25
25
  )
26
26
 
27
27
 
28
+ def _get_filter_input(allow_positional_parsing: bool = False) -> StrInput:
29
+ return StrInput(
30
+ name="filter",
31
+ description="Visual filter",
32
+ prompt="Visual Filter",
33
+ allow_empty=True,
34
+ allow_positional_parsing=allow_positional_parsing,
35
+ always_prompt=False,
36
+ default=TODO_VISUAL_FILTER,
37
+ )
38
+
39
+
28
40
  @make_task(
29
41
  name="add-todo",
30
42
  input=[
@@ -51,6 +63,7 @@ from zrb.util.todo import (
51
63
  prompt="Task context (space separated)",
52
64
  allow_empty=True,
53
65
  ),
66
+ _get_filter_input(),
54
67
  ],
55
68
  description="➕ Add todo",
56
69
  group=todo_group,
@@ -82,16 +95,22 @@ def add_todo(ctx: AnyContext):
82
95
  )
83
96
  )
84
97
  save_todo_list(todo_file_path, todo_list)
85
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
98
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
86
99
 
87
100
 
88
- @make_task(name="list-todo", description="📋 List todo", group=todo_group, alias="list")
101
+ @make_task(
102
+ name="list-todo",
103
+ input=_get_filter_input(allow_positional_parsing=True),
104
+ description="📋 List todo",
105
+ group=todo_group,
106
+ alias="list",
107
+ )
89
108
  def list_todo(ctx: AnyContext):
90
109
  todo_file_path = os.path.join(TODO_DIR, "todo.txt")
91
110
  todo_list: list[TodoTaskModel] = []
92
111
  if os.path.isfile(todo_file_path):
93
112
  todo_list = load_todo_list(todo_file_path)
94
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
113
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
95
114
 
96
115
 
97
116
  @make_task(
@@ -127,7 +146,10 @@ def show_todo(ctx: AnyContext):
127
146
 
128
147
  @make_task(
129
148
  name="complete-todo",
130
- input=StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
149
+ input=[
150
+ StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
151
+ _get_filter_input(),
152
+ ],
131
153
  description="✅ Complete todo",
132
154
  group=todo_group,
133
155
  alias="complete",
@@ -141,10 +163,10 @@ def complete_todo(ctx: AnyContext):
141
163
  todo_task = select_todo_task(todo_list, ctx.input.keyword)
142
164
  if todo_task is None:
143
165
  ctx.log_error("Task not found")
144
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
166
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
145
167
  if todo_task.completed:
146
168
  ctx.log_error("Task already completed")
147
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
169
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
148
170
  # Update todo task
149
171
  todo_task = cascade_todo_task(todo_task)
150
172
  if todo_task.creation_date is not None:
@@ -152,11 +174,12 @@ def complete_todo(ctx: AnyContext):
152
174
  todo_task.completed = True
153
175
  # Save todo list
154
176
  save_todo_list(todo_file_path, todo_list)
155
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
177
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
156
178
 
157
179
 
158
180
  @make_task(
159
181
  name="archive-todo",
182
+ input=_get_filter_input(),
160
183
  description="📚 Archive todo",
161
184
  group=todo_group,
162
185
  alias="archive",
@@ -180,7 +203,7 @@ def archive_todo(ctx: AnyContext):
180
203
  ]
181
204
  if len(new_archived_todo_list) == 0:
182
205
  ctx.print("No completed task to archive")
183
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
206
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
184
207
  archive_file_path = os.path.join(TODO_DIR, "archive.txt")
185
208
  if not os.path.isdir(TODO_DIR):
186
209
  os.make_dirs(TODO_DIR, exist_ok=True)
@@ -192,7 +215,7 @@ def archive_todo(ctx: AnyContext):
192
215
  # Save the new todo list and add the archived ones
193
216
  save_todo_list(archive_file_path, archived_todo_list)
194
217
  save_todo_list(todo_file_path, working_todo_list)
195
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
218
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
196
219
 
197
220
 
198
221
  @make_task(
@@ -216,6 +239,7 @@ def archive_todo(ctx: AnyContext):
216
239
  description="Working stop time",
217
240
  default=lambda _: _get_default_stop_work_time_str(),
218
241
  ),
242
+ _get_filter_input(),
219
243
  ],
220
244
  description="🕒 Log work todo",
221
245
  group=todo_group,
@@ -230,7 +254,7 @@ def log_todo(ctx: AnyContext):
230
254
  todo_task = select_todo_task(todo_list, ctx.input.keyword)
231
255
  if todo_task is None:
232
256
  ctx.log_error("Task not found")
233
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
257
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
234
258
  # Update todo task
235
259
  todo_task = cascade_todo_task(todo_task)
236
260
  current_duration_str = todo_task.keyval.get("duration", "0")
@@ -268,7 +292,7 @@ def log_todo(ctx: AnyContext):
268
292
  log_work_list = json.loads(read_file(log_work_path))
269
293
  return "\n".join(
270
294
  [
271
- get_visual_todo_list(todo_list, TODO_VISUAL_FILTER),
295
+ get_visual_todo_list(todo_list, filter=ctx.input.filter),
272
296
  "",
273
297
  get_visual_todo_card(todo_task, log_work_list),
274
298
  ]
@@ -296,6 +320,7 @@ def _get_default_stop_work_time_str() -> str:
296
320
  default=lambda _: _get_todo_txt_content(),
297
321
  allow_positional_parsing=False,
298
322
  ),
323
+ _get_filter_input(),
299
324
  ],
300
325
  description="📝 Edit todo",
301
326
  group=todo_group,
@@ -311,7 +336,7 @@ def edit_todo(ctx: AnyContext):
311
336
  todo_file_path = os.path.join(TODO_DIR, "todo.txt")
312
337
  write_file(todo_file_path, new_content)
313
338
  todo_list = load_todo_list(todo_file_path)
314
- return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
339
+ return get_visual_todo_list(todo_list, filter=ctx.input.filter)
315
340
 
316
341
 
317
342
  def _get_todo_txt_content() -> str:
zrb/config.py CHANGED
@@ -85,10 +85,10 @@ LLM_HISTORY_FILE = os.getenv(
85
85
  LLM_ALLOW_ACCESS_LOCAL_FILE = to_boolean(os.getenv("ZRB_LLM_ACCESS_LOCAL_FILE", "1"))
86
86
  LLM_ALLOW_ACCESS_SHELL = to_boolean(os.getenv("ZRB_LLM_ACCESS_SHELL", "1"))
87
87
  LLM_ALLOW_ACCESS_INTERNET = to_boolean(os.getenv("ZRB_LLM_ACCESS_INTERNET", "1"))
88
- # noqa See: https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models
89
- RAG_EMBEDDING_MODEL = os.getenv(
90
- "ZRB_RAG_EMBEDDING_MODEL", "nomic-ai/nomic-embed-text-v1.5-Q"
91
- )
88
+ # RAG Configuration
89
+ RAG_EMBEDDING_API_KEY = os.getenv("ZRB_RAG_EMBEDDING_API_KEY", None)
90
+ RAG_EMBEDDING_BASE_URL = os.getenv("ZRB_RAG_EMBEDDING_BASE_URL", None)
91
+ RAG_EMBEDDING_MODEL = os.getenv("ZRB_RAG_EMBEDDING_MODEL", "text-embedding-ada-002")
92
92
  RAG_CHUNK_SIZE = int(os.getenv("ZRB_RAG_CHUNK_SIZE", "1024"))
93
93
  RAG_OVERLAP = int(os.getenv("ZRB_RAG_OVERLAP", "128"))
94
94
  RAG_MAX_RESULT_COUNT = int(os.getenv("ZRB_RAG_MAX_RESULT_COUNT", "5"))
zrb/llm_config.py CHANGED
@@ -2,20 +2,29 @@ import os
2
2
 
3
3
  from pydantic_ai.models import Model
4
4
  from pydantic_ai.models.openai import OpenAIModel
5
+ from pydantic_ai.providers import Provider
5
6
  from pydantic_ai.providers.openai import OpenAIProvider
6
7
 
7
8
  DEFAULT_SYSTEM_PROMPT = """
8
9
  You have access to tools.
9
- Your goal to to answer user queries accurately.
10
+ Your goal is to provide insightful and accurate information based on user queries.
10
11
  Follow these instructions precisely:
11
- 1. ALWAYS use available tools to gather information BEFORE asking the user questions
12
- 2. For tools that require arguments: provide arguments in valid JSON format
13
- 3. For tools that require NO arguments: call with empty JSON object ({}) NOT empty string ('')
14
- 4. NEVER pass arguments to tools that don't accept parameters
15
- 5. NEVER ask users for information obtainable through tools
16
- 6. Use tools in logical sequence until you have sufficient information
17
- 7. If a tool call fails, check if you're passing arguments in the correct format
18
- 8. Only after exhausting relevant tools should you request clarification
12
+ 1. ALWAYS use available tools to gather information BEFORE asking the user questions.
13
+ 2. For tools that require arguments: provide arguments in valid JSON format.
14
+ 3. For tools with no args: call the tool without args. Do NOT pass "" or {}.
15
+ 4. NEVER pass arguments to tools that don't accept parameters.
16
+ 5. NEVER ask users for information obtainable through tools.
17
+ 6. Use tools in a logical sequence until you have sufficient information.
18
+ 7. If a tool call fails, check if you're passing arguments in the correct format.
19
+ Consider alternative strategies if the issue persists.
20
+ 8. Only after exhausting relevant tools should you request clarification.
21
+ 9. Understand the context of user queries to provide relevant and accurate responses.
22
+ 10. Engage with users in a conversational manner once the necessary information is gathered.
23
+ 11. Adapt to different query types or scenarios to improve flexibility and effectiveness.
24
+ """.strip()
25
+
26
+ DEFAULT_PERSONA = """
27
+ You are an expert in various fields including technology, science, history, and more.
19
28
  """.strip()
20
29
 
21
30
 
@@ -26,6 +35,7 @@ class LLMConfig:
26
35
  default_model_name: str | None = None,
27
36
  default_base_url: str | None = None,
28
37
  default_api_key: str | None = None,
38
+ default_persona: str | None = None,
29
39
  default_system_prompt: str | None = None,
30
40
  ):
31
41
  self._model_name = (
@@ -48,12 +58,20 @@ class LLMConfig:
48
58
  if default_system_prompt is not None
49
59
  else os.getenv("ZRB_LLM_SYSTEM_PROMPT", None)
50
60
  )
61
+ self._persona = (
62
+ default_persona
63
+ if default_persona is not None
64
+ else os.getenv("ZRB_LLM_PERSONA", None)
65
+ )
66
+ self._default_provider = None
51
67
  self._default_model = None
52
68
 
53
69
  def _get_model_name(self) -> str | None:
54
70
  return self._model_name if self._model_name is not None else None
55
71
 
56
- def _get_model_provider(self) -> OpenAIProvider:
72
+ def get_default_model_provider(self) -> Provider | str:
73
+ if self._default_provider is not None:
74
+ return self._default_provider
57
75
  if self._model_base_url is None and self._model_api_key is None:
58
76
  return "openai"
59
77
  return OpenAIProvider(
@@ -61,9 +79,15 @@ class LLMConfig:
61
79
  )
62
80
 
63
81
  def get_default_system_prompt(self) -> str:
64
- if self._system_prompt is not None:
65
- return self._system_prompt
66
- return DEFAULT_SYSTEM_PROMPT
82
+ system_prompt = (
83
+ DEFAULT_SYSTEM_PROMPT
84
+ if self._system_prompt is None
85
+ else self._system_prompt
86
+ )
87
+ persona = DEFAULT_PERSONA if self._persona is None else self._persona
88
+ if persona is not None:
89
+ return f"{persona}\n{system_prompt}"
90
+ return system_prompt
67
91
 
68
92
  def get_default_model(self) -> Model | str | None:
69
93
  if self._default_model is not None:
@@ -73,7 +97,7 @@ class LLMConfig:
73
97
  return None
74
98
  return OpenAIModel(
75
99
  model_name=model_name,
76
- provider=self._get_model_provider(),
100
+ provider=self.get_default_model_provider(),
77
101
  )
78
102
 
79
103
  def set_default_system_prompt(self, system_prompt: str):
@@ -88,6 +112,9 @@ class LLMConfig:
88
112
  def set_default_model_base_url(self, model_base_url: str):
89
113
  self._model_base_url = model_base_url
90
114
 
115
+ def set_default_provider(self, provider: Provider | str):
116
+ self._default_provider = provider
117
+
91
118
  def set_default_model(self, model: Model | str | None):
92
119
  self._default_model = model
93
120
 
zrb/task/llm_task.py CHANGED
@@ -1,5 +1,8 @@
1
+ import functools
2
+ import inspect
1
3
  import json
2
4
  import os
5
+ import traceback
3
6
  from collections.abc import Callable
4
7
  from typing import Any
5
8
 
@@ -201,6 +204,9 @@ class LLMTask(BaseTask):
201
204
  async with node.stream(agent_run.ctx) as handle_stream:
202
205
  async for event in handle_stream:
203
206
  if isinstance(event, FunctionToolCallEvent):
207
+ # Fixing anthrophic claude when call function with empty parameter
208
+ if event.part.args == "":
209
+ event.part.args = {}
204
210
  ctx.print(
205
211
  stylize_faint(
206
212
  f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
@@ -240,7 +246,7 @@ class LLMTask(BaseTask):
240
246
  )
241
247
  tools_or_callables.extend(self._additional_tools)
242
248
  tools = [
243
- tool if isinstance(tool, Tool) else Tool(tool, takes_ctx=False)
249
+ tool if isinstance(tool, Tool) else Tool(_wrap_tool(tool), takes_ctx=False)
244
250
  for tool in tools_or_callables
245
251
  ]
246
252
  return Agent(
@@ -256,21 +262,17 @@ class LLMTask(BaseTask):
256
262
  if model is None:
257
263
  return default_llm_config.get_default_model()
258
264
  if isinstance(model, str):
265
+ model_base_url = self._get_model_base_url(ctx)
266
+ model_api_key = self._get_model_api_key(ctx)
259
267
  llm_config = LLMConfig(
260
268
  default_model_name=model,
261
- default_base_url=get_attr(
262
- ctx,
263
- self._get_model_base_url(ctx),
264
- None,
265
- auto_render=self._render_model_base_url,
266
- ),
267
- default_api_key=get_attr(
268
- ctx,
269
- self._get_model_api_key(ctx),
270
- None,
271
- auto_render=self._render_model_api_key,
272
- ),
269
+ default_base_url=model_base_url,
270
+ default_api_key=model_api_key,
273
271
  )
272
+ if model_base_url is None and model_api_key is None:
273
+ default_model_provider = default_llm_config.get_default_model_provider()
274
+ if default_model_provider is not None:
275
+ llm_config.set_default_provider(default_model_provider)
274
276
  return llm_config.get_default_model()
275
277
  raise ValueError(f"Invalid model: {model}")
276
278
 
@@ -288,7 +290,7 @@ class LLMTask(BaseTask):
288
290
  )
289
291
  if isinstance(api_key, str) or api_key is None:
290
292
  return api_key
291
- raise ValueError(f"Invalid model base URL: {api_key}")
293
+ raise ValueError(f"Invalid model API key: {api_key}")
292
294
 
293
295
  def _get_system_prompt(self, ctx: AnyContext) -> str:
294
296
  system_prompt = get_attr(
@@ -325,3 +327,16 @@ class LLMTask(BaseTask):
325
327
  "",
326
328
  auto_render=self._render_history_file,
327
329
  )
330
+
331
+
332
+ def _wrap_tool(func):
333
+ @functools.wraps(func)
334
+ async def wrapper(*args, **kwargs):
335
+ try:
336
+ return await run_async(func(*args, **kwargs))
337
+ except Exception as e:
338
+ # Optionally, you can include more details from traceback if needed.
339
+ error_details = traceback.format_exc()
340
+ return f"Error: {e}\nDetails: {error_details}"
341
+
342
+ return wrapper