vectara-agentic 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_agent.py +44 -1
- tests/test_tools.py +3 -3
- vectara_agentic/__init__.py +12 -19
- vectara_agentic/_observability.py +3 -3
- vectara_agentic/_prompts.py +5 -3
- vectara_agentic/_version.py +4 -0
- vectara_agentic/agent.py +126 -40
- vectara_agentic/agent_config.py +86 -0
- vectara_agentic/agent_endpoint.py +6 -7
- vectara_agentic/tools.py +275 -61
- vectara_agentic/tools_catalog.py +8 -1
- vectara_agentic/types.py +10 -0
- vectara_agentic/utils.py +50 -34
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/METADATA +122 -38
- vectara_agentic-0.1.25.dist-info/RECORD +21 -0
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/WHEEL +1 -1
- vectara_agentic-0.1.23.dist-info/RECORD +0 -19
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/LICENSE +0 -0
- {vectara_agentic-0.1.23.dist-info → vectara_agentic-0.1.25.dist-info}/top_level.txt +0 -0
vectara_agentic/tools.py
CHANGED
|
@@ -9,6 +9,7 @@ import os
|
|
|
9
9
|
|
|
10
10
|
from typing import Callable, List, Dict, Any, Optional, Type
|
|
11
11
|
from pydantic import BaseModel, Field
|
|
12
|
+
from pydantic_core import PydanticUndefined
|
|
12
13
|
|
|
13
14
|
from llama_index.core.tools import FunctionTool
|
|
14
15
|
from llama_index.core.tools.function_tool import AsyncCallable
|
|
@@ -125,6 +126,122 @@ class VectaraTool(FunctionTool):
|
|
|
125
126
|
break
|
|
126
127
|
return is_equal
|
|
127
128
|
|
|
129
|
+
def _build_filter_string(kwargs: Dict[str, Any], tool_args_type: Dict[str, str], fixed_filter: str) -> str:
|
|
130
|
+
"""
|
|
131
|
+
Build filter string for Vectara from kwargs
|
|
132
|
+
"""
|
|
133
|
+
filter_parts = []
|
|
134
|
+
comparison_operators = [">=", "<=", "!=", ">", "<", "="]
|
|
135
|
+
numeric_only_ops = {">", "<", ">=", "<="}
|
|
136
|
+
|
|
137
|
+
for key, value in kwargs.items():
|
|
138
|
+
if value is None or value == "":
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Determine the prefix for the key. Valid values are "doc" or "part"
|
|
142
|
+
# default to 'doc' if not specified
|
|
143
|
+
prefix = tool_args_type.get(key, "doc")
|
|
144
|
+
|
|
145
|
+
if prefix not in ["doc", "part"]:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if value is PydanticUndefined:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Value of argument {key} is undefined, and this is invalid. "
|
|
153
|
+
"Please form proper arguments and try again."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# value of the argument
|
|
157
|
+
val_str = str(value).strip()
|
|
158
|
+
|
|
159
|
+
# Special handling for range operator
|
|
160
|
+
if val_str.startswith(("[", "(")) and val_str.endswith(("]", ")")):
|
|
161
|
+
# Extract the boundary types
|
|
162
|
+
start_inclusive = val_str.startswith("[")
|
|
163
|
+
end_inclusive = val_str.endswith("]")
|
|
164
|
+
|
|
165
|
+
# Remove the boundaries and strip whitespace
|
|
166
|
+
val_str = val_str[1:-1].strip()
|
|
167
|
+
|
|
168
|
+
if "," in val_str:
|
|
169
|
+
val_str = val_str.split(",")
|
|
170
|
+
if len(val_str) != 2:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Range operator requires two values for {key}: {value}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Validate both bounds as numeric or empty (for unbounded ranges)
|
|
176
|
+
start_val, end_val = val_str[0].strip(), val_str[1].strip()
|
|
177
|
+
if start_val and not (start_val.isdigit() or is_float(start_val)):
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Range operator requires numeric operands for {key}: {value}"
|
|
180
|
+
)
|
|
181
|
+
if end_val and not (end_val.isdigit() or is_float(end_val)):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Range operator requires numeric operands for {key}: {value}"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Build the SQL condition
|
|
187
|
+
range_conditions = []
|
|
188
|
+
if start_val:
|
|
189
|
+
operator = ">=" if start_inclusive else ">"
|
|
190
|
+
range_conditions.append(f"{prefix}.{key} {operator} {start_val}")
|
|
191
|
+
if end_val:
|
|
192
|
+
operator = "<=" if end_inclusive else "<"
|
|
193
|
+
range_conditions.append(f"{prefix}.{key} {operator} {end_val}")
|
|
194
|
+
|
|
195
|
+
# Join the range conditions with AND
|
|
196
|
+
filter_parts.append('( ' + " AND ".join(range_conditions) + ' )')
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Range operator requires two values for {key}: {value}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Check if value contains a known comparison operator at the start
|
|
204
|
+
matched_operator = None
|
|
205
|
+
for op in comparison_operators:
|
|
206
|
+
if val_str.startswith(op):
|
|
207
|
+
matched_operator = op
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
# Break down operator from value
|
|
211
|
+
# e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
|
|
212
|
+
if matched_operator:
|
|
213
|
+
rhs = val_str[len(matched_operator):].strip()
|
|
214
|
+
|
|
215
|
+
if matched_operator in numeric_only_ops:
|
|
216
|
+
# Must be numeric
|
|
217
|
+
if not (rhs.isdigit() or is_float(rhs)):
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
|
|
220
|
+
)
|
|
221
|
+
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
222
|
+
else:
|
|
223
|
+
# = and != operators can be numeric or string
|
|
224
|
+
if rhs.isdigit() or is_float(rhs):
|
|
225
|
+
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
226
|
+
elif rhs.lower() in ["true", "false"]:
|
|
227
|
+
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs.lower()}")
|
|
228
|
+
else:
|
|
229
|
+
# For string operands, wrap them in quotes
|
|
230
|
+
filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
|
|
231
|
+
else:
|
|
232
|
+
if val_str.isdigit() or is_float(val_str):
|
|
233
|
+
filter_parts.append(f"{prefix}.{key}={val_str}")
|
|
234
|
+
elif val_str.lower() in ["true", "false"]:
|
|
235
|
+
# This is to handle boolean values.
|
|
236
|
+
# This is not complete solution - the best solution would be to test if the field is boolean
|
|
237
|
+
# That can be done after we move to APIv2
|
|
238
|
+
filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
|
|
239
|
+
else:
|
|
240
|
+
filter_parts.append(f"{prefix}.{key}='{val_str}'")
|
|
241
|
+
|
|
242
|
+
filter_str = " AND ".join(filter_parts)
|
|
243
|
+
return f"({fixed_filter}) AND ({filter_str})" if fixed_filter else filter_str
|
|
244
|
+
|
|
128
245
|
class VectaraToolFactory:
|
|
129
246
|
"""
|
|
130
247
|
A factory class for creating Vectara RAG tools.
|
|
@@ -148,13 +265,159 @@ class VectaraToolFactory:
|
|
|
148
265
|
self.vectara_api_key = vectara_api_key
|
|
149
266
|
self.num_corpora = len(vectara_corpus_id.split(","))
|
|
150
267
|
|
|
268
|
+
def create_search_tool(
|
|
269
|
+
self,
|
|
270
|
+
tool_name: str,
|
|
271
|
+
tool_description: str,
|
|
272
|
+
tool_args_schema: type[BaseModel],
|
|
273
|
+
tool_args_type: Dict[str, str] = {},
|
|
274
|
+
fixed_filter: str = "",
|
|
275
|
+
lambda_val: float = 0.005,
|
|
276
|
+
reranker: str = "mmr",
|
|
277
|
+
rerank_k: int = 50,
|
|
278
|
+
mmr_diversity_bias: float = 0.2,
|
|
279
|
+
udf_expression: str = None,
|
|
280
|
+
rerank_chain: List[Dict] = None,
|
|
281
|
+
verbose: bool = False,
|
|
282
|
+
) -> VectaraTool:
|
|
283
|
+
"""
|
|
284
|
+
Creates a Vectara search/retrieval tool
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
tool_name (str): The name of the tool.
|
|
288
|
+
tool_description (str): The description of the tool.
|
|
289
|
+
tool_args_schema (BaseModel): The schema for the tool arguments.
|
|
290
|
+
tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
|
|
291
|
+
fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
|
|
292
|
+
lambda_val (float, optional): Lambda value for the Vectara query.
|
|
293
|
+
reranker (str, optional): The reranker mode.
|
|
294
|
+
rerank_k (int, optional): Number of top-k documents for reranking.
|
|
295
|
+
mmr_diversity_bias (float, optional): MMR diversity bias.
|
|
296
|
+
udf_expression (str, optional): the user defined expression for reranking results.
|
|
297
|
+
rerank_chain (List[Dict], optional): A list of rerankers to be applied sequentially.
|
|
298
|
+
Each dictionary should specify the "type" of reranker (mmr, slingshot, udf)
|
|
299
|
+
and any other parameters (e.g. "limit" or "cutoff" for any type,
|
|
300
|
+
"diversity_bias" for mmr, and "user_function" for udf).
|
|
301
|
+
If using slingshot/multilingual_reranker_v1, it must be first in the list.
|
|
302
|
+
verbose (bool, optional): Whether to print verbose output.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
VectaraTool: A VectaraTool object.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
vectara = VectaraIndex(
|
|
309
|
+
vectara_api_key=self.vectara_api_key,
|
|
310
|
+
vectara_customer_id=self.vectara_customer_id,
|
|
311
|
+
vectara_corpus_id=self.vectara_corpus_id,
|
|
312
|
+
x_source_str="vectara-agentic",
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Dynamically generate the search function
|
|
316
|
+
def search_function(*args, **kwargs) -> ToolOutput:
|
|
317
|
+
"""
|
|
318
|
+
Dynamically generated function for semantic search Vectara.
|
|
319
|
+
"""
|
|
320
|
+
# Convert args to kwargs using the function signature
|
|
321
|
+
sig = inspect.signature(search_function)
|
|
322
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
323
|
+
bound_args.apply_defaults()
|
|
324
|
+
kwargs = bound_args.arguments
|
|
325
|
+
|
|
326
|
+
query = kwargs.pop("query")
|
|
327
|
+
top_k = kwargs.pop("top_k", 10)
|
|
328
|
+
try:
|
|
329
|
+
filter_string = _build_filter_string(kwargs, tool_args_type, fixed_filter)
|
|
330
|
+
except ValueError as e:
|
|
331
|
+
return ToolOutput(
|
|
332
|
+
tool_name=search_function.__name__,
|
|
333
|
+
content=str(e),
|
|
334
|
+
raw_input={"args": args, "kwargs": kwargs},
|
|
335
|
+
raw_output={"response": str(e)},
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
vectara_retriever = vectara.as_retriever(
|
|
339
|
+
summary_enabled=False,
|
|
340
|
+
similarity_top_k=top_k,
|
|
341
|
+
reranker=reranker,
|
|
342
|
+
rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
|
|
343
|
+
mmr_diversity_bias=mmr_diversity_bias,
|
|
344
|
+
udf_expression=udf_expression,
|
|
345
|
+
rerank_chain=rerank_chain,
|
|
346
|
+
lambda_val=lambda_val,
|
|
347
|
+
filter=filter_string,
|
|
348
|
+
x_source_str="vectara-agentic",
|
|
349
|
+
verbose=verbose,
|
|
350
|
+
)
|
|
351
|
+
response = vectara_retriever.retrieve(query)
|
|
352
|
+
|
|
353
|
+
if len(response) == 0:
|
|
354
|
+
msg = "Vectara Tool failed to retreive any results for the query."
|
|
355
|
+
return ToolOutput(
|
|
356
|
+
tool_name=search_function.__name__,
|
|
357
|
+
content=msg,
|
|
358
|
+
raw_input={"args": args, "kwargs": kwargs},
|
|
359
|
+
raw_output={"response": msg},
|
|
360
|
+
)
|
|
361
|
+
tool_output = "Matching documents:\n"
|
|
362
|
+
unique_ids = set()
|
|
363
|
+
for doc in response:
|
|
364
|
+
if doc.id_ in unique_ids:
|
|
365
|
+
continue
|
|
366
|
+
unique_ids.add(doc.id_)
|
|
367
|
+
tool_output += f"document '{doc.id_}' metadata: {doc.metadata}\n"
|
|
368
|
+
out = ToolOutput(
|
|
369
|
+
tool_name=search_function.__name__,
|
|
370
|
+
content=tool_output,
|
|
371
|
+
raw_input={"args": args, "kwargs": kwargs},
|
|
372
|
+
raw_output=response,
|
|
373
|
+
)
|
|
374
|
+
return out
|
|
375
|
+
|
|
376
|
+
fields = tool_args_schema.model_fields
|
|
377
|
+
params = [
|
|
378
|
+
inspect.Parameter(
|
|
379
|
+
name=field_name,
|
|
380
|
+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
381
|
+
default=field_info.default,
|
|
382
|
+
annotation=field_info,
|
|
383
|
+
)
|
|
384
|
+
for field_name, field_info in fields.items()
|
|
385
|
+
]
|
|
386
|
+
|
|
387
|
+
# Create a new signature using the extracted parameters
|
|
388
|
+
sig = inspect.Signature(params)
|
|
389
|
+
search_function.__signature__ = sig
|
|
390
|
+
search_function.__annotations__["return"] = dict[str, Any]
|
|
391
|
+
search_function.__name__ = "_" + re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
|
|
392
|
+
|
|
393
|
+
# Create the tool function signature string
|
|
394
|
+
fields = []
|
|
395
|
+
for name, field in tool_args_schema.__fields__.items():
|
|
396
|
+
annotation = field.annotation
|
|
397
|
+
type_name = annotation.__name__ if hasattr(annotation, '__name__') else str(annotation)
|
|
398
|
+
fields.append(f"{name}: {type_name}")
|
|
399
|
+
args_str = ", ".join(fields)
|
|
400
|
+
function_str = f"{tool_name}({args_str}) -> str"
|
|
401
|
+
|
|
402
|
+
# Create the tool
|
|
403
|
+
tool = VectaraTool.from_defaults(
|
|
404
|
+
fn=search_function,
|
|
405
|
+
name=tool_name,
|
|
406
|
+
description=function_str + ". " + tool_description,
|
|
407
|
+
fn_schema=tool_args_schema,
|
|
408
|
+
tool_type=ToolType.QUERY,
|
|
409
|
+
)
|
|
410
|
+
return tool
|
|
411
|
+
|
|
151
412
|
def create_rag_tool(
|
|
152
413
|
self,
|
|
153
414
|
tool_name: str,
|
|
154
415
|
tool_description: str,
|
|
155
416
|
tool_args_schema: type[BaseModel],
|
|
156
417
|
tool_args_type: Dict[str, str] = {},
|
|
418
|
+
fixed_filter: str = "",
|
|
157
419
|
vectara_summarizer: str = "vectara-summary-ext-24-05-sml",
|
|
420
|
+
vectara_prompt_text: str = None,
|
|
158
421
|
summary_num_results: int = 5,
|
|
159
422
|
summary_response_lang: str = "eng",
|
|
160
423
|
n_sentences_before: int = 2,
|
|
@@ -177,7 +440,9 @@ class VectaraToolFactory:
|
|
|
177
440
|
tool_description (str): The description of the tool.
|
|
178
441
|
tool_args_schema (BaseModel): The schema for the tool arguments.
|
|
179
442
|
tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
|
|
443
|
+
fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
|
|
180
444
|
vectara_summarizer (str, optional): The Vectara summarizer to use.
|
|
445
|
+
vectara_prompt_text (str, optional): The prompt text for the Vectara summarizer.
|
|
181
446
|
summary_num_results (int, optional): The number of summary results.
|
|
182
447
|
summary_response_lang (str, optional): The response language for the summary.
|
|
183
448
|
n_sentences_before (int, optional): Number of sentences before the summary.
|
|
@@ -209,66 +474,6 @@ class VectaraToolFactory:
|
|
|
209
474
|
x_source_str="vectara-agentic",
|
|
210
475
|
)
|
|
211
476
|
|
|
212
|
-
def _build_filter_string(kwargs: Dict[str, Any], tool_args_type: Dict[str, str]) -> str:
|
|
213
|
-
filter_parts = []
|
|
214
|
-
comparison_operators = [">=", "<=", "!=", ">", "<", "="]
|
|
215
|
-
numeric_only_ops = {">", "<", ">=", "<="}
|
|
216
|
-
|
|
217
|
-
for key, value in kwargs.items():
|
|
218
|
-
if value is None or value == "":
|
|
219
|
-
continue
|
|
220
|
-
|
|
221
|
-
# Determine the prefix for the key. Valid values are "doc" or "part"
|
|
222
|
-
# default to 'doc' if not specified
|
|
223
|
-
prefix = tool_args_type.get(key, "doc")
|
|
224
|
-
|
|
225
|
-
if prefix not in ["doc", "part"]:
|
|
226
|
-
raise ValueError(
|
|
227
|
-
f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# Check if value contains a known comparison operator at the start
|
|
231
|
-
val_str = str(value).strip()
|
|
232
|
-
matched_operator = None
|
|
233
|
-
for op in comparison_operators:
|
|
234
|
-
if val_str.startswith(op):
|
|
235
|
-
matched_operator = op
|
|
236
|
-
break
|
|
237
|
-
|
|
238
|
-
# Break down operator from value
|
|
239
|
-
# e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
|
|
240
|
-
if matched_operator:
|
|
241
|
-
rhs = val_str[len(matched_operator):].strip()
|
|
242
|
-
|
|
243
|
-
if matched_operator in numeric_only_ops:
|
|
244
|
-
# Must be numeric
|
|
245
|
-
if not (rhs.isdigit() or is_float(rhs)):
|
|
246
|
-
raise ValueError(
|
|
247
|
-
f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
|
|
248
|
-
)
|
|
249
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
250
|
-
else:
|
|
251
|
-
# = and != operators can be numeric or string
|
|
252
|
-
if rhs.isdigit() or is_float(rhs):
|
|
253
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
254
|
-
elif rhs.lower() in ["true", "false"]:
|
|
255
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs.lower()}")
|
|
256
|
-
else:
|
|
257
|
-
# For string operands, wrap them in quotes
|
|
258
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
|
|
259
|
-
else:
|
|
260
|
-
if val_str.isdigit() or is_float(val_str):
|
|
261
|
-
filter_parts.append(f"{prefix}.{key}={val_str}")
|
|
262
|
-
elif val_str.lower() in ["true", "false"]:
|
|
263
|
-
# This is to handle boolean values.
|
|
264
|
-
# This is not complete solution - the best solution would be to test if the field is boolean
|
|
265
|
-
# That can be done after we move to APIv2
|
|
266
|
-
filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
|
|
267
|
-
else:
|
|
268
|
-
filter_parts.append(f"{prefix}.{key}='{val_str}'")
|
|
269
|
-
|
|
270
|
-
return " AND ".join(filter_parts)
|
|
271
|
-
|
|
272
477
|
# Dynamically generate the RAG function
|
|
273
478
|
def rag_function(*args, **kwargs) -> ToolOutput:
|
|
274
479
|
"""
|
|
@@ -281,13 +486,22 @@ class VectaraToolFactory:
|
|
|
281
486
|
kwargs = bound_args.arguments
|
|
282
487
|
|
|
283
488
|
query = kwargs.pop("query")
|
|
284
|
-
|
|
489
|
+
try:
|
|
490
|
+
filter_string = _build_filter_string(kwargs, tool_args_type, fixed_filter)
|
|
491
|
+
except ValueError as e:
|
|
492
|
+
return ToolOutput(
|
|
493
|
+
tool_name=rag_function.__name__,
|
|
494
|
+
content=str(e),
|
|
495
|
+
raw_input={"args": args, "kwargs": kwargs},
|
|
496
|
+
raw_output={"response": str(e)},
|
|
497
|
+
)
|
|
285
498
|
|
|
286
499
|
vectara_query_engine = vectara.as_query_engine(
|
|
287
500
|
summary_enabled=True,
|
|
288
501
|
summary_num_results=summary_num_results,
|
|
289
502
|
summary_response_lang=summary_response_lang,
|
|
290
503
|
summary_prompt_name=vectara_summarizer,
|
|
504
|
+
prompt_text=vectara_prompt_text,
|
|
291
505
|
reranker=reranker,
|
|
292
506
|
rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
|
|
293
507
|
mmr_diversity_bias=mmr_diversity_bias,
|
vectara_agentic/tools_catalog.py
CHANGED
|
@@ -3,9 +3,11 @@ This module contains the tools catalog for the Vectara Agentic.
|
|
|
3
3
|
"""
|
|
4
4
|
from typing import List
|
|
5
5
|
from functools import lru_cache
|
|
6
|
-
from
|
|
6
|
+
from datetime import date
|
|
7
7
|
import requests
|
|
8
8
|
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
|
|
9
11
|
from .types import LLMRole
|
|
10
12
|
from .utils import get_llm
|
|
11
13
|
|
|
@@ -19,6 +21,11 @@ get_headers = {
|
|
|
19
21
|
"Connection": "keep-alive",
|
|
20
22
|
}
|
|
21
23
|
|
|
24
|
+
def get_current_date() -> str:
|
|
25
|
+
"""
|
|
26
|
+
Returns: the current date.
|
|
27
|
+
"""
|
|
28
|
+
return date.today().strftime("%A, %B %d, %Y")
|
|
22
29
|
|
|
23
30
|
#
|
|
24
31
|
# Standard Tools
|
vectara_agentic/types.py
CHANGED
|
@@ -3,6 +3,9 @@ This module contains the types used in the Vectara Agentic.
|
|
|
3
3
|
"""
|
|
4
4
|
from enum import Enum
|
|
5
5
|
|
|
6
|
+
from llama_index.core.tools.types import ToolOutput as LI_ToolOutput
|
|
7
|
+
from llama_index.core.chat_engine.types import AgentChatResponse as LI_AgentChatResponse
|
|
8
|
+
from llama_index.core.chat_engine.types import StreamingAgentChatResponse as LI_StreamingAgentChatResponse
|
|
6
9
|
|
|
7
10
|
class AgentType(Enum):
|
|
8
11
|
"""Enumeration for different types of agents."""
|
|
@@ -29,6 +32,7 @@ class ModelProvider(Enum):
|
|
|
29
32
|
FIREWORKS = "FIREWORKS"
|
|
30
33
|
COHERE = "COHERE"
|
|
31
34
|
GEMINI = "GEMINI"
|
|
35
|
+
BEDROCK = "BEDROCK"
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
class AgentStatusType(Enum):
|
|
@@ -51,3 +55,9 @@ class ToolType(Enum):
|
|
|
51
55
|
"""Enumeration for different types of tools."""
|
|
52
56
|
QUERY = "query"
|
|
53
57
|
ACTION = "action"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# classes for Agent responses
|
|
61
|
+
ToolOutput = LI_ToolOutput
|
|
62
|
+
AgentResponse = LI_AgentChatResponse
|
|
63
|
+
AgentStreamingResponse = LI_StreamingAgentChatResponse
|
vectara_agentic/utils.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
Utilities for the Vectara agentic.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Tuple, Callable, Optional
|
|
7
6
|
from functools import lru_cache
|
|
8
7
|
|
|
@@ -13,6 +12,7 @@ from llama_index.llms.openai import OpenAI
|
|
|
13
12
|
from llama_index.llms.anthropic import Anthropic
|
|
14
13
|
|
|
15
14
|
from .types import LLMRole, AgentType, ModelProvider
|
|
15
|
+
from .agent_config import AgentConfig
|
|
16
16
|
|
|
17
17
|
provider_to_default_model_name = {
|
|
18
18
|
ModelProvider.OPENAI: "gpt-4o",
|
|
@@ -20,6 +20,7 @@ provider_to_default_model_name = {
|
|
|
20
20
|
ModelProvider.TOGETHER: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
21
21
|
ModelProvider.GROQ: "llama-3.3-70b-versatile",
|
|
22
22
|
ModelProvider.FIREWORKS: "accounts/fireworks/models/firefunction-v2",
|
|
23
|
+
ModelProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
23
24
|
ModelProvider.COHERE: "command-r-plus",
|
|
24
25
|
ModelProvider.GEMINI: "models/gemini-1.5-flash",
|
|
25
26
|
}
|
|
@@ -27,44 +28,52 @@ provider_to_default_model_name = {
|
|
|
27
28
|
DEFAULT_MODEL_PROVIDER = ModelProvider.OPENAI
|
|
28
29
|
|
|
29
30
|
@lru_cache(maxsize=None)
|
|
30
|
-
def _get_llm_params_for_role(
|
|
31
|
-
|
|
31
|
+
def _get_llm_params_for_role(
|
|
32
|
+
role: LLMRole,
|
|
33
|
+
config: Optional[AgentConfig] = None
|
|
34
|
+
) -> Tuple[ModelProvider, str]:
|
|
35
|
+
"""
|
|
36
|
+
Get the model provider and model name for the specified role.
|
|
37
|
+
|
|
38
|
+
If config is None, a new AgentConfig() is instantiated using environment defaults.
|
|
39
|
+
"""
|
|
40
|
+
config = config or AgentConfig() # fallback to default config
|
|
41
|
+
|
|
32
42
|
if role == LLMRole.TOOL:
|
|
33
|
-
model_provider =
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
model_name =
|
|
37
|
-
|
|
38
|
-
provider_to_default_model_name.get(model_provider)
|
|
43
|
+
model_provider = config.tool_llm_provider
|
|
44
|
+
# If the user hasn’t explicitly set a tool_llm_model_name,
|
|
45
|
+
# fallback to provider default from provider_to_default_model_name
|
|
46
|
+
model_name = (
|
|
47
|
+
config.tool_llm_model_name
|
|
48
|
+
or provider_to_default_model_name.get(model_provider)
|
|
39
49
|
)
|
|
40
50
|
else:
|
|
41
|
-
model_provider =
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
"VECTARA_AGENTIC_MAIN_MODEL_NAME",
|
|
46
|
-
provider_to_default_model_name.get(model_provider),
|
|
51
|
+
model_provider = config.main_llm_provider
|
|
52
|
+
model_name = (
|
|
53
|
+
config.main_llm_model_name
|
|
54
|
+
or provider_to_default_model_name.get(model_provider)
|
|
47
55
|
)
|
|
48
56
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
and model_provider != ModelProvider.OPENAI
|
|
56
|
-
):
|
|
57
|
-
raise ValueError(
|
|
58
|
-
"OpenAI agent requested but main model provider is not OpenAI."
|
|
59
|
-
)
|
|
57
|
+
# If the agent type is OpenAI, check that the main LLM provider is also OpenAI.
|
|
58
|
+
if role == LLMRole.MAIN and config.agent_type == AgentType.OPENAI:
|
|
59
|
+
if model_provider != ModelProvider.OPENAI:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"OpenAI agent requested but main model provider is not OpenAI."
|
|
62
|
+
)
|
|
60
63
|
|
|
61
64
|
return model_provider, model_name
|
|
62
65
|
|
|
63
66
|
@lru_cache(maxsize=None)
|
|
64
|
-
def get_tokenizer_for_model(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
+
def get_tokenizer_for_model(
|
|
68
|
+
role: LLMRole,
|
|
69
|
+
config: Optional[AgentConfig] = None
|
|
70
|
+
) -> Optional[Callable]:
|
|
71
|
+
"""
|
|
72
|
+
Get the tokenizer for the specified model, as determined by the role & config.
|
|
73
|
+
"""
|
|
74
|
+
model_provider, model_name = _get_llm_params_for_role(role, config)
|
|
67
75
|
if model_provider == ModelProvider.OPENAI:
|
|
76
|
+
# This might raise an exception if the model_name is unknown to tiktoken
|
|
68
77
|
return tiktoken.encoding_for_model(model_name).encode
|
|
69
78
|
if model_provider == ModelProvider.ANTHROPIC:
|
|
70
79
|
return Anthropic().tokenizer
|
|
@@ -72,10 +81,15 @@ def get_tokenizer_for_model(role: LLMRole) -> Optional[Callable]:
|
|
|
72
81
|
|
|
73
82
|
|
|
74
83
|
@lru_cache(maxsize=None)
|
|
75
|
-
def get_llm(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
84
|
+
def get_llm(
|
|
85
|
+
role: LLMRole,
|
|
86
|
+
config: Optional[AgentConfig] = None
|
|
87
|
+
) -> LLM:
|
|
88
|
+
"""
|
|
89
|
+
Get the LLM for the specified role, using the provided config
|
|
90
|
+
or a default if none is provided.
|
|
91
|
+
"""
|
|
92
|
+
model_provider, model_name = _get_llm_params_for_role(role, config)
|
|
79
93
|
if model_provider == ModelProvider.OPENAI:
|
|
80
94
|
llm = OpenAI(model=model_name, temperature=0, is_function_calling_model=True)
|
|
81
95
|
elif model_provider == ModelProvider.ANTHROPIC:
|
|
@@ -92,12 +106,14 @@ def get_llm(role: LLMRole) -> LLM:
|
|
|
92
106
|
elif model_provider == ModelProvider.FIREWORKS:
|
|
93
107
|
from llama_index.llms.fireworks import Fireworks
|
|
94
108
|
llm = Fireworks(model=model_name, temperature=0)
|
|
109
|
+
elif model_provider == ModelProvider.BEDROCK:
|
|
110
|
+
from llama_index.llms.bedrock import Bedrock
|
|
111
|
+
llm = Bedrock(model=model_name, temperature=0)
|
|
95
112
|
elif model_provider == ModelProvider.COHERE:
|
|
96
113
|
from llama_index.llms.cohere import Cohere
|
|
97
114
|
llm = Cohere(model=model_name, temperature=0)
|
|
98
115
|
else:
|
|
99
116
|
raise ValueError(f"Unknown LLM provider: {model_provider}")
|
|
100
|
-
|
|
101
117
|
return llm
|
|
102
118
|
|
|
103
119
|
def is_float(value: str) -> bool:
|