zrb 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/llm_chat.py +8 -6
- zrb/builtin/llm/tool/api.py +1 -1
- zrb/builtin/llm/tool/file.py +471 -113
- zrb/builtin/llm/tool/rag.py +28 -10
- zrb/builtin/llm/tool/web.py +47 -15
- zrb/builtin/todo.py +37 -12
- zrb/config.py +4 -4
- zrb/llm_config.py +41 -14
- zrb/task/llm_task.py +29 -14
- {zrb-1.4.2.dist-info → zrb-1.5.0.dist-info}/METADATA +64 -41
- {zrb-1.4.2.dist-info → zrb-1.5.0.dist-info}/RECORD +13 -13
- {zrb-1.4.2.dist-info → zrb-1.5.0.dist-info}/WHEEL +0 -0
- {zrb-1.4.2.dist-info → zrb-1.5.0.dist-info}/entry_points.txt +0 -0
zrb/builtin/llm/tool/rag.py
CHANGED
@@ -9,6 +9,8 @@ import ulid
|
|
9
9
|
|
10
10
|
from zrb.config import (
|
11
11
|
RAG_CHUNK_SIZE,
|
12
|
+
RAG_EMBEDDING_API_KEY,
|
13
|
+
RAG_EMBEDDING_BASE_URL,
|
12
14
|
RAG_EMBEDDING_MODEL,
|
13
15
|
RAG_MAX_RESULT_COUNT,
|
14
16
|
RAG_OVERLAP,
|
@@ -35,24 +37,34 @@ def create_rag_from_directory(
|
|
35
37
|
tool_name: str,
|
36
38
|
tool_description: str,
|
37
39
|
document_dir_path: str = "./documents",
|
38
|
-
model: str = RAG_EMBEDDING_MODEL,
|
39
40
|
vector_db_path: str = "./chroma",
|
40
41
|
vector_db_collection: str = "documents",
|
41
42
|
chunk_size: int = RAG_CHUNK_SIZE,
|
42
43
|
overlap: int = RAG_OVERLAP,
|
43
44
|
max_result_count: int = RAG_MAX_RESULT_COUNT,
|
44
45
|
file_reader: list[RAGFileReader] = [],
|
46
|
+
openai_api_key: str = RAG_EMBEDDING_API_KEY,
|
47
|
+
openai_base_url: str = RAG_EMBEDDING_BASE_URL,
|
48
|
+
openai_embedding_model: str = RAG_EMBEDDING_MODEL,
|
45
49
|
):
|
46
50
|
async def retrieve(query: str) -> str:
|
47
51
|
from chromadb import PersistentClient
|
48
52
|
from chromadb.config import Settings
|
49
|
-
from
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
+
from openai import OpenAI
|
54
|
+
|
55
|
+
# Initialize OpenAI client with custom URL if provided
|
56
|
+
client_args = {}
|
57
|
+
if openai_api_key:
|
58
|
+
client_args["api_key"] = openai_api_key
|
59
|
+
if openai_base_url:
|
60
|
+
client_args["base_url"] = openai_base_url
|
61
|
+
# Initialize OpenAI client for embeddings
|
62
|
+
openai_client = OpenAI(**client_args)
|
63
|
+
# Initialize ChromaDB client
|
64
|
+
chroma_client = PersistentClient(
|
53
65
|
path=vector_db_path, settings=Settings(allow_reset=True)
|
54
66
|
)
|
55
|
-
collection =
|
67
|
+
collection = chroma_client.get_or_create_collection(vector_db_collection)
|
56
68
|
# Track file changes using a hash-based approach
|
57
69
|
hash_file_path = os.path.join(vector_db_path, "file_hashes.json")
|
58
70
|
previous_hashes = _load_hashes(hash_file_path)
|
@@ -89,8 +101,11 @@ def create_rag_from_directory(
|
|
89
101
|
),
|
90
102
|
file=sys.stderr,
|
91
103
|
)
|
92
|
-
|
93
|
-
|
104
|
+
# Get embeddings using OpenAI
|
105
|
+
embedding_response = openai_client.embeddings.create(
|
106
|
+
input=chunk, model=openai_embedding_model
|
107
|
+
)
|
108
|
+
vector = embedding_response.data[0].embedding
|
94
109
|
collection.upsert(
|
95
110
|
ids=[chunk_id],
|
96
111
|
embeddings=[vector],
|
@@ -113,8 +128,11 @@ def create_rag_from_directory(
|
|
113
128
|
)
|
114
129
|
# Vectorize query and get related document chunks
|
115
130
|
print(stylize_faint("Vectorizing query"), file=sys.stderr)
|
116
|
-
|
117
|
-
|
131
|
+
# Get embeddings using OpenAI
|
132
|
+
embedding_response = openai_client.embeddings.create(
|
133
|
+
input=query, model=openai_embedding_model
|
134
|
+
)
|
135
|
+
query_vector = embedding_response.data[0].embedding
|
118
136
|
print(stylize_faint("Searching documents"), file=sys.stderr)
|
119
137
|
results = collection.query(
|
120
138
|
query_embeddings=query_vector,
|
zrb/builtin/llm/tool/web.py
CHANGED
@@ -3,21 +3,53 @@ from collections.abc import Callable
|
|
3
3
|
from typing import Annotated
|
4
4
|
|
5
5
|
|
6
|
-
def open_web_page(url: str) -> str:
|
7
|
-
"""Get content from a web page."""
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
6
|
+
async def open_web_page(url: str) -> str:
|
7
|
+
"""Get content from a web page using a headless browser."""
|
8
|
+
|
9
|
+
async def get_page_content(page_url: str):
|
10
|
+
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" # noqa
|
11
|
+
try:
|
12
|
+
from playwright.async_api import async_playwright
|
13
|
+
|
14
|
+
async with async_playwright() as p:
|
15
|
+
browser = await p.chromium.launch(headless=True)
|
16
|
+
page = await browser.new_page()
|
17
|
+
await page.set_extra_http_headers({"User-Agent": user_agent})
|
18
|
+
try:
|
19
|
+
# Navigate to the URL with a timeout of 30 seconds
|
20
|
+
await page.goto(page_url, wait_until="networkidle", timeout=30000)
|
21
|
+
# Wait for the content to load
|
22
|
+
await page.wait_for_load_state("domcontentloaded")
|
23
|
+
# Get the page content
|
24
|
+
content = await page.content()
|
25
|
+
# Extract all links from the page
|
26
|
+
links = await page.eval_on_selector_all(
|
27
|
+
"a[href]",
|
28
|
+
"""
|
29
|
+
(elements) => elements.map(el => {
|
30
|
+
const href = el.getAttribute('href');
|
31
|
+
if (href && !href.startsWith('#') && !href.startsWith('/')) {
|
32
|
+
return href;
|
33
|
+
}
|
34
|
+
return null;
|
35
|
+
}).filter(href => href !== null)
|
36
|
+
""",
|
37
|
+
)
|
38
|
+
return {"content": content, "links_on_page": links}
|
39
|
+
finally:
|
40
|
+
await browser.close()
|
41
|
+
except ImportError:
|
42
|
+
import requests
|
43
|
+
|
44
|
+
response = requests.get(url, headers={"User-Agent": user_agent})
|
45
|
+
if response.status_code != 200:
|
46
|
+
msg = f"Unable to retrieve search results. Status code: {response.status_code}"
|
47
|
+
raise Exception(msg)
|
48
|
+
return {"content": response.text, "links_on_page": []}
|
49
|
+
|
50
|
+
result = await get_page_content(url)
|
51
|
+
# Parse the HTML content
|
52
|
+
return json.dumps(parse_html_text(result["content"]))
|
21
53
|
|
22
54
|
|
23
55
|
def create_search_internet_tool(serp_api_key: str) -> Callable[[str, int], str]:
|
zrb/builtin/todo.py
CHANGED
@@ -25,6 +25,18 @@ from zrb.util.todo import (
|
|
25
25
|
)
|
26
26
|
|
27
27
|
|
28
|
+
def _get_filter_input(allow_positional_parsing: bool = False) -> StrInput:
|
29
|
+
return StrInput(
|
30
|
+
name="filter",
|
31
|
+
description="Visual filter",
|
32
|
+
prompt="Visual Filter",
|
33
|
+
allow_empty=True,
|
34
|
+
allow_positional_parsing=allow_positional_parsing,
|
35
|
+
always_prompt=False,
|
36
|
+
default=TODO_VISUAL_FILTER,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
28
40
|
@make_task(
|
29
41
|
name="add-todo",
|
30
42
|
input=[
|
@@ -51,6 +63,7 @@ from zrb.util.todo import (
|
|
51
63
|
prompt="Task context (space separated)",
|
52
64
|
allow_empty=True,
|
53
65
|
),
|
66
|
+
_get_filter_input(),
|
54
67
|
],
|
55
68
|
description="➕ Add todo",
|
56
69
|
group=todo_group,
|
@@ -82,16 +95,22 @@ def add_todo(ctx: AnyContext):
|
|
82
95
|
)
|
83
96
|
)
|
84
97
|
save_todo_list(todo_file_path, todo_list)
|
85
|
-
return get_visual_todo_list(todo_list,
|
98
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
86
99
|
|
87
100
|
|
88
|
-
@make_task(
|
101
|
+
@make_task(
|
102
|
+
name="list-todo",
|
103
|
+
input=_get_filter_input(allow_positional_parsing=True),
|
104
|
+
description="📋 List todo",
|
105
|
+
group=todo_group,
|
106
|
+
alias="list",
|
107
|
+
)
|
89
108
|
def list_todo(ctx: AnyContext):
|
90
109
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
91
110
|
todo_list: list[TodoTaskModel] = []
|
92
111
|
if os.path.isfile(todo_file_path):
|
93
112
|
todo_list = load_todo_list(todo_file_path)
|
94
|
-
return get_visual_todo_list(todo_list,
|
113
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
95
114
|
|
96
115
|
|
97
116
|
@make_task(
|
@@ -127,7 +146,10 @@ def show_todo(ctx: AnyContext):
|
|
127
146
|
|
128
147
|
@make_task(
|
129
148
|
name="complete-todo",
|
130
|
-
input=
|
149
|
+
input=[
|
150
|
+
StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
|
151
|
+
_get_filter_input(),
|
152
|
+
],
|
131
153
|
description="✅ Complete todo",
|
132
154
|
group=todo_group,
|
133
155
|
alias="complete",
|
@@ -141,10 +163,10 @@ def complete_todo(ctx: AnyContext):
|
|
141
163
|
todo_task = select_todo_task(todo_list, ctx.input.keyword)
|
142
164
|
if todo_task is None:
|
143
165
|
ctx.log_error("Task not found")
|
144
|
-
return get_visual_todo_list(todo_list,
|
166
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
145
167
|
if todo_task.completed:
|
146
168
|
ctx.log_error("Task already completed")
|
147
|
-
return get_visual_todo_list(todo_list,
|
169
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
148
170
|
# Update todo task
|
149
171
|
todo_task = cascade_todo_task(todo_task)
|
150
172
|
if todo_task.creation_date is not None:
|
@@ -152,11 +174,12 @@ def complete_todo(ctx: AnyContext):
|
|
152
174
|
todo_task.completed = True
|
153
175
|
# Save todo list
|
154
176
|
save_todo_list(todo_file_path, todo_list)
|
155
|
-
return get_visual_todo_list(todo_list,
|
177
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
156
178
|
|
157
179
|
|
158
180
|
@make_task(
|
159
181
|
name="archive-todo",
|
182
|
+
input=_get_filter_input(),
|
160
183
|
description="📚 Archive todo",
|
161
184
|
group=todo_group,
|
162
185
|
alias="archive",
|
@@ -180,7 +203,7 @@ def archive_todo(ctx: AnyContext):
|
|
180
203
|
]
|
181
204
|
if len(new_archived_todo_list) == 0:
|
182
205
|
ctx.print("No completed task to archive")
|
183
|
-
return get_visual_todo_list(todo_list,
|
206
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
184
207
|
archive_file_path = os.path.join(TODO_DIR, "archive.txt")
|
185
208
|
if not os.path.isdir(TODO_DIR):
|
186
209
|
os.make_dirs(TODO_DIR, exist_ok=True)
|
@@ -192,7 +215,7 @@ def archive_todo(ctx: AnyContext):
|
|
192
215
|
# Save the new todo list and add the archived ones
|
193
216
|
save_todo_list(archive_file_path, archived_todo_list)
|
194
217
|
save_todo_list(todo_file_path, working_todo_list)
|
195
|
-
return get_visual_todo_list(todo_list,
|
218
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
196
219
|
|
197
220
|
|
198
221
|
@make_task(
|
@@ -216,6 +239,7 @@ def archive_todo(ctx: AnyContext):
|
|
216
239
|
description="Working stop time",
|
217
240
|
default=lambda _: _get_default_stop_work_time_str(),
|
218
241
|
),
|
242
|
+
_get_filter_input(),
|
219
243
|
],
|
220
244
|
description="🕒 Log work todo",
|
221
245
|
group=todo_group,
|
@@ -230,7 +254,7 @@ def log_todo(ctx: AnyContext):
|
|
230
254
|
todo_task = select_todo_task(todo_list, ctx.input.keyword)
|
231
255
|
if todo_task is None:
|
232
256
|
ctx.log_error("Task not found")
|
233
|
-
return get_visual_todo_list(todo_list,
|
257
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
234
258
|
# Update todo task
|
235
259
|
todo_task = cascade_todo_task(todo_task)
|
236
260
|
current_duration_str = todo_task.keyval.get("duration", "0")
|
@@ -268,7 +292,7 @@ def log_todo(ctx: AnyContext):
|
|
268
292
|
log_work_list = json.loads(read_file(log_work_path))
|
269
293
|
return "\n".join(
|
270
294
|
[
|
271
|
-
get_visual_todo_list(todo_list,
|
295
|
+
get_visual_todo_list(todo_list, filter=ctx.input.filter),
|
272
296
|
"",
|
273
297
|
get_visual_todo_card(todo_task, log_work_list),
|
274
298
|
]
|
@@ -296,6 +320,7 @@ def _get_default_stop_work_time_str() -> str:
|
|
296
320
|
default=lambda _: _get_todo_txt_content(),
|
297
321
|
allow_positional_parsing=False,
|
298
322
|
),
|
323
|
+
_get_filter_input(),
|
299
324
|
],
|
300
325
|
description="📝 Edit todo",
|
301
326
|
group=todo_group,
|
@@ -311,7 +336,7 @@ def edit_todo(ctx: AnyContext):
|
|
311
336
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
312
337
|
write_file(todo_file_path, new_content)
|
313
338
|
todo_list = load_todo_list(todo_file_path)
|
314
|
-
return get_visual_todo_list(todo_list,
|
339
|
+
return get_visual_todo_list(todo_list, filter=ctx.input.filter)
|
315
340
|
|
316
341
|
|
317
342
|
def _get_todo_txt_content() -> str:
|
zrb/config.py
CHANGED
@@ -85,10 +85,10 @@ LLM_HISTORY_FILE = os.getenv(
|
|
85
85
|
LLM_ALLOW_ACCESS_LOCAL_FILE = to_boolean(os.getenv("ZRB_LLM_ACCESS_LOCAL_FILE", "1"))
|
86
86
|
LLM_ALLOW_ACCESS_SHELL = to_boolean(os.getenv("ZRB_LLM_ACCESS_SHELL", "1"))
|
87
87
|
LLM_ALLOW_ACCESS_INTERNET = to_boolean(os.getenv("ZRB_LLM_ACCESS_INTERNET", "1"))
|
88
|
-
#
|
89
|
-
|
90
|
-
|
91
|
-
)
|
88
|
+
# RAG Configuration
|
89
|
+
RAG_EMBEDDING_API_KEY = os.getenv("ZRB_RAG_EMBEDDING_API_KEY", None)
|
90
|
+
RAG_EMBEDDING_BASE_URL = os.getenv("ZRB_RAG_EMBEDDING_BASE_URL", None)
|
91
|
+
RAG_EMBEDDING_MODEL = os.getenv("ZRB_RAG_EMBEDDING_MODEL", "text-embedding-ada-002")
|
92
92
|
RAG_CHUNK_SIZE = int(os.getenv("ZRB_RAG_CHUNK_SIZE", "1024"))
|
93
93
|
RAG_OVERLAP = int(os.getenv("ZRB_RAG_OVERLAP", "128"))
|
94
94
|
RAG_MAX_RESULT_COUNT = int(os.getenv("ZRB_RAG_MAX_RESULT_COUNT", "5"))
|
zrb/llm_config.py
CHANGED
@@ -2,20 +2,29 @@ import os
|
|
2
2
|
|
3
3
|
from pydantic_ai.models import Model
|
4
4
|
from pydantic_ai.models.openai import OpenAIModel
|
5
|
+
from pydantic_ai.providers import Provider
|
5
6
|
from pydantic_ai.providers.openai import OpenAIProvider
|
6
7
|
|
7
8
|
DEFAULT_SYSTEM_PROMPT = """
|
8
9
|
You have access to tools.
|
9
|
-
Your goal
|
10
|
+
Your goal is to provide insightful and accurate information based on user queries.
|
10
11
|
Follow these instructions precisely:
|
11
|
-
1. ALWAYS use available tools to gather information BEFORE asking the user questions
|
12
|
-
2. For tools that require arguments: provide arguments in valid JSON format
|
13
|
-
3. For tools
|
14
|
-
4. NEVER pass arguments to tools that don't accept parameters
|
15
|
-
5. NEVER ask users for information obtainable through tools
|
16
|
-
6. Use tools in logical sequence until you have sufficient information
|
17
|
-
7. If a tool call fails, check if you're passing arguments in the correct format
|
18
|
-
|
12
|
+
1. ALWAYS use available tools to gather information BEFORE asking the user questions.
|
13
|
+
2. For tools that require arguments: provide arguments in valid JSON format.
|
14
|
+
3. For tools with no args: call the tool without args. Do NOT pass "" or {}.
|
15
|
+
4. NEVER pass arguments to tools that don't accept parameters.
|
16
|
+
5. NEVER ask users for information obtainable through tools.
|
17
|
+
6. Use tools in a logical sequence until you have sufficient information.
|
18
|
+
7. If a tool call fails, check if you're passing arguments in the correct format.
|
19
|
+
Consider alternative strategies if the issue persists.
|
20
|
+
8. Only after exhausting relevant tools should you request clarification.
|
21
|
+
9. Understand the context of user queries to provide relevant and accurate responses.
|
22
|
+
10. Engage with users in a conversational manner once the necessary information is gathered.
|
23
|
+
11. Adapt to different query types or scenarios to improve flexibility and effectiveness.
|
24
|
+
""".strip()
|
25
|
+
|
26
|
+
DEFAULT_PERSONA = """
|
27
|
+
You are an expert in various fields including technology, science, history, and more.
|
19
28
|
""".strip()
|
20
29
|
|
21
30
|
|
@@ -26,6 +35,7 @@ class LLMConfig:
|
|
26
35
|
default_model_name: str | None = None,
|
27
36
|
default_base_url: str | None = None,
|
28
37
|
default_api_key: str | None = None,
|
38
|
+
default_persona: str | None = None,
|
29
39
|
default_system_prompt: str | None = None,
|
30
40
|
):
|
31
41
|
self._model_name = (
|
@@ -48,12 +58,20 @@ class LLMConfig:
|
|
48
58
|
if default_system_prompt is not None
|
49
59
|
else os.getenv("ZRB_LLM_SYSTEM_PROMPT", None)
|
50
60
|
)
|
61
|
+
self._persona = (
|
62
|
+
default_persona
|
63
|
+
if default_persona is not None
|
64
|
+
else os.getenv("ZRB_LLM_PERSONA", None)
|
65
|
+
)
|
66
|
+
self._default_provider = None
|
51
67
|
self._default_model = None
|
52
68
|
|
53
69
|
def _get_model_name(self) -> str | None:
|
54
70
|
return self._model_name if self._model_name is not None else None
|
55
71
|
|
56
|
-
def
|
72
|
+
def get_default_model_provider(self) -> Provider | str:
|
73
|
+
if self._default_provider is not None:
|
74
|
+
return self._default_provider
|
57
75
|
if self._model_base_url is None and self._model_api_key is None:
|
58
76
|
return "openai"
|
59
77
|
return OpenAIProvider(
|
@@ -61,9 +79,15 @@ class LLMConfig:
|
|
61
79
|
)
|
62
80
|
|
63
81
|
def get_default_system_prompt(self) -> str:
|
64
|
-
|
65
|
-
|
66
|
-
|
82
|
+
system_prompt = (
|
83
|
+
DEFAULT_SYSTEM_PROMPT
|
84
|
+
if self._system_prompt is None
|
85
|
+
else self._system_prompt
|
86
|
+
)
|
87
|
+
persona = DEFAULT_PERSONA if self._persona is None else self._persona
|
88
|
+
if persona is not None:
|
89
|
+
return f"{persona}\n{system_prompt}"
|
90
|
+
return system_prompt
|
67
91
|
|
68
92
|
def get_default_model(self) -> Model | str | None:
|
69
93
|
if self._default_model is not None:
|
@@ -73,7 +97,7 @@ class LLMConfig:
|
|
73
97
|
return None
|
74
98
|
return OpenAIModel(
|
75
99
|
model_name=model_name,
|
76
|
-
provider=self.
|
100
|
+
provider=self.get_default_model_provider(),
|
77
101
|
)
|
78
102
|
|
79
103
|
def set_default_system_prompt(self, system_prompt: str):
|
@@ -88,6 +112,9 @@ class LLMConfig:
|
|
88
112
|
def set_default_model_base_url(self, model_base_url: str):
|
89
113
|
self._model_base_url = model_base_url
|
90
114
|
|
115
|
+
def set_default_provider(self, provider: Provider | str):
|
116
|
+
self._default_provider = provider
|
117
|
+
|
91
118
|
def set_default_model(self, model: Model | str | None):
|
92
119
|
self._default_model = model
|
93
120
|
|
zrb/task/llm_task.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
|
+
import functools
|
2
|
+
import inspect
|
1
3
|
import json
|
2
4
|
import os
|
5
|
+
import traceback
|
3
6
|
from collections.abc import Callable
|
4
7
|
from typing import Any
|
5
8
|
|
@@ -201,6 +204,9 @@ class LLMTask(BaseTask):
|
|
201
204
|
async with node.stream(agent_run.ctx) as handle_stream:
|
202
205
|
async for event in handle_stream:
|
203
206
|
if isinstance(event, FunctionToolCallEvent):
|
207
|
+
# Fixing anthrophic claude when call function with empty parameter
|
208
|
+
if event.part.args == "":
|
209
|
+
event.part.args = {}
|
204
210
|
ctx.print(
|
205
211
|
stylize_faint(
|
206
212
|
f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
|
@@ -240,7 +246,7 @@ class LLMTask(BaseTask):
|
|
240
246
|
)
|
241
247
|
tools_or_callables.extend(self._additional_tools)
|
242
248
|
tools = [
|
243
|
-
tool if isinstance(tool, Tool) else Tool(tool, takes_ctx=False)
|
249
|
+
tool if isinstance(tool, Tool) else Tool(_wrap_tool(tool), takes_ctx=False)
|
244
250
|
for tool in tools_or_callables
|
245
251
|
]
|
246
252
|
return Agent(
|
@@ -256,21 +262,17 @@ class LLMTask(BaseTask):
|
|
256
262
|
if model is None:
|
257
263
|
return default_llm_config.get_default_model()
|
258
264
|
if isinstance(model, str):
|
265
|
+
model_base_url = self._get_model_base_url(ctx)
|
266
|
+
model_api_key = self._get_model_api_key(ctx)
|
259
267
|
llm_config = LLMConfig(
|
260
268
|
default_model_name=model,
|
261
|
-
default_base_url=
|
262
|
-
|
263
|
-
self._get_model_base_url(ctx),
|
264
|
-
None,
|
265
|
-
auto_render=self._render_model_base_url,
|
266
|
-
),
|
267
|
-
default_api_key=get_attr(
|
268
|
-
ctx,
|
269
|
-
self._get_model_api_key(ctx),
|
270
|
-
None,
|
271
|
-
auto_render=self._render_model_api_key,
|
272
|
-
),
|
269
|
+
default_base_url=model_base_url,
|
270
|
+
default_api_key=model_api_key,
|
273
271
|
)
|
272
|
+
if model_base_url is None and model_api_key is None:
|
273
|
+
default_model_provider = default_llm_config.get_default_model_provider()
|
274
|
+
if default_model_provider is not None:
|
275
|
+
llm_config.set_default_provider(default_model_provider)
|
274
276
|
return llm_config.get_default_model()
|
275
277
|
raise ValueError(f"Invalid model: {model}")
|
276
278
|
|
@@ -288,7 +290,7 @@ class LLMTask(BaseTask):
|
|
288
290
|
)
|
289
291
|
if isinstance(api_key, str) or api_key is None:
|
290
292
|
return api_key
|
291
|
-
raise ValueError(f"Invalid model
|
293
|
+
raise ValueError(f"Invalid model API key: {api_key}")
|
292
294
|
|
293
295
|
def _get_system_prompt(self, ctx: AnyContext) -> str:
|
294
296
|
system_prompt = get_attr(
|
@@ -325,3 +327,16 @@ class LLMTask(BaseTask):
|
|
325
327
|
"",
|
326
328
|
auto_render=self._render_history_file,
|
327
329
|
)
|
330
|
+
|
331
|
+
|
332
|
+
def _wrap_tool(func):
|
333
|
+
@functools.wraps(func)
|
334
|
+
async def wrapper(*args, **kwargs):
|
335
|
+
try:
|
336
|
+
return await run_async(func(*args, **kwargs))
|
337
|
+
except Exception as e:
|
338
|
+
# Optionally, you can include more details from traceback if needed.
|
339
|
+
error_details = traceback.format_exc()
|
340
|
+
return f"Error: {e}\nDetails: {error_details}"
|
341
|
+
|
342
|
+
return wrapper
|