vectara-agentic 0.1.28__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

vectara_agentic/agent.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """
2
2
  This module contains the Agent class for handling different types of agents and their interactions.
3
3
  """
4
- from typing import List, Callable, Optional, Dict, Any
4
+ from typing import List, Callable, Optional, Dict, Any, Union, Tuple
5
5
  import os
6
+ import re
6
7
  from datetime import date
7
8
  import time
8
9
  import json
@@ -10,12 +11,15 @@ import logging
10
11
  import traceback
11
12
  import asyncio
12
13
 
13
- import dill
14
+ import cloudpickle as pickle
15
+
14
16
  from dotenv import load_dotenv
15
17
 
16
18
  from retrying import retry
17
19
  from pydantic import Field, create_model
18
20
 
21
+ from llama_index.core.memory import ChatMemoryBuffer
22
+ from llama_index.core.llms import ChatMessage, MessageRole
19
23
  from llama_index.core.tools import FunctionTool
20
24
  from llama_index.core.agent import ReActAgent
21
25
  from llama_index.core.agent.react.formatter import ReActChatFormatter
@@ -24,7 +28,7 @@ from llama_index.agent.lats import LATSAgentWorker
24
28
  from llama_index.core.callbacks import CallbackManager, TokenCountingHandler
25
29
  from llama_index.core.callbacks.base_handler import BaseCallbackHandler
26
30
  from llama_index.agent.openai import OpenAIAgent
27
- from llama_index.core.memory import ChatMemoryBuffer
31
+
28
32
 
29
33
  from .types import AgentType, AgentStatusType, LLMRole, ToolType, AgentResponse, AgentStreamingResponse
30
34
  from .utils import get_llm, get_tokenizer_for_model
@@ -35,6 +39,21 @@ from .tools import VectaraToolFactory, VectaraTool, ToolsFactory
35
39
  from .tools_catalog import get_current_date
36
40
  from .agent_config import AgentConfig
37
41
 
42
+ class IgnoreUnpickleableAttributeFilter(logging.Filter):
43
+ '''
44
+ Filter to ignore log messages that contain certain strings
45
+ '''
46
+ def filter(self, record):
47
+ msgs_to_ignore = [
48
+ "Removing unpickleable private attribute _chunking_tokenizer_fn",
49
+ "Removing unpickleable private attribute _split_fns",
50
+ "Removing unpickleable private attribute _sub_sentence_split_fns",
51
+ ]
52
+ return all(msg not in record.getMessage() for msg in msgs_to_ignore)
53
+
54
+
55
+ logging.getLogger().addFilter(IgnoreUnpickleableAttributeFilter())
56
+
38
57
  logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
39
58
  logger.setLevel(logging.CRITICAL)
40
59
 
@@ -81,6 +100,34 @@ def _retry_if_exception(exception):
81
100
  return isinstance(exception, (TimeoutError))
82
101
 
83
102
 
103
+ def get_field_type(field_schema: dict) -> Any:
104
+ """
105
+ Convert a JSON schema field definition to a Python type.
106
+ Handles 'type' and 'anyOf' cases.
107
+ """
108
+ json_type_to_python = {
109
+ "string": str,
110
+ "integer": int,
111
+ "boolean": bool,
112
+ "array": list,
113
+ "object": dict,
114
+ "number": float,
115
+ }
116
+ if "anyOf" in field_schema:
117
+ types = []
118
+ for option in field_schema["anyOf"]:
119
+ # If the option has a type, convert it; otherwise, use Any.
120
+ if "type" in option:
121
+ types.append(json_type_to_python.get(option["type"], Any))
122
+ else:
123
+ types.append(Any)
124
+ # Return a Union of the types. For example, Union[str, int]
125
+ return Union[tuple(types)]
126
+ elif "type" in field_schema:
127
+ return json_type_to_python.get(field_schema["type"], Any)
128
+ else:
129
+ return Any
130
+
84
131
  class Agent:
85
132
  """
86
133
  Agent class for handling different types of agents and their interactions.
@@ -96,6 +143,7 @@ class Agent:
96
143
  agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
97
144
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
98
145
  agent_config: Optional[AgentConfig] = None,
146
+ chat_history: Optional[list[Tuple[str, str]]] = None,
99
147
  ) -> None:
100
148
  """
101
149
  Initialize the agent with the specified type, tools, topic, and system message.
@@ -111,10 +159,13 @@ class Agent:
111
159
  query_logging_callback (Callable): A callback function the code calls upon completion of a query
112
160
  agent_config (AgentConfig, optional): The configuration of the agent.
113
161
  Defaults to AgentConfig(), which reads from environment variables.
162
+ chat_history (Tuple[str, str], optional): A list of user/agent chat pairs to initialize the agent memory.
114
163
  """
115
164
  self.agent_config = agent_config or AgentConfig()
116
165
  self.agent_type = self.agent_config.agent_type
117
- self.tools = tools + [ToolsFactory().create_tool(get_current_date)]
166
+ self.tools = tools
167
+ if not any(tool.metadata.name == 'get_current_date' for tool in self.tools):
168
+ self.tools += [ToolsFactory().create_tool(get_current_date)]
118
169
  self.llm = get_llm(LLMRole.MAIN, config=self.agent_config)
119
170
  self._custom_instructions = custom_instructions
120
171
  self._topic = topic
@@ -135,7 +186,14 @@ class Agent:
135
186
  self.llm.callback_manager = callback_manager
136
187
  self.verbose = verbose
137
188
 
138
- self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
189
+ if chat_history:
190
+ msg_history = []
191
+ for inx, text in enumerate(chat_history):
192
+ role = MessageRole.USER if inx % 2 == 0 else MessageRole.ASSISTANT
193
+ msg_history.append(ChatMessage.from_str(content=text, role=role))
194
+ self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000, chat_history=msg_history)
195
+ else:
196
+ self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000)
139
197
  if self.agent_type == AgentType.REACT:
140
198
  prompt = _get_prompt(REACT_PROMPT_TEMPLATE, topic, custom_instructions)
141
199
  self.agent = ReActAgent.from_tools(
@@ -219,7 +277,10 @@ class Agent:
219
277
 
220
278
  # Compare tools
221
279
  if self.tools != other.tools:
222
- print(f"Comparison failed: tools differ. (self.tools: {self.tools}, other.tools: {other.tools})")
280
+ print(
281
+ "Comparison failed: tools differ."
282
+ f"(self.tools: {[t.metadata.name for t in self.tools]}, "
283
+ f"other.tools: {[t.metadata.name for t in other.tools]})")
223
284
  return False
224
285
 
225
286
  # Compare topic
@@ -263,6 +324,7 @@ class Agent:
263
324
  agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
264
325
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
265
326
  agent_config: AgentConfig = AgentConfig(),
327
+ chat_history: Optional[list[Tuple[str, str]]] = None,
266
328
  ) -> "Agent":
267
329
  """
268
330
  Create an agent from tools, agent type, and language model.
@@ -277,6 +339,7 @@ class Agent:
277
339
  update_func (Callable): old name for agent_progress_callback. Will be deprecated in future.
278
340
  query_logging_callback (Callable): A callback function the code calls upon completion of a query
279
341
  agent_config (AgentConfig, optional): The configuration of the agent.
342
+ chat_history (Tuple[str, str], optional): A list of user/agent chat pairs to initialize the agent memory.
280
343
 
281
344
  Returns:
282
345
  Agent: An instance of the Agent class.
@@ -285,7 +348,8 @@ class Agent:
285
348
  tools=tools, topic=topic, custom_instructions=custom_instructions,
286
349
  verbose=verbose, agent_progress_callback=agent_progress_callback,
287
350
  query_logging_callback=query_logging_callback,
288
- update_func=update_func, agent_config=agent_config
351
+ update_func=update_func, agent_config=agent_config,
352
+ chat_history=chat_history,
289
353
  )
290
354
 
291
355
  @classmethod
@@ -294,28 +358,42 @@ class Agent:
294
358
  tool_name: str,
295
359
  data_description: str,
296
360
  assistant_specialty: str,
297
- vectara_customer_id: str = str(os.environ.get("VECTARA_CUSTOMER_ID", "")),
298
- vectara_corpus_id: str = str(os.environ.get("VECTARA_CORPUS_ID", "")),
361
+ vectara_corpus_key: str = str(os.environ.get("VECTARA_CORPUS_KEY", "")),
299
362
  vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
300
363
  agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None,
301
364
  query_logging_callback: Optional[Callable[[str, str], None]] = None,
302
365
  verbose: bool = False,
303
366
  vectara_filter_fields: list[dict] = [],
367
+ vectara_offset: int = 0,
304
368
  vectara_lambda_val: float = 0.005,
305
- vectara_reranker: str = "mmr",
369
+ vectara_semantics: str = "default",
370
+ vectara_custom_dimensions: Dict = {},
371
+ vectara_reranker: str = "slingshot",
306
372
  vectara_rerank_k: int = 50,
373
+ vectara_rerank_limit: Optional[int] = None,
374
+ vectara_rerank_cutoff: Optional[float] = None,
375
+ vectara_diversity_bias: float = 0.2,
376
+ vectara_udf_expression: str = None,
377
+ vectara_rerank_chain: List[Dict] = None,
307
378
  vectara_n_sentences_before: int = 2,
308
379
  vectara_n_sentences_after: int = 2,
309
380
  vectara_summary_num_results: int = 10,
310
- vectara_summarizer: str = "vectara-summary-ext-24-05-sml",
381
+ vectara_summarizer: str = "vectara-summary-ext-24-05-med-omni",
382
+ vectara_summary_response_language: str = "eng",
383
+ vectara_summary_prompt_text: Optional[str] = None,
384
+ vectara_max_response_chars: Optional[int] = None,
385
+ vectara_max_tokens: Optional[int] = None,
386
+ vectara_temperature: Optional[float] = None,
387
+ vectara_frequency_penalty: Optional[float] = None,
388
+ vectara_presence_penalty: Optional[float] = None,
389
+ vectara_save_history: bool = True,
311
390
  ) -> "Agent":
312
391
  """
313
392
  Create an agent from a single Vectara corpus
314
393
 
315
394
  Args:
316
395
  tool_name (str): The name of Vectara tool used by the agent
317
- vectara_customer_id (str): The Vectara customer ID.
318
- vectara_corpus_id (str): The Vectara corpus ID (or comma separated list of IDs).
396
+ vectara_corpus_key (str): The Vectara corpus key (or comma separated list of keys).
319
397
  vectara_api_key (str): The Vectara API key.
320
398
  agent_progress_callback (Callable): A callback function the code calls on any agent updates.
321
399
  query_logging_callback (Callable): A callback function the code calls upon completion of a query
@@ -324,21 +402,41 @@ class Agent:
324
402
  verbose (bool, optional): Whether to print verbose output.
325
403
  vectara_filter_fields (List[dict], optional): The filterable attributes
326
404
  (each dict maps field name to Tuple[type, description]).
327
- vectara_lambda_val (float, optional): The lambda value for Vectara hybrid search.
328
- vectara_reranker (str, optional): The Vectara reranker name (default "mmr")
405
+ vectara_offset (int, optional): Number of results to skip.
406
+ vectara_lambda_val (float, optional): Lambda value for Vectara hybrid search.
407
+ vectara_semantics: (str, optional): Indicates whether the query is intended as a query or response.
408
+ vectara_custom_dimensions: (Dict, optional): Custom dimensions for the query.
409
+ vectara_reranker (str, optional): The Vectara reranker name (default "slingshot")
329
410
  vectara_rerank_k (int, optional): The number of results to use with reranking.
411
+ vetara_rerank_limit: (int, optional): The maximum number of results to return after reranking.
412
+ vectara_rerank_cutoff: (float, optional): The minimum score threshold for results to include after
413
+ reranking.
414
+ vectara_diversity_bias (float, optional): The MMR diversity bias.
415
+ vectara_udf_expression (str, optional): The user defined expression for reranking results.
416
+ vectara_rerank_chain (List[Dict], optional): A list of Vectara rerankers to be applied sequentially.
330
417
  vectara_n_sentences_before (int, optional): The number of sentences before the matching text
331
418
  vectara_n_sentences_after (int, optional): The number of sentences after the matching text.
332
419
  vectara_summary_num_results (int, optional): The number of results to use in summarization.
333
420
  vectara_summarizer (str, optional): The Vectara summarizer name.
421
+ vectara_summary_response_language (str, optional): The response language for the Vectara summary.
422
+ vectara_summary_prompt_text (str, optional): The custom prompt, using appropriate prompt variables and
423
+ functions.
424
+ vectara_max_response_chars (int, optional): The desired maximum number of characters for the generated
425
+ summary.
426
+ vectara_max_tokens (int, optional): The maximum number of tokens to be returned by the LLM.
427
+ vectara_temperature (float, optional): The sampling temperature; higher values lead to more randomness.
428
+ vectara_frequency_penalty (float, optional): How much to penalize repeating tokens in the response,
429
+ higher values reducing likelihood of repeating the same line.
430
+ vectara_presence_penalty (float, optional): How much to penalize repeating tokens in the response,
431
+ higher values increasing the diversity of topics.
432
+ vectara_save_history (bool, optional): Whether to save the query in history.
334
433
 
335
434
  Returns:
336
435
  Agent: An instance of the Agent class.
337
436
  """
338
437
  vec_factory = VectaraToolFactory(
339
438
  vectara_api_key=vectara_api_key,
340
- vectara_customer_id=vectara_customer_id,
341
- vectara_corpus_id=vectara_corpus_id,
439
+ vectara_corpus_key=vectara_corpus_key,
342
440
  )
343
441
  field_definitions = {}
344
442
  field_definitions["query"] = (str, Field(description="The user query")) # type: ignore
@@ -349,8 +447,12 @@ class Agent:
349
447
  ) # type: ignore
350
448
  query_args = create_model("QueryArgs", **field_definitions) # type: ignore
351
449
 
450
+ # tool name must be valid Python function name
451
+ if tool_name:
452
+ tool_name = re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
453
+
352
454
  vectara_tool = vec_factory.create_rag_tool(
353
- tool_name=tool_name or f"vectara_{vectara_corpus_id}",
455
+ tool_name=tool_name or f"vectara_{vectara_corpus_key}",
354
456
  tool_description=f"""
355
457
  Given a user query,
356
458
  returns a response (str) to a user question about {data_description}.
@@ -358,11 +460,27 @@ class Agent:
358
460
  tool_args_schema=query_args,
359
461
  reranker=vectara_reranker,
360
462
  rerank_k=vectara_rerank_k,
463
+ rerank_limit=vectara_rerank_limit,
464
+ rerank_cutoff=vectara_rerank_cutoff,
465
+ mmr_diversity_bias=vectara_diversity_bias,
466
+ udf_expression=vectara_udf_expression,
467
+ rerank_chain=vectara_rerank_chain,
361
468
  n_sentences_before=vectara_n_sentences_before,
362
469
  n_sentences_after=vectara_n_sentences_after,
470
+ offset=vectara_offset,
363
471
  lambda_val=vectara_lambda_val,
472
+ semantics=vectara_semantics,
473
+ custom_dimensions=vectara_custom_dimensions,
364
474
  summary_num_results=vectara_summary_num_results,
365
475
  vectara_summarizer=vectara_summarizer,
476
+ summary_response_lang=vectara_summary_response_language,
477
+ vectara_prompt_text=vectara_summary_prompt_text,
478
+ max_response_chars=vectara_max_response_chars,
479
+ max_tokens=vectara_max_tokens,
480
+ temperature=vectara_temperature,
481
+ frequency_penalty=vectara_frequency_penalty,
482
+ presence_penalty=vectara_presence_penalty,
483
+ save_history=vectara_save_history,
366
484
  include_citations=True,
367
485
  verbose=verbose,
368
486
  )
@@ -534,12 +652,13 @@ class Agent:
534
652
 
535
653
  for tool in self.tools:
536
654
  # Serialize each tool's metadata, function, and dynamic model schema (QueryArgs)
655
+ # TODO: deal with tools that have weakref (e.g. db_tools); for now those cannot be serialized.
537
656
  tool_dict = {
538
657
  "tool_type": tool.metadata.tool_type.value,
539
658
  "name": tool.metadata.name,
540
659
  "description": tool.metadata.description,
541
- "fn": dill.dumps(tool.fn).decode("latin-1") if tool.fn else None, # Serialize fn
542
- "async_fn": dill.dumps(tool.async_fn).decode("latin-1")
660
+ "fn": pickle.dumps(tool.fn).decode("latin-1") if tool.fn else None, # Serialize fn
661
+ "async_fn": pickle.dumps(tool.async_fn).decode("latin-1")
543
662
  if tool.async_fn
544
663
  else None, # Serialize async_fn
545
664
  "fn_schema": tool.metadata.fn_schema.model_json_schema()
@@ -550,7 +669,7 @@ class Agent:
550
669
 
551
670
  return {
552
671
  "agent_type": self.agent_type.value,
553
- "memory": dill.dumps(self.agent.memory).decode("latin-1"),
672
+ "memory": pickle.dumps(self.agent.memory).decode("latin-1"),
554
673
  "tools": tool_info,
555
674
  "topic": self._topic,
556
675
  "custom_instructions": self._custom_instructions,
@@ -564,39 +683,30 @@ class Agent:
564
683
  agent_config = AgentConfig.from_dict(data["agent_config"])
565
684
  tools = []
566
685
 
567
- json_type_to_python = {
568
- "string": str,
569
- "integer": int,
570
- "boolean": bool,
571
- "array": list,
572
- "object": dict,
573
- "number": float,
574
- }
575
-
576
686
  for tool_data in data["tools"]:
577
687
  # Recreate the dynamic model using the schema info
578
688
  if tool_data.get("fn_schema"):
579
689
  field_definitions = {}
580
690
  for field, values in tool_data["fn_schema"]["properties"].items():
691
+ # Instead of checking for 'type', use the helper:
692
+ field_type = get_field_type(values)
693
+ # If there's a default value, include it.
581
694
  if "default" in values:
582
695
  field_definitions[field] = (
583
- json_type_to_python.get(values["type"], values["type"]),
584
- Field(
585
- description=values["description"],
586
- default=values["default"],
587
- ),
588
- ) # type: ignore
696
+ field_type,
697
+ Field(description=values.get("description", ""), default=values["default"]),
698
+ )
589
699
  else:
590
700
  field_definitions[field] = (
591
- json_type_to_python.get(values["type"], values["type"]),
592
- Field(description=values["description"]),
593
- ) # type: ignore
701
+ field_type,
702
+ Field(description=values.get("description", "")),
703
+ )
594
704
  query_args_model = create_model("QueryArgs", **field_definitions) # type: ignore
595
705
  else:
596
706
  query_args_model = create_model("QueryArgs")
597
707
 
598
- fn = dill.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None
599
- async_fn = dill.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None
708
+ fn = pickle.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None
709
+ async_fn = pickle.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None
600
710
 
601
711
  tool = VectaraTool.from_defaults(
602
712
  name=tool_data["name"],
@@ -615,7 +725,7 @@ class Agent:
615
725
  custom_instructions=data["custom_instructions"],
616
726
  verbose=data["verbose"],
617
727
  )
618
- memory = dill.loads(data["memory"].encode("latin-1")) if data.get("memory") else None
728
+ memory = pickle.loads(data["memory"].encode("latin-1")) if data.get("memory") else None
619
729
  if memory:
620
730
  agent.agent.memory = memory
621
731
  return agent
@@ -44,6 +44,15 @@ class AgentConfig:
44
44
  default_factory=lambda: os.getenv("VECTARA_AGENTIC_TOOL_MODEL_NAME", "")
45
45
  )
46
46
 
47
+ # Params for Private LLM endpoint if used
48
+ private_llm_api_base: str = field(
49
+ default_factory=lambda: os.getenv("VECTARA_AGENTIC_PRIVATE_LLM_API_BASE",
50
+ "http://private-endpoint.company.com:5000/v1")
51
+ )
52
+ private_llm_api_key: str = field(
53
+ default_factory=lambda: os.getenv("VECTARA_AGENTIC_PRIVATE_LLM_API_KEY", "<private-api-key>")
54
+ )
55
+
47
56
  # Observer
48
57
  observer: ObserverType = field(
49
58
  default_factory=lambda: ObserverType(
@@ -74,7 +74,7 @@ class DBLoadUniqueValues(DBTool):
74
74
  """
75
75
  A tool to list all unique values for each column in a set of columns of a database table.
76
76
  """
77
- def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> dict:
77
+ def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
78
78
  """
79
79
  Fetches the first num_vals unique values from the specified columns of the database table.
80
80
 
@@ -84,7 +84,7 @@ class DBLoadUniqueValues(DBTool):
84
84
  num_vals (int): The number of unique values to fetch for each column. Default is 200.
85
85
 
86
86
  Returns:
87
- dict: A dictionary containing the unique values for each column.
87
+ Any: the result of the database query
88
88
  """
89
89
  res = {}
90
90
  try: