vectara-agentic 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_agent.py +18 -1
- tests/test_agent_planning.py +0 -9
- tests/test_agent_type.py +40 -0
- tests/test_tools.py +139 -41
- tests/test_vectara_llms.py +77 -0
- vectara_agentic/_prompts.py +6 -8
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +239 -78
- vectara_agentic/tools.py +209 -140
- vectara_agentic/utils.py +74 -46
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.13.dist-info}/METADATA +335 -230
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.13.dist-info}/RECORD +15 -14
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.13.dist-info}/WHEEL +1 -1
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.2.12.dist-info → vectara_agentic-0.2.13.dist-info}/top_level.txt +0 -0
vectara_agentic/tools.py
CHANGED
|
@@ -8,7 +8,7 @@ import importlib
|
|
|
8
8
|
import os
|
|
9
9
|
import asyncio
|
|
10
10
|
|
|
11
|
-
from typing import Callable, List, Dict, Any, Optional, Union, Type
|
|
11
|
+
from typing import Callable, List, Dict, Any, Optional, Union, Type, Tuple
|
|
12
12
|
from pydantic import BaseModel, Field, create_model
|
|
13
13
|
from pydantic_core import PydanticUndefined
|
|
14
14
|
|
|
@@ -22,7 +22,7 @@ from llama_index.core.workflow.context import Context
|
|
|
22
22
|
from .types import ToolType
|
|
23
23
|
from .tools_catalog import ToolsCatalog, get_bad_topics
|
|
24
24
|
from .db_tools import DatabaseTools
|
|
25
|
-
from .utils import
|
|
25
|
+
from .utils import summarize_documents, is_float
|
|
26
26
|
from .agent_config import AgentConfig
|
|
27
27
|
|
|
28
28
|
LI_packages = {
|
|
@@ -30,9 +30,11 @@ LI_packages = {
|
|
|
30
30
|
"arxiv": ToolType.QUERY,
|
|
31
31
|
"tavily_research": ToolType.QUERY,
|
|
32
32
|
"exa": ToolType.QUERY,
|
|
33
|
-
"
|
|
33
|
+
"brave_search": ToolType.QUERY,
|
|
34
|
+
"bing_search": ToolType.QUERY,
|
|
34
35
|
"neo4j": ToolType.QUERY,
|
|
35
36
|
"kuzu": ToolType.QUERY,
|
|
37
|
+
"wikipedia": ToolType.QUERY,
|
|
36
38
|
"google": {
|
|
37
39
|
"GmailToolSpec": {
|
|
38
40
|
"load_data": ToolType.QUERY,
|
|
@@ -233,10 +235,18 @@ def _create_tool_from_dynamic_function(
|
|
|
233
235
|
tool_description: str,
|
|
234
236
|
base_params_model: Type[BaseModel], # Now a Pydantic BaseModel
|
|
235
237
|
tool_args_schema: Type[BaseModel],
|
|
238
|
+
compact_docstring: bool = False,
|
|
236
239
|
) -> VectaraTool:
|
|
237
240
|
fields = {}
|
|
238
241
|
base_params = []
|
|
239
242
|
|
|
243
|
+
if tool_args_schema is None:
|
|
244
|
+
|
|
245
|
+
class EmptyBaseModel(BaseModel):
|
|
246
|
+
"""empty base model"""
|
|
247
|
+
|
|
248
|
+
tool_args_schema = EmptyBaseModel
|
|
249
|
+
|
|
240
250
|
# Create inspect.Parameter objects for base_params_model fields.
|
|
241
251
|
for param_name, model_field in base_params_model.model_fields.items():
|
|
242
252
|
field_type = base_params_model.__annotations__.get(
|
|
@@ -297,15 +307,22 @@ def _create_tool_from_dynamic_function(
|
|
|
297
307
|
function.__name__ = re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
|
|
298
308
|
|
|
299
309
|
# Build a docstring using parameter descriptions from the BaseModels.
|
|
300
|
-
params_str = "
|
|
310
|
+
params_str = ", ".join(
|
|
301
311
|
f"{p.name}: {p.annotation.__name__ if hasattr(p.annotation, '__name__') else p.annotation}"
|
|
302
312
|
for p in all_params
|
|
303
313
|
)
|
|
304
|
-
signature_line = f"{tool_name}(
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
314
|
+
signature_line = f"{tool_name}({params_str}) -> dict[str, Any]"
|
|
315
|
+
if compact_docstring:
|
|
316
|
+
doc_lines = [
|
|
317
|
+
tool_description.strip(),
|
|
318
|
+
]
|
|
319
|
+
else:
|
|
320
|
+
doc_lines = [
|
|
321
|
+
signature_line,
|
|
322
|
+
"",
|
|
323
|
+
tool_description.strip(),
|
|
324
|
+
]
|
|
325
|
+
doc_lines += [
|
|
309
326
|
"",
|
|
310
327
|
"Args:",
|
|
311
328
|
]
|
|
@@ -316,25 +333,31 @@ def _create_tool_from_dynamic_function(
|
|
|
316
333
|
elif param.name in tool_args_schema.model_fields:
|
|
317
334
|
description = tool_args_schema.model_fields[param.name].description
|
|
318
335
|
if not description:
|
|
319
|
-
description = "
|
|
336
|
+
description = ""
|
|
320
337
|
type_name = (
|
|
321
338
|
param.annotation.__name__
|
|
322
339
|
if hasattr(param.annotation, "__name__")
|
|
323
340
|
else str(param.annotation)
|
|
324
341
|
)
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
342
|
+
if (
|
|
343
|
+
param.default is not inspect.Parameter.empty
|
|
344
|
+
and param.default is not PydanticUndefined
|
|
345
|
+
):
|
|
346
|
+
default_text = f", default={param.default!r}"
|
|
347
|
+
else:
|
|
348
|
+
default_text = ""
|
|
349
|
+
doc_lines.append(f" - {param.name} ({type_name}){default_text}: {description}")
|
|
331
350
|
doc_lines.append("")
|
|
332
351
|
doc_lines.append("Returns:")
|
|
333
352
|
return_desc = getattr(
|
|
334
353
|
function, "__return_description__", "A dictionary containing the result data."
|
|
335
354
|
)
|
|
336
355
|
doc_lines.append(f" dict[str, Any]: {return_desc}")
|
|
337
|
-
|
|
356
|
+
|
|
357
|
+
initial_docstring = "\n".join(doc_lines)
|
|
358
|
+
collapsed_spaces = re.sub(r' {2,}', ' ', initial_docstring)
|
|
359
|
+
final_docstring = re.sub(r'\n{2,}', '\n', collapsed_spaces).strip()
|
|
360
|
+
function.__doc__ = final_docstring
|
|
338
361
|
|
|
339
362
|
tool = VectaraTool.from_defaults(
|
|
340
363
|
fn=function,
|
|
@@ -346,6 +369,65 @@ def _create_tool_from_dynamic_function(
|
|
|
346
369
|
return tool
|
|
347
370
|
|
|
348
371
|
|
|
372
|
+
Range = Tuple[float, float, bool, bool] # (min, max, min_inclusive, max_inclusive)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _parse_range(val_str: str) -> Range:
|
|
376
|
+
"""
|
|
377
|
+
Parses '[1,10)' or '(0.5, 5]' etc.
|
|
378
|
+
Returns (start, end, start_incl, end_incl) or raises ValueError.
|
|
379
|
+
"""
|
|
380
|
+
m = re.match(
|
|
381
|
+
r"""
|
|
382
|
+
^([\[\(])\s* # opening bracket
|
|
383
|
+
([+-]?\d+(\.\d*)?)\s*, # first number
|
|
384
|
+
\s*([+-]?\d+(\.\d*)?) # second number
|
|
385
|
+
\s*([\]\)])$ # closing bracket
|
|
386
|
+
""",
|
|
387
|
+
val_str,
|
|
388
|
+
re.VERBOSE,
|
|
389
|
+
)
|
|
390
|
+
if not m:
|
|
391
|
+
raise ValueError(f"Invalid range syntax: {val_str!r}")
|
|
392
|
+
start_inc = m.group(1) == "["
|
|
393
|
+
end_inc = m.group(7) == "]"
|
|
394
|
+
start = float(m.group(2))
|
|
395
|
+
end = float(m.group(4))
|
|
396
|
+
if start > end:
|
|
397
|
+
raise ValueError(f"Range lower bound greater than upper bound: {val_str!r}")
|
|
398
|
+
return start, end, start_inc, end_inc
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _parse_comparison(val_str: str) -> Tuple[str, Union[float, str, bool]]:
|
|
402
|
+
"""
|
|
403
|
+
Parses '>10', '<=3.14', '!=foo', \"='bar'\" etc.
|
|
404
|
+
Returns (operator, rhs) or raises ValueError.
|
|
405
|
+
"""
|
|
406
|
+
# pick off the operator
|
|
407
|
+
comparison_operators = [">=", "<=", "!=", ">", "<", "="]
|
|
408
|
+
numeric_only_operators = {">", "<", ">=", "<="}
|
|
409
|
+
for op in comparison_operators:
|
|
410
|
+
if val_str.startswith(op):
|
|
411
|
+
rhs = val_str[len(op) :].strip()
|
|
412
|
+
if op in numeric_only_operators:
|
|
413
|
+
try:
|
|
414
|
+
rhs_val = float(rhs)
|
|
415
|
+
except ValueError as e:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
f"Numeric comparison {op!r} must have a number, got {rhs!r}"
|
|
418
|
+
) from e
|
|
419
|
+
return op, rhs_val
|
|
420
|
+
# = and != can be bool, numeric, or string
|
|
421
|
+
low = rhs.lower()
|
|
422
|
+
if low in ("true", "false"):
|
|
423
|
+
return op, (low == "true")
|
|
424
|
+
try:
|
|
425
|
+
return op, float(rhs)
|
|
426
|
+
except ValueError:
|
|
427
|
+
return op, rhs
|
|
428
|
+
raise ValueError(f"No valid comparison operator at start of {val_str!r}")
|
|
429
|
+
|
|
430
|
+
|
|
349
431
|
def _build_filter_string(
|
|
350
432
|
kwargs: Dict[str, Any], tool_args_type: Dict[str, dict], fixed_filter: str
|
|
351
433
|
) -> str:
|
|
@@ -353,130 +435,84 @@ def _build_filter_string(
|
|
|
353
435
|
Build filter string for Vectara from kwargs
|
|
354
436
|
"""
|
|
355
437
|
filter_parts = []
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
for key, value in kwargs.items():
|
|
360
|
-
if value is None or value == "":
|
|
438
|
+
for key, raw in kwargs.items():
|
|
439
|
+
if raw is None or raw == "":
|
|
361
440
|
continue
|
|
362
441
|
|
|
363
|
-
|
|
364
|
-
# default to 'doc' if not specified
|
|
365
|
-
tool_args_dict = tool_args_type.get(key, {"type": "doc", "is_list": False})
|
|
366
|
-
prefix = tool_args_dict.get(key, "doc")
|
|
367
|
-
is_list = tool_args_dict.get("is_list", False)
|
|
368
|
-
|
|
369
|
-
if prefix not in ["doc", "part"]:
|
|
370
|
-
raise ValueError(
|
|
371
|
-
f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
if value is PydanticUndefined:
|
|
442
|
+
if raw is PydanticUndefined:
|
|
375
443
|
raise ValueError(
|
|
376
|
-
f"Value of argument {key} is undefined, and this is invalid."
|
|
444
|
+
f"Value of argument {key!r} is undefined, and this is invalid. "
|
|
377
445
|
"Please form proper arguments and try again."
|
|
378
446
|
)
|
|
379
447
|
|
|
380
|
-
|
|
381
|
-
|
|
448
|
+
tool_args_dict = tool_args_type.get(key, {"type": "doc", "is_list": False})
|
|
449
|
+
prefix = tool_args_dict.get("type", "doc")
|
|
450
|
+
is_list = tool_args_dict.get("is_list", False)
|
|
382
451
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
end_inclusive = val_str.endswith("]")
|
|
452
|
+
if prefix not in ("doc", "part"):
|
|
453
|
+
raise ValueError(
|
|
454
|
+
f'Unrecognized prefix {prefix!r}. Please make sure to use either "doc" or "part" for the prefix.'
|
|
455
|
+
)
|
|
388
456
|
|
|
389
|
-
|
|
390
|
-
|
|
457
|
+
# 1) native numeric
|
|
458
|
+
if isinstance(raw, (int, float)) or is_float(str(raw)):
|
|
459
|
+
val = str(raw)
|
|
460
|
+
if is_list:
|
|
461
|
+
filter_parts.append(f"({val} IN {prefix}.{key})")
|
|
462
|
+
else:
|
|
463
|
+
filter_parts.append(f"{prefix}.{key}={val}")
|
|
464
|
+
continue
|
|
391
465
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
466
|
+
# 2) native boolean
|
|
467
|
+
if isinstance(raw, bool):
|
|
468
|
+
val = "true" if raw else "false"
|
|
469
|
+
if is_list:
|
|
470
|
+
filter_parts.append(f"({val} IN {prefix}.{key})")
|
|
471
|
+
else:
|
|
472
|
+
filter_parts.append(f"{prefix}.{key}={val}")
|
|
473
|
+
continue
|
|
398
474
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
475
|
+
if not isinstance(raw, str):
|
|
476
|
+
raise ValueError(f"Unsupported type for {key!r}: {type(raw).__name__}")
|
|
477
|
+
|
|
478
|
+
val_str = raw.strip()
|
|
479
|
+
|
|
480
|
+
# 3) Range operator
|
|
481
|
+
if (val_str.startswith("[") or val_str.startswith("(")) and (
|
|
482
|
+
val_str.endswith("]") or val_str.endswith(")")
|
|
483
|
+
):
|
|
484
|
+
start, end, start_incl, end_incl = _parse_range(val_str)
|
|
485
|
+
conds = []
|
|
486
|
+
op1 = ">=" if start_incl else ">"
|
|
487
|
+
op2 = "<=" if end_incl else "<"
|
|
488
|
+
conds.append(f"{prefix}.{key} {op1} {start}")
|
|
489
|
+
conds.append(f"{prefix}.{key} {op2} {end}")
|
|
490
|
+
filter_parts.append("(" + " AND ".join(conds) + ")")
|
|
491
|
+
continue
|
|
409
492
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
range_conditions.append(f"{prefix}.{key} {operator} {end_val}")
|
|
418
|
-
|
|
419
|
-
# Join the range conditions with AND
|
|
420
|
-
filter_parts.append("( " + " AND ".join(range_conditions) + " )")
|
|
421
|
-
continue
|
|
422
|
-
|
|
423
|
-
raise ValueError(f"Range operator requires two values for {key}: {value}")
|
|
424
|
-
|
|
425
|
-
# Check if value contains a known comparison operator at the start
|
|
426
|
-
matched_operator = None
|
|
427
|
-
for op in comparison_operators:
|
|
428
|
-
if val_str.startswith(op):
|
|
429
|
-
matched_operator = op
|
|
430
|
-
break
|
|
431
|
-
|
|
432
|
-
# Break down operator from value
|
|
433
|
-
# e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
|
|
434
|
-
if matched_operator:
|
|
435
|
-
rhs = val_str[len(matched_operator) :].strip()
|
|
436
|
-
|
|
437
|
-
if matched_operator in numeric_only_ops:
|
|
438
|
-
# Must be numeric
|
|
439
|
-
if not (rhs.isdigit() or is_float(rhs)):
|
|
440
|
-
raise ValueError(
|
|
441
|
-
f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
|
|
442
|
-
)
|
|
443
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
493
|
+
# 4) comparison operator
|
|
494
|
+
try:
|
|
495
|
+
op, rhs = _parse_comparison(val_str)
|
|
496
|
+
except ValueError:
|
|
497
|
+
# no operator → treat as membership or equality-on-string
|
|
498
|
+
if is_list:
|
|
499
|
+
filter_parts.append(f"('{val_str}' IN {prefix}.{key})")
|
|
444
500
|
else:
|
|
445
|
-
|
|
446
|
-
if rhs.isdigit() or is_float(rhs):
|
|
447
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
|
|
448
|
-
elif rhs.lower() in ["true", "false"]:
|
|
449
|
-
filter_parts.append(
|
|
450
|
-
f"{prefix}.{key}{matched_operator}{rhs.lower()}"
|
|
451
|
-
)
|
|
452
|
-
else:
|
|
453
|
-
# For string operands, wrap them in quotes
|
|
454
|
-
filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
|
|
501
|
+
filter_parts.append(f"{prefix}.{key}='{val_str}'")
|
|
455
502
|
else:
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
elif val_str.lower() in ["true", "false"]:
|
|
462
|
-
# This is to handle boolean values.
|
|
463
|
-
# This is not complete solution - the best solution would be to test if the field is boolean
|
|
464
|
-
# That can be done after we move to APIv2
|
|
465
|
-
if is_list:
|
|
466
|
-
filter_parts.append(f"({val_str.lower()} IN {prefix}.{key})")
|
|
467
|
-
else:
|
|
468
|
-
filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
|
|
503
|
+
# normal comparison always binds to the field
|
|
504
|
+
if isinstance(rhs, bool):
|
|
505
|
+
rhs_sql = "true" if rhs else "false"
|
|
506
|
+
elif isinstance(rhs, (int, float)):
|
|
507
|
+
rhs_sql = str(rhs)
|
|
469
508
|
else:
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
return f"({fixed_filter}) AND ({filter_str})"
|
|
478
|
-
else:
|
|
479
|
-
return fixed_filter or filter_str
|
|
509
|
+
rhs_sql = f"'{rhs}'"
|
|
510
|
+
filter_parts.append(f"{prefix}.{key}{op}{rhs_sql}")
|
|
511
|
+
|
|
512
|
+
joined = " AND ".join(filter_parts)
|
|
513
|
+
if fixed_filter and joined:
|
|
514
|
+
return f"({fixed_filter}) AND ({joined})"
|
|
515
|
+
return fixed_filter or joined
|
|
480
516
|
|
|
481
517
|
|
|
482
518
|
class VectaraToolFactory:
|
|
@@ -488,25 +524,29 @@ class VectaraToolFactory:
|
|
|
488
524
|
self,
|
|
489
525
|
vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
|
|
490
526
|
vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
|
|
527
|
+
compact_docstring: bool = False,
|
|
491
528
|
) -> None:
|
|
492
529
|
"""
|
|
493
530
|
Initialize the VectaraToolFactory
|
|
494
531
|
Args:
|
|
495
532
|
vectara_corpus_key (str): The Vectara corpus key (or comma separated list of keys).
|
|
496
533
|
vectara_api_key (str): The Vectara API key.
|
|
534
|
+
compact_docstring (bool): Whether to use a compact docstring format for tools
|
|
535
|
+
This is useful if OpenAI complains on the 1024 token limit.
|
|
497
536
|
"""
|
|
498
537
|
self.vectara_corpus_key = vectara_corpus_key
|
|
499
538
|
self.vectara_api_key = vectara_api_key
|
|
500
539
|
self.num_corpora = len(vectara_corpus_key.split(","))
|
|
501
|
-
self.
|
|
502
|
-
self.max_cache_size = 128
|
|
540
|
+
self.compact_docstring = compact_docstring
|
|
503
541
|
|
|
504
542
|
def create_search_tool(
|
|
505
543
|
self,
|
|
506
544
|
tool_name: str,
|
|
507
545
|
tool_description: str,
|
|
508
|
-
tool_args_schema: type[BaseModel],
|
|
546
|
+
tool_args_schema: type[BaseModel] = None,
|
|
509
547
|
tool_args_type: Dict[str, str] = {},
|
|
548
|
+
summarize_docs: Optional[bool] = None,
|
|
549
|
+
summarize_llm_name: Optional[str] = None,
|
|
510
550
|
fixed_filter: str = "",
|
|
511
551
|
lambda_val: Union[List[float], float] = 0.005,
|
|
512
552
|
semantics: Union[List[str] | str] = "default",
|
|
@@ -532,7 +572,7 @@ class VectaraToolFactory:
|
|
|
532
572
|
Args:
|
|
533
573
|
tool_name (str): The name of the tool.
|
|
534
574
|
tool_description (str): The description of the tool.
|
|
535
|
-
tool_args_schema (BaseModel): The schema for the tool arguments.
|
|
575
|
+
tool_args_schema (BaseModel, optional): The schema for the tool arguments.
|
|
536
576
|
tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
|
|
537
577
|
fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
|
|
538
578
|
lambda_val (Union[List[float] | float], optional): Lambda value (or list of values for each corpora)
|
|
@@ -584,7 +624,11 @@ class VectaraToolFactory:
|
|
|
584
624
|
|
|
585
625
|
query = kwargs.pop("query")
|
|
586
626
|
top_k = kwargs.pop("top_k", 10)
|
|
587
|
-
summarize =
|
|
627
|
+
summarize = (
|
|
628
|
+
kwargs.pop("summarize", True)
|
|
629
|
+
if summarize_docs is None
|
|
630
|
+
else summarize_docs
|
|
631
|
+
)
|
|
588
632
|
try:
|
|
589
633
|
filter_string = _build_filter_string(
|
|
590
634
|
kwargs, tool_args_type, fixed_filter
|
|
@@ -643,7 +687,10 @@ class VectaraToolFactory:
|
|
|
643
687
|
if summarize:
|
|
644
688
|
summaries_dict = asyncio.run(
|
|
645
689
|
summarize_documents(
|
|
646
|
-
self.vectara_corpus_key,
|
|
690
|
+
corpus_key=self.vectara_corpus_key,
|
|
691
|
+
api_key=self.vectara_api_key,
|
|
692
|
+
llm_name=summarize_llm_name,
|
|
693
|
+
doc_ids=list(unique_ids),
|
|
647
694
|
)
|
|
648
695
|
)
|
|
649
696
|
for doc_id, metadata in docs:
|
|
@@ -665,30 +712,47 @@ class VectaraToolFactory:
|
|
|
665
712
|
|
|
666
713
|
class SearchToolBaseParams(BaseModel):
|
|
667
714
|
"""Model for the base parameters of the search tool."""
|
|
715
|
+
|
|
668
716
|
query: str = Field(
|
|
669
717
|
...,
|
|
670
|
-
description="The search query to perform,
|
|
718
|
+
description="The search query to perform, in the form of a question.",
|
|
671
719
|
)
|
|
672
720
|
top_k: int = Field(
|
|
673
721
|
10, description="The number of top documents to retrieve."
|
|
674
722
|
)
|
|
675
723
|
summarize: bool = Field(
|
|
676
724
|
True,
|
|
677
|
-
description="
|
|
725
|
+
description="Whether to summarize the retrieved documents.",
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
class SearchToolBaseParamsWithoutSummarize(BaseModel):
|
|
729
|
+
"""Model for the base parameters of the search tool."""
|
|
730
|
+
|
|
731
|
+
query: str = Field(
|
|
732
|
+
...,
|
|
733
|
+
description="The search query to perform, in the form of a question.",
|
|
734
|
+
)
|
|
735
|
+
top_k: int = Field(
|
|
736
|
+
10, description="The number of top documents to retrieve."
|
|
678
737
|
)
|
|
679
738
|
|
|
680
739
|
search_tool_extra_desc = (
|
|
681
740
|
tool_description
|
|
682
741
|
+ "\n"
|
|
683
|
-
+ "
|
|
742
|
+
+ "Use this tool to search for relevant documents, not to ask questions."
|
|
684
743
|
)
|
|
685
744
|
|
|
686
745
|
tool = _create_tool_from_dynamic_function(
|
|
687
746
|
search_function,
|
|
688
747
|
tool_name,
|
|
689
748
|
search_tool_extra_desc,
|
|
690
|
-
|
|
749
|
+
(
|
|
750
|
+
SearchToolBaseParams
|
|
751
|
+
if summarize_docs is None
|
|
752
|
+
else SearchToolBaseParamsWithoutSummarize
|
|
753
|
+
),
|
|
691
754
|
tool_args_schema,
|
|
755
|
+
compact_docstring=self.compact_docstring,
|
|
692
756
|
)
|
|
693
757
|
return tool
|
|
694
758
|
|
|
@@ -696,7 +760,7 @@ class VectaraToolFactory:
|
|
|
696
760
|
self,
|
|
697
761
|
tool_name: str,
|
|
698
762
|
tool_description: str,
|
|
699
|
-
tool_args_schema: type[BaseModel],
|
|
763
|
+
tool_args_schema: type[BaseModel] = None,
|
|
700
764
|
tool_args_type: Dict[str, dict] = {},
|
|
701
765
|
fixed_filter: str = "",
|
|
702
766
|
vectara_summarizer: str = "vectara-summary-ext-24-05-med-omni",
|
|
@@ -718,6 +782,7 @@ class VectaraToolFactory:
|
|
|
718
782
|
rerank_chain: List[Dict] = None,
|
|
719
783
|
max_response_chars: Optional[int] = None,
|
|
720
784
|
max_tokens: Optional[int] = None,
|
|
785
|
+
llm_name: Optional[str] = None,
|
|
721
786
|
temperature: Optional[float] = None,
|
|
722
787
|
frequency_penalty: Optional[float] = None,
|
|
723
788
|
presence_penalty: Optional[float] = None,
|
|
@@ -734,7 +799,7 @@ class VectaraToolFactory:
|
|
|
734
799
|
Args:
|
|
735
800
|
tool_name (str): The name of the tool.
|
|
736
801
|
tool_description (str): The description of the tool.
|
|
737
|
-
tool_args_schema (BaseModel): The schema for
|
|
802
|
+
tool_args_schema (BaseModel, optional): The schema for any tool arguments for filtering.
|
|
738
803
|
tool_args_type (Dict[str, dict], optional): attributes for each argument where they key is the field name
|
|
739
804
|
and the value is a dictionary with the following keys:
|
|
740
805
|
- 'type': the type of each filter attribute in Vectara (doc or part).
|
|
@@ -765,6 +830,7 @@ class VectaraToolFactory:
|
|
|
765
830
|
If using slingshot/multilingual_reranker_v1, it must be first in the list.
|
|
766
831
|
max_response_chars (int, optional): The desired maximum number of characters for the generated summary.
|
|
767
832
|
max_tokens (int, optional): The maximum number of tokens to be returned by the LLM.
|
|
833
|
+
llm_name (str, optional): The name of the LLM to use for generation.
|
|
768
834
|
temperature (float, optional): The sampling temperature; higher values lead to more randomness.
|
|
769
835
|
frequency_penalty (float, optional): How much to penalize repeating tokens in the response,
|
|
770
836
|
higher values reducing likelihood of repeating the same line.
|
|
@@ -842,6 +908,7 @@ class VectaraToolFactory:
|
|
|
842
908
|
filter=filter_string,
|
|
843
909
|
max_response_chars=max_response_chars,
|
|
844
910
|
max_tokens=max_tokens,
|
|
911
|
+
llm_name=llm_name,
|
|
845
912
|
temperature=temperature,
|
|
846
913
|
frequency_penalty=frequency_penalty,
|
|
847
914
|
presence_penalty=presence_penalty,
|
|
@@ -920,9 +987,10 @@ class VectaraToolFactory:
|
|
|
920
987
|
|
|
921
988
|
class RagToolBaseParams(BaseModel):
|
|
922
989
|
"""Model for the base parameters of the RAG tool."""
|
|
990
|
+
|
|
923
991
|
query: str = Field(
|
|
924
992
|
...,
|
|
925
|
-
description="The search query to perform,
|
|
993
|
+
description="The search query to perform, in the form of a question",
|
|
926
994
|
)
|
|
927
995
|
|
|
928
996
|
tool = _create_tool_from_dynamic_function(
|
|
@@ -931,6 +999,7 @@ class VectaraToolFactory:
|
|
|
931
999
|
tool_description,
|
|
932
1000
|
RagToolBaseParams,
|
|
933
1001
|
tool_args_schema,
|
|
1002
|
+
compact_docstring=self.compact_docstring,
|
|
934
1003
|
)
|
|
935
1004
|
return tool
|
|
936
1005
|
|