vectara-agentic 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

vectara_agentic/tools.py CHANGED
@@ -8,7 +8,7 @@ import importlib
8
8
  import os
9
9
  import asyncio
10
10
 
11
- from typing import Callable, List, Dict, Any, Optional, Union, Type
11
+ from typing import Callable, List, Dict, Any, Optional, Union, Type, Tuple
12
12
  from pydantic import BaseModel, Field, create_model
13
13
  from pydantic_core import PydanticUndefined
14
14
 
@@ -22,7 +22,7 @@ from llama_index.core.workflow.context import Context
22
22
  from .types import ToolType
23
23
  from .tools_catalog import ToolsCatalog, get_bad_topics
24
24
  from .db_tools import DatabaseTools
25
- from .utils import is_float, summarize_documents
25
+ from .utils import summarize_documents, is_float
26
26
  from .agent_config import AgentConfig
27
27
 
28
28
  LI_packages = {
@@ -30,9 +30,11 @@ LI_packages = {
30
30
  "arxiv": ToolType.QUERY,
31
31
  "tavily_research": ToolType.QUERY,
32
32
  "exa": ToolType.QUERY,
33
- "brave": ToolType.QUERY,
33
+ "brave_search": ToolType.QUERY,
34
+ "bing_search": ToolType.QUERY,
34
35
  "neo4j": ToolType.QUERY,
35
36
  "kuzu": ToolType.QUERY,
37
+ "wikipedia": ToolType.QUERY,
36
38
  "google": {
37
39
  "GmailToolSpec": {
38
40
  "load_data": ToolType.QUERY,
@@ -173,6 +175,21 @@ class VectaraTool(FunctionTool):
173
175
  ) -> ToolOutput:
174
176
  try:
175
177
  return super().call(*args, ctx=ctx, **kwargs)
178
+ except TypeError as e:
179
+ sig = inspect.signature(self.metadata.fn_schema)
180
+ valid_parameters = list(sig.parameters.keys())
181
+ params_str = ", ".join(valid_parameters)
182
+
183
+ err_output = ToolOutput(
184
+ tool_name=self.metadata.name,
185
+ content=(
186
+ f"Wrong argument used when calling {self.metadata.name}: {str(e)}. "
187
+ f"Valid arguments: {params_str}. please call the tool again with the correct arguments."
188
+ ),
189
+ raw_input={"args": args, "kwargs": kwargs},
190
+ raw_output={"response": str(e)},
191
+ )
192
+ return err_output
176
193
  except Exception as e:
177
194
  err_output = ToolOutput(
178
195
  tool_name=self.metadata.name,
@@ -187,6 +204,21 @@ class VectaraTool(FunctionTool):
187
204
  ) -> ToolOutput:
188
205
  try:
189
206
  return await super().acall(*args, ctx=ctx, **kwargs)
207
+ except TypeError as e:
208
+ sig = inspect.signature(self.metadata.fn_schema)
209
+ valid_parameters = list(sig.parameters.keys())
210
+ params_str = ", ".join(valid_parameters)
211
+
212
+ err_output = ToolOutput(
213
+ tool_name=self.metadata.name,
214
+ content=(
215
+ f"Wrong argument used when calling {self.metadata.name}: {str(e)}. "
216
+ f"Valid arguments: {params_str}. please call the tool again with the correct arguments."
217
+ ),
218
+ raw_input={"args": args, "kwargs": kwargs},
219
+ raw_output={"response": str(e)},
220
+ )
221
+ return err_output
190
222
  except Exception as e:
191
223
  err_output = ToolOutput(
192
224
  tool_name=self.metadata.name,
@@ -201,80 +233,201 @@ def _create_tool_from_dynamic_function(
201
233
  function: Callable[..., ToolOutput],
202
234
  tool_name: str,
203
235
  tool_description: str,
204
- base_params: list[inspect.Parameter],
205
- tool_args_schema: type[BaseModel],
236
+ base_params_model: Type[BaseModel], # Now a Pydantic BaseModel
237
+ tool_args_schema: Type[BaseModel],
238
+ compact_docstring: bool = False,
206
239
  ) -> VectaraTool:
207
- """
208
- Create a VectaraTool from a dynamic function, including
209
- setting the function signature and creating the tool schema.
210
- """
211
240
  fields = {}
212
- for param in base_params:
241
+ base_params = []
242
+
243
+ if tool_args_schema is None:
244
+
245
+ class EmptyBaseModel(BaseModel):
246
+ """empty base model"""
247
+
248
+ tool_args_schema = EmptyBaseModel
249
+
250
+ # Create inspect.Parameter objects for base_params_model fields.
251
+ for param_name, model_field in base_params_model.model_fields.items():
252
+ field_type = base_params_model.__annotations__.get(
253
+ param_name, str
254
+ ) # default to str if not found
213
255
  default_value = (
214
- param.default if param.default != inspect.Parameter.empty else ...
256
+ model_field.default
257
+ if model_field.default is not None
258
+ else inspect.Parameter.empty
215
259
  )
216
- fields[param.name] = (param.annotation, default_value)
260
+ base_params.append(
261
+ inspect.Parameter(
262
+ param_name,
263
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
264
+ default=default_value,
265
+ annotation=field_type,
266
+ )
267
+ )
268
+ fields[param_name] = (
269
+ field_type,
270
+ model_field.default if model_field.default is not None else ...,
271
+ )
272
+
273
+ # Add tool_args_schema fields to the fields dict if not already included.
274
+ # Also add them to the function signature by creating new inspect.Parameter objects.
217
275
  for field_name, field_info in tool_args_schema.model_fields.items():
218
276
  if field_name not in fields:
219
277
  default_value = (
220
- field_info.default
221
- if field_info.default is not PydanticUndefined
222
- else ...
278
+ field_info.default if field_info.default is not None else ...
223
279
  )
224
- fields[field_name] = (field_info.annotation, default_value)
225
- fn_schema = create_model(f"{tool_name}", **fields)
226
-
227
- schema_params = [
228
- inspect.Parameter(
229
- name=field_name,
230
- kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
231
- default=(
232
- field_info.default
233
- if field_info.default is not PydanticUndefined
234
- else inspect.Parameter.empty
235
- ),
236
- annotation=(
237
- field_info.annotation
238
- if hasattr(field_info, "annotation")
239
- else field_info
240
- ),
241
- )
242
- for field_name, field_info in tool_args_schema.model_fields.items()
243
- if field_name not in [p.name for p in base_params]
244
- ]
245
- all_params = base_params + schema_params
280
+ field_type = tool_args_schema.__annotations__.get(field_name, None)
281
+ fields[field_name] = (field_type, default_value)
282
+ # Append these fields to the signature.
283
+ base_params.append(
284
+ inspect.Parameter(
285
+ field_name,
286
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
287
+ default=(
288
+ default_value
289
+ if default_value is not ...
290
+ else inspect.Parameter.empty
291
+ ),
292
+ annotation=field_type,
293
+ )
294
+ )
295
+
296
+ # Create the dynamic schema with both base_params_model and tool_args_schema fields.
297
+ fn_schema = create_model(f"{tool_name}_schema", **fields)
246
298
 
299
+ # Combine parameters into a function signature.
300
+ all_params = base_params[:] # Now all_params contains parameters from both models.
247
301
  required_params = [p for p in all_params if p.default is inspect.Parameter.empty]
248
302
  optional_params = [
249
303
  p for p in all_params if p.default is not inspect.Parameter.empty
250
304
  ]
251
- sig = inspect.Signature(required_params + optional_params)
252
- function.__signature__ = sig
305
+ function.__signature__ = inspect.Signature(required_params + optional_params)
253
306
  function.__annotations__["return"] = dict[str, Any]
254
307
  function.__name__ = re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
255
308
 
256
- # Create the tool function signature string
257
- param_strs = []
309
+ # Build a docstring using parameter descriptions from the BaseModels.
310
+ params_str = ", ".join(
311
+ f"{p.name}: {p.annotation.__name__ if hasattr(p.annotation, '__name__') else p.annotation}"
312
+ for p in all_params
313
+ )
314
+ signature_line = f"{tool_name}({params_str}) -> dict[str, Any]"
315
+ if compact_docstring:
316
+ doc_lines = [
317
+ tool_description.strip(),
318
+ ]
319
+ else:
320
+ doc_lines = [
321
+ signature_line,
322
+ "",
323
+ tool_description.strip(),
324
+ ]
325
+ doc_lines += [
326
+ "",
327
+ "Args:",
328
+ ]
258
329
  for param in all_params:
259
- annotation = param.annotation
330
+ description = ""
331
+ if param.name in base_params_model.model_fields:
332
+ description = base_params_model.model_fields[param.name].description
333
+ elif param.name in tool_args_schema.model_fields:
334
+ description = tool_args_schema.model_fields[param.name].description
335
+ if not description:
336
+ description = ""
260
337
  type_name = (
261
- annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
338
+ param.annotation.__name__
339
+ if hasattr(param.annotation, "__name__")
340
+ else str(param.annotation)
262
341
  )
263
- param_strs.append(f"{param.name}: {type_name}")
264
- args_str = ", ".join(param_strs)
265
- function_str = f"{tool_name}({args_str}) -> str"
342
+ if (
343
+ param.default is not inspect.Parameter.empty
344
+ and param.default is not PydanticUndefined
345
+ ):
346
+ default_text = f", default={param.default!r}"
347
+ else:
348
+ default_text = ""
349
+ doc_lines.append(f" - {param.name} ({type_name}){default_text}: {description}")
350
+ doc_lines.append("")
351
+ doc_lines.append("Returns:")
352
+ return_desc = getattr(
353
+ function, "__return_description__", "A dictionary containing the result data."
354
+ )
355
+ doc_lines.append(f" dict[str, Any]: {return_desc}")
356
+
357
+ initial_docstring = "\n".join(doc_lines)
358
+ collapsed_spaces = re.sub(r' {2,}', ' ', initial_docstring)
359
+ final_docstring = re.sub(r'\n{2,}', '\n', collapsed_spaces).strip()
360
+ function.__doc__ = final_docstring
266
361
 
267
- # Create the tool
268
362
  tool = VectaraTool.from_defaults(
269
363
  fn=function,
270
364
  name=tool_name,
271
- description=function_str + "\n" + tool_description,
365
+ description=function.__doc__,
272
366
  fn_schema=fn_schema,
273
367
  tool_type=ToolType.QUERY,
274
368
  )
275
369
  return tool
276
370
 
277
371
 
372
+ Range = Tuple[float, float, bool, bool] # (min, max, min_inclusive, max_inclusive)
373
+
374
+
375
+ def _parse_range(val_str: str) -> Range:
376
+ """
377
+ Parses '[1,10)' or '(0.5, 5]' etc.
378
+ Returns (start, end, start_incl, end_incl) or raises ValueError.
379
+ """
380
+ m = re.match(
381
+ r"""
382
+ ^([\[\(])\s* # opening bracket
383
+ ([+-]?\d+(\.\d*)?)\s*, # first number
384
+ \s*([+-]?\d+(\.\d*)?) # second number
385
+ \s*([\]\)])$ # closing bracket
386
+ """,
387
+ val_str,
388
+ re.VERBOSE,
389
+ )
390
+ if not m:
391
+ raise ValueError(f"Invalid range syntax: {val_str!r}")
392
+ start_inc = m.group(1) == "["
393
+ end_inc = m.group(7) == "]"
394
+ start = float(m.group(2))
395
+ end = float(m.group(4))
396
+ if start > end:
397
+ raise ValueError(f"Range lower bound greater than upper bound: {val_str!r}")
398
+ return start, end, start_inc, end_inc
399
+
400
+
401
+ def _parse_comparison(val_str: str) -> Tuple[str, Union[float, str, bool]]:
402
+ """
403
+ Parses '>10', '<=3.14', '!=foo', \"='bar'\" etc.
404
+ Returns (operator, rhs) or raises ValueError.
405
+ """
406
+ # pick off the operator
407
+ comparison_operators = [">=", "<=", "!=", ">", "<", "="]
408
+ numeric_only_operators = {">", "<", ">=", "<="}
409
+ for op in comparison_operators:
410
+ if val_str.startswith(op):
411
+ rhs = val_str[len(op) :].strip()
412
+ if op in numeric_only_operators:
413
+ try:
414
+ rhs_val = float(rhs)
415
+ except ValueError as e:
416
+ raise ValueError(
417
+ f"Numeric comparison {op!r} must have a number, got {rhs!r}"
418
+ ) from e
419
+ return op, rhs_val
420
+ # = and != can be bool, numeric, or string
421
+ low = rhs.lower()
422
+ if low in ("true", "false"):
423
+ return op, (low == "true")
424
+ try:
425
+ return op, float(rhs)
426
+ except ValueError:
427
+ return op, rhs
428
+ raise ValueError(f"No valid comparison operator at start of {val_str!r}")
429
+
430
+
278
431
  def _build_filter_string(
279
432
  kwargs: Dict[str, Any], tool_args_type: Dict[str, dict], fixed_filter: str
280
433
  ) -> str:
@@ -282,130 +435,84 @@ def _build_filter_string(
282
435
  Build filter string for Vectara from kwargs
283
436
  """
284
437
  filter_parts = []
285
- comparison_operators = [">=", "<=", "!=", ">", "<", "="]
286
- numeric_only_ops = {">", "<", ">=", "<="}
287
-
288
- for key, value in kwargs.items():
289
- if value is None or value == "":
438
+ for key, raw in kwargs.items():
439
+ if raw is None or raw == "":
290
440
  continue
291
441
 
292
- # Determine the prefix for the key. Valid values are "doc" or "part"
293
- # default to 'doc' if not specified
294
- tool_args_dict = tool_args_type.get(key, {"type": "doc", "is_list": False})
295
- prefix = tool_args_dict.get(key, "doc")
296
- is_list = tool_args_dict.get("is_list", False)
297
-
298
- if prefix not in ["doc", "part"]:
299
- raise ValueError(
300
- f'Unrecognized prefix {prefix}. Please make sure to use either "doc" or "part" for the prefix.'
301
- )
302
-
303
- if value is PydanticUndefined:
442
+ if raw is PydanticUndefined:
304
443
  raise ValueError(
305
- f"Value of argument {key} is undefined, and this is invalid."
444
+ f"Value of argument {key!r} is undefined, and this is invalid. "
306
445
  "Please form proper arguments and try again."
307
446
  )
308
447
 
309
- # value of the argument
310
- val_str = str(value).strip()
448
+ tool_args_dict = tool_args_type.get(key, {"type": "doc", "is_list": False})
449
+ prefix = tool_args_dict.get("type", "doc")
450
+ is_list = tool_args_dict.get("is_list", False)
311
451
 
312
- # Special handling for range operator
313
- if val_str.startswith(("[", "(")) and val_str.endswith(("]", ")")):
314
- # Extract the boundary types
315
- start_inclusive = val_str.startswith("[")
316
- end_inclusive = val_str.endswith("]")
452
+ if prefix not in ("doc", "part"):
453
+ raise ValueError(
454
+ f'Unrecognized prefix {prefix!r}. Please make sure to use either "doc" or "part" for the prefix.'
455
+ )
317
456
 
318
- # Remove the boundaries and strip whitespace
319
- val_str = val_str[1:-1].strip()
457
+ # 1) native numeric
458
+ if isinstance(raw, (int, float)) or is_float(str(raw)):
459
+ val = str(raw)
460
+ if is_list:
461
+ filter_parts.append(f"({val} IN {prefix}.{key})")
462
+ else:
463
+ filter_parts.append(f"{prefix}.{key}={val}")
464
+ continue
320
465
 
321
- if "," in val_str:
322
- val_str = val_str.split(",")
323
- if len(val_str) != 2:
324
- raise ValueError(
325
- f"Range operator requires two values for {key}: {value}"
326
- )
466
+ # 2) native boolean
467
+ if isinstance(raw, bool):
468
+ val = "true" if raw else "false"
469
+ if is_list:
470
+ filter_parts.append(f"({val} IN {prefix}.{key})")
471
+ else:
472
+ filter_parts.append(f"{prefix}.{key}={val}")
473
+ continue
327
474
 
328
- # Validate both bounds as numeric or empty (for unbounded ranges)
329
- start_val, end_val = val_str[0].strip(), val_str[1].strip()
330
- if start_val and not (start_val.isdigit() or is_float(start_val)):
331
- raise ValueError(
332
- f"Range operator requires numeric operands for {key}: {value}"
333
- )
334
- if end_val and not (end_val.isdigit() or is_float(end_val)):
335
- raise ValueError(
336
- f"Range operator requires numeric operands for {key}: {value}"
337
- )
475
+ if not isinstance(raw, str):
476
+ raise ValueError(f"Unsupported type for {key!r}: {type(raw).__name__}")
477
+
478
+ val_str = raw.strip()
479
+
480
+ # 3) Range operator
481
+ if (val_str.startswith("[") or val_str.startswith("(")) and (
482
+ val_str.endswith("]") or val_str.endswith(")")
483
+ ):
484
+ start, end, start_incl, end_incl = _parse_range(val_str)
485
+ conds = []
486
+ op1 = ">=" if start_incl else ">"
487
+ op2 = "<=" if end_incl else "<"
488
+ conds.append(f"{prefix}.{key} {op1} {start}")
489
+ conds.append(f"{prefix}.{key} {op2} {end}")
490
+ filter_parts.append("(" + " AND ".join(conds) + ")")
491
+ continue
338
492
 
339
- # Build the SQL condition
340
- range_conditions = []
341
- if start_val:
342
- operator = ">=" if start_inclusive else ">"
343
- range_conditions.append(f"{prefix}.{key} {operator} {start_val}")
344
- if end_val:
345
- operator = "<=" if end_inclusive else "<"
346
- range_conditions.append(f"{prefix}.{key} {operator} {end_val}")
347
-
348
- # Join the range conditions with AND
349
- filter_parts.append("( " + " AND ".join(range_conditions) + " )")
350
- continue
351
-
352
- raise ValueError(f"Range operator requires two values for {key}: {value}")
353
-
354
- # Check if value contains a known comparison operator at the start
355
- matched_operator = None
356
- for op in comparison_operators:
357
- if val_str.startswith(op):
358
- matched_operator = op
359
- break
360
-
361
- # Break down operator from value
362
- # e.g. val_str = ">2022" --> operator = ">", rhs = "2022"
363
- if matched_operator:
364
- rhs = val_str[len(matched_operator) :].strip()
365
-
366
- if matched_operator in numeric_only_ops:
367
- # Must be numeric
368
- if not (rhs.isdigit() or is_float(rhs)):
369
- raise ValueError(
370
- f"Operator {matched_operator} requires a numeric operand for {key}: {val_str}"
371
- )
372
- filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
493
+ # 4) comparison operator
494
+ try:
495
+ op, rhs = _parse_comparison(val_str)
496
+ except ValueError:
497
+ # no operator → treat as membership or equality-on-string
498
+ if is_list:
499
+ filter_parts.append(f"('{val_str}' IN {prefix}.{key})")
373
500
  else:
374
- # = and != operators can be numeric or string
375
- if rhs.isdigit() or is_float(rhs):
376
- filter_parts.append(f"{prefix}.{key}{matched_operator}{rhs}")
377
- elif rhs.lower() in ["true", "false"]:
378
- filter_parts.append(
379
- f"{prefix}.{key}{matched_operator}{rhs.lower()}"
380
- )
381
- else:
382
- # For string operands, wrap them in quotes
383
- filter_parts.append(f"{prefix}.{key}{matched_operator}'{rhs}'")
501
+ filter_parts.append(f"{prefix}.{key}='{val_str}'")
384
502
  else:
385
- if val_str.isdigit() or is_float(val_str):
386
- if is_list:
387
- filter_parts.append(f"({val_str} IN {prefix}.{key})")
388
- else:
389
- filter_parts.append(f"{prefix}.{key}={val_str}")
390
- elif val_str.lower() in ["true", "false"]:
391
- # This is to handle boolean values.
392
- # This is not complete solution - the best solution would be to test if the field is boolean
393
- # That can be done after we move to APIv2
394
- if is_list:
395
- filter_parts.append(f"({val_str.lower()} IN {prefix}.{key})")
396
- else:
397
- filter_parts.append(f"{prefix}.{key}={val_str.lower()}")
503
+ # normal comparison always binds to the field
504
+ if isinstance(rhs, bool):
505
+ rhs_sql = "true" if rhs else "false"
506
+ elif isinstance(rhs, (int, float)):
507
+ rhs_sql = str(rhs)
398
508
  else:
399
- if is_list:
400
- filter_parts.append(f"('{val_str}' IN {prefix}.{key})")
401
- else:
402
- filter_parts.append(f"{prefix}.{key}='{val_str}'")
403
-
404
- filter_str = " AND ".join(filter_parts)
405
- if fixed_filter and filter_str:
406
- return f"({fixed_filter}) AND ({filter_str})"
407
- else:
408
- return fixed_filter or filter_str
509
+ rhs_sql = f"'{rhs}'"
510
+ filter_parts.append(f"{prefix}.{key}{op}{rhs_sql}")
511
+
512
+ joined = " AND ".join(filter_parts)
513
+ if fixed_filter and joined:
514
+ return f"({fixed_filter}) AND ({joined})"
515
+ return fixed_filter or joined
409
516
 
410
517
 
411
518
  class VectaraToolFactory:
@@ -417,25 +524,29 @@ class VectaraToolFactory:
417
524
  self,
418
525
  vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
419
526
  vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
527
+ compact_docstring: bool = False,
420
528
  ) -> None:
421
529
  """
422
530
  Initialize the VectaraToolFactory
423
531
  Args:
424
532
  vectara_corpus_key (str): The Vectara corpus key (or comma separated list of keys).
425
533
  vectara_api_key (str): The Vectara API key.
534
+ compact_docstring (bool): Whether to use a compact docstring format for tools
535
+ This is useful if OpenAI complains on the 1024 token limit.
426
536
  """
427
537
  self.vectara_corpus_key = vectara_corpus_key
428
538
  self.vectara_api_key = vectara_api_key
429
539
  self.num_corpora = len(vectara_corpus_key.split(","))
430
- self.cache_expiry = 60 * 60 # 1 hour
431
- self.max_cache_size = 128
540
+ self.compact_docstring = compact_docstring
432
541
 
433
542
  def create_search_tool(
434
543
  self,
435
544
  tool_name: str,
436
545
  tool_description: str,
437
- tool_args_schema: type[BaseModel],
546
+ tool_args_schema: type[BaseModel] = None,
438
547
  tool_args_type: Dict[str, str] = {},
548
+ summarize_docs: Optional[bool] = None,
549
+ summarize_llm_name: Optional[str] = None,
439
550
  fixed_filter: str = "",
440
551
  lambda_val: Union[List[float], float] = 0.005,
441
552
  semantics: Union[List[str] | str] = "default",
@@ -461,7 +572,7 @@ class VectaraToolFactory:
461
572
  Args:
462
573
  tool_name (str): The name of the tool.
463
574
  tool_description (str): The description of the tool.
464
- tool_args_schema (BaseModel): The schema for the tool arguments.
575
+ tool_args_schema (BaseModel, optional): The schema for the tool arguments.
465
576
  tool_args_type (Dict[str, str], optional): The type of each argument (doc or part).
466
577
  fixed_filter (str, optional): A fixed Vectara filter condition to apply to all queries.
467
578
  lambda_val (Union[List[float] | float], optional): Lambda value (or list of values for each corpora)
@@ -513,7 +624,11 @@ class VectaraToolFactory:
513
624
 
514
625
  query = kwargs.pop("query")
515
626
  top_k = kwargs.pop("top_k", 10)
516
- summarize = kwargs.pop("summarize", True)
627
+ summarize = (
628
+ kwargs.pop("summarize", True)
629
+ if summarize_docs is None
630
+ else summarize_docs
631
+ )
517
632
  try:
518
633
  filter_string = _build_filter_string(
519
634
  kwargs, tool_args_type, fixed_filter
@@ -572,7 +687,10 @@ class VectaraToolFactory:
572
687
  if summarize:
573
688
  summaries_dict = asyncio.run(
574
689
  summarize_documents(
575
- self.vectara_corpus_key, self.vectara_api_key, list(unique_ids)
690
+ corpus_key=self.vectara_corpus_key,
691
+ api_key=self.vectara_api_key,
692
+ llm_name=summarize_llm_name,
693
+ doc_ids=list(unique_ids),
576
694
  )
577
695
  )
578
696
  for doc_id, metadata in docs:
@@ -592,37 +710,49 @@ class VectaraToolFactory:
592
710
  )
593
711
  return out
594
712
 
595
- base_params = [
596
- inspect.Parameter(
597
- "query", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
598
- ),
599
- inspect.Parameter(
600
- "top_k", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int
601
- ),
602
- inspect.Parameter(
603
- "summarize",
604
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
605
- default=True,
606
- annotation=bool,
607
- ),
608
- ]
713
+ class SearchToolBaseParams(BaseModel):
714
+ """Model for the base parameters of the search tool."""
715
+
716
+ query: str = Field(
717
+ ...,
718
+ description="The search query to perform, in the form of a question.",
719
+ )
720
+ top_k: int = Field(
721
+ 10, description="The number of top documents to retrieve."
722
+ )
723
+ summarize: bool = Field(
724
+ True,
725
+ description="Whether to summarize the retrieved documents.",
726
+ )
727
+
728
+ class SearchToolBaseParamsWithoutSummarize(BaseModel):
729
+ """Model for the base parameters of the search tool."""
730
+
731
+ query: str = Field(
732
+ ...,
733
+ description="The search query to perform, in the form of a question.",
734
+ )
735
+ top_k: int = Field(
736
+ 10, description="The number of top documents to retrieve."
737
+ )
738
+
609
739
  search_tool_extra_desc = (
610
740
  tool_description
611
741
  + "\n"
612
- + """
613
- This tool is meant to perform a search for relevant documents, it is not meant for asking questions.
614
- The response includes metadata about each relevant document.
615
- If summarize=True, it also includes a summary of each document, but takes a lot longer to respond,
616
- so avoid using it unless necessary.
617
- """
742
+ + "Use this tool to search for relevant documents, not to ask questions."
618
743
  )
619
744
 
620
745
  tool = _create_tool_from_dynamic_function(
621
746
  search_function,
622
747
  tool_name,
623
748
  search_tool_extra_desc,
624
- base_params,
749
+ (
750
+ SearchToolBaseParams
751
+ if summarize_docs is None
752
+ else SearchToolBaseParamsWithoutSummarize
753
+ ),
625
754
  tool_args_schema,
755
+ compact_docstring=self.compact_docstring,
626
756
  )
627
757
  return tool
628
758
 
@@ -630,7 +760,7 @@ class VectaraToolFactory:
630
760
  self,
631
761
  tool_name: str,
632
762
  tool_description: str,
633
- tool_args_schema: type[BaseModel],
763
+ tool_args_schema: type[BaseModel] = None,
634
764
  tool_args_type: Dict[str, dict] = {},
635
765
  fixed_filter: str = "",
636
766
  vectara_summarizer: str = "vectara-summary-ext-24-05-med-omni",
@@ -652,6 +782,7 @@ class VectaraToolFactory:
652
782
  rerank_chain: List[Dict] = None,
653
783
  max_response_chars: Optional[int] = None,
654
784
  max_tokens: Optional[int] = None,
785
+ llm_name: Optional[str] = None,
655
786
  temperature: Optional[float] = None,
656
787
  frequency_penalty: Optional[float] = None,
657
788
  presence_penalty: Optional[float] = None,
@@ -668,7 +799,7 @@ class VectaraToolFactory:
668
799
  Args:
669
800
  tool_name (str): The name of the tool.
670
801
  tool_description (str): The description of the tool.
671
- tool_args_schema (BaseModel): The schema for the tool arguments.
802
+ tool_args_schema (BaseModel, optional): The schema for any tool arguments for filtering.
672
803
  tool_args_type (Dict[str, dict], optional): attributes for each argument where they key is the field name
673
804
  and the value is a dictionary with the following keys:
674
805
  - 'type': the type of each filter attribute in Vectara (doc or part).
@@ -699,6 +830,7 @@ class VectaraToolFactory:
699
830
  If using slingshot/multilingual_reranker_v1, it must be first in the list.
700
831
  max_response_chars (int, optional): The desired maximum number of characters for the generated summary.
701
832
  max_tokens (int, optional): The maximum number of tokens to be returned by the LLM.
833
+ llm_name (str, optional): The name of the LLM to use for generation.
702
834
  temperature (float, optional): The sampling temperature; higher values lead to more randomness.
703
835
  frequency_penalty (float, optional): How much to penalize repeating tokens in the response,
704
836
  higher values reducing likelihood of repeating the same line.
@@ -776,6 +908,7 @@ class VectaraToolFactory:
776
908
  filter=filter_string,
777
909
  max_response_chars=max_response_chars,
778
910
  max_tokens=max_tokens,
911
+ llm_name=llm_name,
779
912
  temperature=temperature,
780
913
  frequency_penalty=frequency_penalty,
781
914
  presence_penalty=presence_penalty,
@@ -852,17 +985,21 @@ class VectaraToolFactory:
852
985
  )
853
986
  return out
854
987
 
855
- base_params = [
856
- inspect.Parameter(
857
- "query", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
858
- ),
859
- ]
988
+ class RagToolBaseParams(BaseModel):
989
+ """Model for the base parameters of the RAG tool."""
990
+
991
+ query: str = Field(
992
+ ...,
993
+ description="The search query to perform, in the form of a question",
994
+ )
995
+
860
996
  tool = _create_tool_from_dynamic_function(
861
997
  rag_function,
862
998
  tool_name,
863
999
  tool_description,
864
- base_params,
1000
+ RagToolBaseParams,
865
1001
  tool_args_schema,
1002
+ compact_docstring=self.compact_docstring,
866
1003
  )
867
1004
  return tool
868
1005