vectara-agentic 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

tests/test_agent.py CHANGED
@@ -106,43 +106,6 @@ class TestAgentPackage(unittest.TestCase):
106
106
  self.assertIsInstance(agent, Agent)
107
107
  self.assertEqual(agent._topic, "question answering")
108
108
 
109
- def test_serialization(self):
110
- with ARIZE_LOCK:
111
- config = AgentConfig(
112
- agent_type=AgentType.REACT,
113
- main_llm_provider=ModelProvider.ANTHROPIC,
114
- tool_llm_provider=ModelProvider.TOGETHER,
115
- observer=ObserverType.ARIZE_PHOENIX
116
- )
117
-
118
- agent = Agent.from_corpus(
119
- tool_name="RAG Tool",
120
- agent_config=config,
121
- vectara_corpus_key="corpus_key",
122
- vectara_api_key="api_key",
123
- data_description="information",
124
- assistant_specialty="question answering",
125
- )
126
-
127
- agent_reloaded = agent.loads(agent.dumps())
128
- agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
129
-
130
- self.assertIsInstance(agent_reloaded, Agent)
131
- self.assertEqual(agent, agent_reloaded)
132
- self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
133
-
134
- self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
135
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
136
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
137
-
138
- self.assertIsInstance(agent_reloaded, Agent)
139
- self.assertEqual(agent, agent_reloaded_again)
140
- self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
141
-
142
- self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
143
- self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
144
- self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
145
-
146
109
  def test_chat_history(self):
147
110
  tools = [ToolsFactory().create_tool(mult)]
148
111
  topic = "AI topic"
@@ -0,0 +1,110 @@
1
+ import unittest
2
+ import threading
3
+ import os
4
+
5
+ from vectara_agentic.agent import Agent, AgentType
6
+ from vectara_agentic.agent_config import AgentConfig
7
+ from vectara_agentic.types import ModelProvider, ObserverType
8
+ from vectara_agentic.tools import ToolsFactory
9
+
10
+ from llama_index.core.utilities.sql_wrapper import SQLDatabase
11
+ from sqlalchemy import create_engine
12
+
13
+ def mult(x: float, y: float) -> float:
14
+ return x * y
15
+
16
+
17
+ ARIZE_LOCK = threading.Lock()
18
+
19
+ class TestAgentSerialization(unittest.TestCase):
20
+
21
+ @classmethod
22
+ def tearDown(cls):
23
+ try:
24
+ os.remove('ev_database.db')
25
+ except FileNotFoundError:
26
+ pass
27
+
28
+ def test_serialization(self):
29
+ with ARIZE_LOCK:
30
+ config = AgentConfig(
31
+ agent_type=AgentType.REACT,
32
+ main_llm_provider=ModelProvider.ANTHROPIC,
33
+ tool_llm_provider=ModelProvider.TOGETHER,
34
+ observer=ObserverType.ARIZE_PHOENIX
35
+ )
36
+ db_tools = ToolsFactory().database_tools(
37
+ tool_name_prefix = "ev",
38
+ content_description = 'Electric Vehicles in the state of Washington and other population information',
39
+ sql_database = SQLDatabase(create_engine('sqlite:///ev_database.db')),
40
+ )
41
+
42
+ tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
43
+ topic = "AI topic"
44
+ instructions = "Always do as your father tells you, if your mother agrees!"
45
+ agent = Agent(
46
+ tools=tools,
47
+ topic=topic,
48
+ custom_instructions=instructions,
49
+ agent_config=config
50
+ )
51
+
52
+ agent_reloaded = agent.loads(agent.dumps())
53
+ agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
54
+
55
+ self.assertIsInstance(agent_reloaded, Agent)
56
+ self.assertEqual(agent, agent_reloaded)
57
+ self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
58
+
59
+ self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
60
+ self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
61
+ self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
62
+
63
+ self.assertIsInstance(agent_reloaded, Agent)
64
+ self.assertEqual(agent, agent_reloaded_again)
65
+ self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
66
+
67
+ self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
68
+ self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
69
+ self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
70
+
71
+ def test_serialization_from_corpus(self):
72
+ with ARIZE_LOCK:
73
+ config = AgentConfig(
74
+ agent_type=AgentType.REACT,
75
+ main_llm_provider=ModelProvider.ANTHROPIC,
76
+ tool_llm_provider=ModelProvider.TOGETHER,
77
+ observer=ObserverType.ARIZE_PHOENIX
78
+ )
79
+
80
+ agent = Agent.from_corpus(
81
+ tool_name="RAG Tool",
82
+ agent_config=config,
83
+ vectara_corpus_key="corpus_key",
84
+ vectara_api_key="api_key",
85
+ data_description="information",
86
+ assistant_specialty="question answering",
87
+ )
88
+
89
+ agent_reloaded = agent.loads(agent.dumps())
90
+ agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
91
+
92
+ self.assertIsInstance(agent_reloaded, Agent)
93
+ self.assertEqual(agent, agent_reloaded)
94
+ self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
95
+
96
+ self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
97
+ self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
98
+ self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
99
+
100
+ self.assertIsInstance(agent_reloaded, Agent)
101
+ self.assertEqual(agent, agent_reloaded_again)
102
+ self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
103
+
104
+ self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
105
+ self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
106
+ self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ unittest.main()
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Define the version of the package.
3
3
  """
4
- __version__ = "0.2.7"
4
+ __version__ = "0.2.8"
vectara_agentic/agent.py CHANGED
@@ -768,6 +768,7 @@ class Agent:
768
768
  """
769
769
  max_attempts = 4 if self.fallback_agent_config else 2
770
770
  attempt = 0
771
+ orig_llm = self.llm.metadata.model_name
771
772
  while attempt < max_attempts:
772
773
  try:
773
774
  current_agent = self._get_current_agent()
@@ -788,16 +789,20 @@ class Agent:
788
789
  agent_response.async_response_gen = _stream_response_wrapper # Override the generator
789
790
  return agent_response
790
791
 
791
- except Exception:
792
+ except Exception as e:
793
+ last_error = e
792
794
  if attempt >= 2:
793
795
  if self.verbose:
794
- print("LLM call failed. Switching agent configuration.")
796
+ print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
795
797
  self._switch_agent_config()
796
798
  time.sleep(1)
797
799
  attempt += 1
798
800
 
799
801
  return AgentResponse(
800
- response=f"LLM failure can't be resolved after {max_attempts} attempts."
802
+ response=(
803
+ f"For {orig_llm} LLM - failure can't be resolved after "
804
+ f"{max_attempts} attempts ({last_error})."
805
+ )
801
806
  )
802
807
 
803
808
  #
@@ -861,7 +866,6 @@ class Agent:
861
866
  def to_dict(self) -> Dict[str, Any]:
862
867
  """Serialize the Agent instance to a dictionary."""
863
868
  tool_info = []
864
-
865
869
  for tool in self.tools:
866
870
  if hasattr(tool.metadata, "fn_schema"):
867
871
  fn_schema_cls = tool.metadata.fn_schema
@@ -1,38 +1,122 @@
1
1
  """
2
- This module contains the code to extend and improve DatabaseToolSpec
3
- Specifically adding load_sample_data and load_unique_values methods, as well as
4
- making sure the load_data method returns a list of text values from the database, not Document[] objects.
2
+ This module contains the code adapted from DatabaseToolSpec
3
+ It makes the following adjustments:
4
+ * Adds load_sample_data and load_unique_values methods.
5
+ * Fixes serialization.
6
+ * Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
7
+ * Limits the returned rows to self.max_rows.
5
8
  """
6
- from abc import ABC
7
- from typing import Callable, Any
9
+ from typing import Any, Optional, List, Awaitable, Callable
10
+ import asyncio
11
+ from inspect import signature
8
12
 
9
- #
10
- # Additional database tool
11
- #
12
- class DBTool(ABC):
13
- """
14
- A base class for vectara-agentic database tools extensions
13
+ from sqlalchemy import MetaData, text
14
+ from sqlalchemy.engine import Engine
15
+ from sqlalchemy.exc import NoSuchTableError
16
+ from sqlalchemy.schema import CreateTable
17
+
18
+ from llama_index.core.readers.base import BaseReader
19
+ from llama_index.core.utilities.sql_wrapper import SQLDatabase
20
+ from llama_index.core.schema import Document
21
+ from llama_index.core.tools.function_tool import FunctionTool
22
+ from llama_index.core.tools.types import ToolMetadata
23
+ from llama_index.core.tools.utils import create_schema_from_function
24
+
25
+
26
+ AsyncCallable = Callable[..., Awaitable[Any]]
27
+
28
+ class DatabaseTools(BaseReader):
29
+ """Database tools for vectara-agentic
30
+ This class provides a set of tools to interact with a database.
31
+ It allows you to load data, list tables, describe tables, and load unique values.
32
+ It also provides a method to load sample data from a specified table.
15
33
  """
16
- def __init__(self, load_data_fn: Callable, max_rows: int = 1000):
17
- self.load_data_fn = load_data_fn
34
+ spec_functions = [
35
+ "load_data", "load_sample_data", "list_tables",
36
+ "describe_tables", "load_unique_values",
37
+ ]
38
+
39
+ def __init__(
40
+ self,
41
+ *args: Any,
42
+ max_rows: int = 1000,
43
+ sql_database: Optional[SQLDatabase] = None,
44
+ engine: Optional[Engine] = None,
45
+ uri: Optional[str] = None,
46
+ scheme: Optional[str] = None,
47
+ host: Optional[str] = None,
48
+ port: Optional[str] = None,
49
+ user: Optional[str] = None,
50
+ password: Optional[str] = None,
51
+ dbname: Optional[str] = None,
52
+ **kwargs: Any,
53
+ ) -> None:
18
54
  self.max_rows = max_rows
19
55
 
20
- class DBLoadData(DBTool):
21
- """
22
- A tool to Run SQL query on the database and return the result.
23
- """
24
- def __call__(self, query: str) -> Any:
25
- """Query and load data from the Database, returning a list of Documents.
56
+ if sql_database:
57
+ self.sql_database = sql_database
58
+ elif engine:
59
+ self.sql_database = SQLDatabase(engine, *args, **kwargs)
60
+ elif uri:
61
+ self.uri = uri
62
+ self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
63
+ elif (scheme and host and port and user and password and dbname):
64
+ uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
65
+ self.uri = uri
66
+ self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
67
+ else:
68
+ raise ValueError(
69
+ "You must provide either a SQLDatabase, "
70
+ "a SQL Alchemy Engine, a valid connection URI, or a valid "
71
+ "set of credentials."
72
+ )
73
+ self._uri = getattr(self, "uri", None) or str(self.sql_database.engine.url)
74
+ self._metadata = MetaData()
75
+ self._metadata.reflect(bind=self.sql_database.engine)
76
+
77
+ def _get_metadata_from_fn_name(
78
+ self, fn_name: Callable,
79
+ ) -> Optional[ToolMetadata]:
80
+ """Return map from function name.
81
+
82
+ Return type is Optional, meaning that the schema can be None.
83
+ In this case, it's up to the downstream tool implementation to infer the schema.
84
+ """
85
+ try:
86
+ func = getattr(self, fn_name)
87
+ except AttributeError:
88
+ return None
89
+ name = fn_name
90
+ docstring = func.__doc__ or ""
91
+ description = f"{name}{signature(func)}\n{docstring}"
92
+ fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
93
+ return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
94
+
95
+ def _load_data(self, query: str) -> List[Document]:
96
+ documents = []
97
+ with self.sql_database.engine.connect() as connection:
98
+ if query is None:
99
+ raise ValueError("A query parameter is necessary to filter the data")
100
+ result = connection.execute(text(query))
101
+ for item in result.fetchall():
102
+ doc_str = ", ".join([str(entry) for entry in item])
103
+ documents.append(Document(text=doc_str))
104
+ return documents
26
105
 
106
+ def load_data(self, *args: Any, **load_kwargs: Any) -> List[str]:
107
+ """Query and load data from the Database, returning a list of Documents.
27
108
  Args:
28
109
  query (str): an SQL query to filter tables and rows.
29
-
30
110
  Returns:
31
- List[text]: a list of text values from the database.
111
+ List[Document]: a list of Document objects from the database.
32
112
  """
113
+ query = args[0] if args else load_kwargs.get("args",{}).get("query")
114
+ if query is None:
115
+ raise ValueError("A query parameter is necessary to filter the data")
116
+
33
117
  count_query = f"SELECT COUNT(*) FROM ({query})"
34
118
  try:
35
- count_rows = self.load_data_fn(count_query)
119
+ count_rows = self._load_data(count_query)
36
120
  except Exception as e:
37
121
  return [f"Error ({str(e)}) occurred while counting number of rows"]
38
122
  num_rows = int(count_rows[0].text)
@@ -42,19 +126,12 @@ class DBLoadData(DBTool):
42
126
  "Please refactor your query to make it return less rows. "
43
127
  ]
44
128
  try:
45
- res = self.load_data_fn(query)
129
+ res = self._load_data(query)
46
130
  except Exception as e:
47
131
  return [f"Error ({str(e)}) occurred while executing the query {query}"]
48
132
  return [d.text for d in res]
49
133
 
50
- class DBLoadSampleData(DBTool):
51
- """
52
- A tool to load a sample of data from the specified database table.
53
-
54
- This tool fetches the first num_rows (default 25) rows from the given table
55
- using a provided database query function.
56
- """
57
- def __call__(self, table_name: str, num_rows: int = 25) -> Any:
134
+ def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
58
135
  """
59
136
  Fetches the first num_rows rows from the specified database table.
60
137
 
@@ -65,16 +142,39 @@ class DBLoadSampleData(DBTool):
65
142
  Any: The result of the database query.
66
143
  """
67
144
  try:
68
- res = self.load_data_fn(f"SELECT * FROM {table_name} LIMIT {num_rows}")
145
+ res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
69
146
  except Exception as e:
70
147
  return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
71
- return res
148
+ return [d.text for d in res]
72
149
 
73
- class DBLoadUniqueValues(DBTool):
74
- """
75
- A tool to list all unique values for each column in a set of columns of a database table.
76
- """
77
- def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
150
+ def list_tables(self) -> List[str]:
151
+ """List all tables in the database.
152
+ Returns:
153
+ List[str]: A list of table names in the database.
154
+ """
155
+ return [x.name for x in self._metadata.sorted_tables]
156
+
157
+ def describe_tables(self, tables: Optional[List[str]] = None) -> str:
158
+ """Describe the tables in the database.
159
+ Args:
160
+ tables (Optional[List[str]]): A list of table names to describe. If None, all tables are described.
161
+ Returns:
162
+ str: A string representation of the table schemas.
163
+ """
164
+ table_names = tables or [table.name for table in self._metadata.sorted_tables]
165
+ table_schemas = []
166
+ for table_name in table_names:
167
+ table = next(
168
+ (table for table in self._metadata.sorted_tables if table.name == table_name),
169
+ None,
170
+ )
171
+ if table is None:
172
+ raise NoSuchTableError(f"Table '{table_name}' does not exist.")
173
+ schema = str(CreateTable(table).compile(self.sql_database.engine))
174
+ table_schemas.append(f"{schema}\n")
175
+ return "\n".join(table_schemas)
176
+
177
+ def load_unique_values(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
78
178
  """
79
179
  Fetches the first num_vals unique values from the specified columns of the database table.
80
180
 
@@ -89,8 +189,74 @@ class DBLoadUniqueValues(DBTool):
89
189
  res = {}
90
190
  try:
91
191
  for column in columns:
92
- unique_vals = self.load_data_fn(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
192
+ unique_vals = self._load_data(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
93
193
  res[column] = [d.text for d in unique_vals]
94
194
  except Exception as e:
95
195
  return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
96
196
  return res
197
+
198
+ def to_tool_list(self) -> List[FunctionTool]:
199
+ """
200
+ Returns a list of tools available.
201
+ """
202
+
203
+ tool_list = []
204
+ for tool_name in self.spec_functions:
205
+ func_sync = None
206
+ func_async = None
207
+ func = getattr(self, tool_name)
208
+ if asyncio.iscoroutinefunction(func):
209
+ func_async = func
210
+ else:
211
+ func_sync = func
212
+ metadata = self._get_metadata_from_fn_name(tool_name)
213
+
214
+ if func_sync is None:
215
+ if func_async is not None:
216
+ func_sync = patch_sync(func_async)
217
+ else:
218
+ raise ValueError(
219
+ f"Could not retrieve a function for spec: {tool_name}"
220
+ )
221
+
222
+ tool = FunctionTool.from_defaults(
223
+ fn=func_sync,
224
+ async_fn=func_async,
225
+ tool_metadata=metadata,
226
+ )
227
+ tool_list.append(tool)
228
+ return tool_list
229
+
230
+ # Custom pickling: exclude unpickleable objects
231
+ def __getstate__(self):
232
+ state = self.__dict__.copy()
233
+ if "sql_database" in state:
234
+ state["sql_database_state"] = {"uri": self._uri}
235
+ del state["sql_database"]
236
+ if "_metadata" in state:
237
+ del state["_metadata"]
238
+ return state
239
+
240
+ def __setstate__(self, state):
241
+ self.__dict__.update(state)
242
+ # Reconstruct the sql_database if it was removed
243
+ if "sql_database_state" in state:
244
+ uri = state["sql_database_state"].get("uri")
245
+ if uri:
246
+ self.sql_database = SQLDatabase.from_uri(uri)
247
+ self._uri = uri
248
+ else:
249
+ raise ValueError("Cannot reconstruct SQLDatabase without URI")
250
+ # Rebuild metadata after restoring the engine
251
+ self._metadata = MetaData()
252
+ self._metadata.reflect(bind=self.sql_database.engine)
253
+
254
+
255
+ def patch_sync(func_async: AsyncCallable) -> Callable:
256
+ """Patch sync function from async function."""
257
+
258
+ def patched_sync(*args: Any, **kwargs: Any) -> Any:
259
+ loop = asyncio.get_event_loop()
260
+ return loop.run_until_complete(func_async(*args, **kwargs))
261
+
262
+ return patched_sync
@@ -102,7 +102,7 @@ class SubQuestionQueryWorkflow(Workflow):
102
102
  - What is the name of the mayor of San Jose?
103
103
  Here is the user question: {await ctx.get('original_query')}.
104
104
  Here are previous chat messages: {chat_history}.
105
- And here is the list of tools: {await ctx.get('tools')}
105
+ And here is the list of tools: {ev.tools}
106
106
  """,
107
107
  )
108
108
 
@@ -236,6 +236,7 @@ class SequentialSubQuestionsWorkflow(Workflow):
236
236
  print(f"Query is {await ctx.get('original_query')}")
237
237
 
238
238
  llm = await ctx.get("llm")
239
+ orig_query = await ctx.get("original_query")
239
240
  response = llm.complete(
240
241
  f"""
241
242
  Given a user question, and a list of tools, output a list of
@@ -256,9 +257,9 @@ class SequentialSubQuestionsWorkflow(Workflow):
256
257
  - Who is the mayor of this city?
257
258
  The answer to the first question is San Jose, which is given as context to the second question.
258
259
  The answer to the second question is Matt Mahan.
259
- Here is the user question: {await ctx.get('original_query')}.
260
+ Here is the user question: {orig_query}.
260
261
  Here are previous chat messages: {chat_history}.
261
- And here is the list of tools: {await ctx.get('tools')}
262
+ And here is the list of tools: {ev.tools}
262
263
  """,
263
264
  )
264
265
 
vectara_agentic/tools.py CHANGED
@@ -6,6 +6,7 @@ import inspect
6
6
  import re
7
7
  import importlib
8
8
  import os
9
+ import asyncio
9
10
 
10
11
  from typing import Callable, List, Dict, Any, Optional, Union, Type
11
12
  from pydantic import BaseModel, Field, create_model
@@ -20,8 +21,8 @@ from llama_index.core.workflow.context import Context
20
21
 
21
22
  from .types import ToolType
22
23
  from .tools_catalog import ToolsCatalog, get_bad_topics
23
- from .db_tools import DBLoadSampleData, DBLoadUniqueValues, DBLoadData
24
- from .utils import is_float, summarize_vectara_document
24
+ from .db_tools import DatabaseTools
25
+ from .utils import is_float, summarize_documents
25
26
  from .agent_config import AgentConfig
26
27
 
27
28
  LI_packages = {
@@ -31,7 +32,6 @@ LI_packages = {
31
32
  "exa": ToolType.QUERY,
32
33
  "neo4j": ToolType.QUERY,
33
34
  "kuzu": ToolType.QUERY,
34
- "database": ToolType.QUERY,
35
35
  "google": {
36
36
  "GmailToolSpec": {
37
37
  "load_data": ToolType.QUERY,
@@ -109,9 +109,20 @@ class VectaraTool(FunctionTool):
109
109
  fn, name, description, return_direct, fn_schema, async_fn, tool_metadata,
110
110
  callback, async_callback
111
111
  )
112
- vectara_tool = cls(tool_type=tool_type, fn=tool.fn, metadata=tool.metadata, async_fn=tool.async_fn)
112
+ vectara_tool = cls(
113
+ tool_type=tool_type, fn=tool.fn, metadata=tool.metadata, async_fn=tool.async_fn,
114
+ )
113
115
  return vectara_tool
114
116
 
117
+ def __str__(self) -> str:
118
+ return (
119
+ f"Tool(name={self.metadata.name}, "
120
+ f"Tool metadata={self.metadata})"
121
+ )
122
+
123
+ def __repr__(self) -> str:
124
+ return str(self)
125
+
115
126
  def __eq__(self, other):
116
127
  if not isinstance(other, VectaraTool):
117
128
  return False
@@ -508,17 +519,29 @@ class VectaraToolFactory:
508
519
  raw_input={"args": args, "kwargs": kwargs},
509
520
  raw_output={"response": msg},
510
521
  )
511
- tool_output = "Matching documents:\n"
512
522
  unique_ids = set()
523
+ docs = []
513
524
  for doc in response:
514
525
  if doc.id_ in unique_ids:
515
526
  continue
516
527
  unique_ids.add(doc.id_)
517
- if summarize:
518
- summary = summarize_vectara_document(self.vectara_corpus_key, self.vectara_api_key, doc.id_)
519
- tool_output += f"document_id: '{doc.id_}'\nmetadata: '{doc.metadata}'\nsummary: '{summary}'\n\n"
520
- else:
528
+ docs.append((doc.id_, doc.metadata))
529
+ tool_output = "Matching documents:\n"
530
+ if summarize:
531
+ summaries_dict = asyncio.run(
532
+ summarize_documents(
533
+ self.vectara_corpus_key,
534
+ self.vectara_api_key,
535
+ list(unique_ids)
536
+ )
537
+ )
538
+ for doc_id, metadata in docs:
539
+ summary = summaries_dict.get(doc_id, "")
540
+ tool_output += f"document_id: '{doc_id}'\nmetadata: '{metadata}'\nsummary: '{summary}'\n\n"
541
+ else:
542
+ for doc in docs:
521
543
  tool_output += f"document_id: '{doc.id_}'\nmetadata: '{doc.metadata}'\n\n"
544
+
522
545
  out = ToolOutput(
523
546
  tool_name=search_function.__name__,
524
547
  content=tool_output,
@@ -529,12 +552,14 @@ class VectaraToolFactory:
529
552
 
530
553
  base_params = [
531
554
  inspect.Parameter("query", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
532
- inspect.Parameter("top_k", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=10, annotation=int),
555
+ inspect.Parameter("top_k", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
533
556
  inspect.Parameter("summarize", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=True, annotation=bool),
534
557
  ]
535
558
  search_tool_extra_desc = tool_description + "\n" + """
559
+ This tool is meant to perform a search for relevant documents, it is not meant for asking questions.
536
560
  The response includes metadata about each relevant document.
537
- If summarize=True, it also includes a summary of each document.
561
+ If summarize=True, it also includes a summary of each document, but takes a lot longer to respond,
562
+ so avoid using it unless necessary.
538
563
  """
539
564
 
540
565
  tool = _create_tool_from_dynamic_function(
@@ -905,7 +930,7 @@ class ToolsFactory:
905
930
  user: str = "postgres",
906
931
  password: str = "Password",
907
932
  dbname: str = "postgres",
908
- max_rows: int = 500,
933
+ max_rows: int = 1000,
909
934
  ) -> List[VectaraTool]:
910
935
  """
911
936
  Returns a list of database tools.
@@ -923,24 +948,16 @@ class ToolsFactory:
923
948
  dbname (str, optional): The database name. Defaults to "postgres".
924
949
  You must specify either the sql_database object or the scheme, host, port, user, password, and dbname.
925
950
  max_rows (int, optional): if specified, instructs the load_data tool to never return more than max_rows
926
- rows. Defaults to 500.
951
+ rows. Defaults to 1000.
927
952
 
928
953
  Returns:
929
954
  List[VectaraTool]: A list of VectaraTool objects.
930
955
  """
931
956
  if sql_database:
932
- tools = self.get_llama_index_tools(
933
- tool_package_name="database",
934
- tool_spec_name="DatabaseToolSpec",
935
- tool_name_prefix=tool_name_prefix,
936
- sql_database=sql_database,
937
- )
957
+ dbt = DatabaseTools(sql_database=sql_database)
938
958
  else:
939
959
  if scheme in ["postgresql", "mysql", "sqlite", "mssql", "oracle"]:
940
- tools = self.get_llama_index_tools(
941
- tool_package_name="database",
942
- tool_spec_name="DatabaseToolSpec",
943
- tool_name_prefix=tool_name_prefix,
960
+ dbt = DatabaseTools(
944
961
  scheme=scheme,
945
962
  host=host,
946
963
  port=port,
@@ -955,28 +972,19 @@ class ToolsFactory:
955
972
  )
956
973
 
957
974
  # Update tools with description
975
+ tools = dbt.to_tool_list()
976
+ vtools = []
958
977
  for tool in tools:
959
978
  if content_description:
960
979
  tool.metadata.description = (
961
980
  tool.metadata.description + f"The database tables include data about {content_description}."
962
981
  )
963
-
964
- # Add two new tools: load_sample_data and load_unique_values
965
- load_data_tool_index = next(i for i, t in enumerate(tools) if t.metadata.name.endswith("load_data"))
966
- load_data_fn_original = tools[load_data_tool_index].fn
967
-
968
- load_data_fn = DBLoadData(load_data_fn_original, max_rows=max_rows)
969
- load_data_fn.__name__ = f"{tool_name_prefix}_load_data"
970
- load_data_tool = self.create_tool(load_data_fn, ToolType.QUERY)
971
-
972
- sample_data_fn = DBLoadSampleData(load_data_fn_original)
973
- sample_data_fn.__name__ = f"{tool_name_prefix}_load_sample_data"
974
- sample_data_tool = self.create_tool(sample_data_fn, ToolType.QUERY)
975
-
976
- load_unique_values_fn = DBLoadUniqueValues(load_data_fn_original)
977
- load_unique_values_fn.__name__ = f"{tool_name_prefix}_load_unique_values"
978
- load_unique_values_tool = self.create_tool(load_unique_values_fn, ToolType.QUERY)
979
-
980
- tools[load_data_tool_index] = load_data_tool
981
- tools.extend([sample_data_tool, load_unique_values_tool])
982
- return tools
982
+ if len(tool_name_prefix) > 0:
983
+ tool.metadata.name = tool_name_prefix + "_" + tool.metadata.name
984
+ vtool = VectaraTool(
985
+ tool_type=ToolType.QUERY,
986
+ fn=tool.fn, async_fn=tool.async_fn,
987
+ metadata=tool.metadata
988
+ )
989
+ vtools.append(vtool)
990
+ return vtools
vectara_agentic/utils.py CHANGED
@@ -6,9 +6,9 @@ from typing import Tuple, Callable, Optional
6
6
  from functools import lru_cache
7
7
  from inspect import signature
8
8
  import json
9
- import requests
10
-
9
+ import asyncio
11
10
  import tiktoken
11
+ import aiohttp
12
12
 
13
13
  from llama_index.core.llms import LLM
14
14
  from llama_index.llms.openai import OpenAI
@@ -101,13 +101,16 @@ def get_llm(
101
101
  max_tokens=max_tokens
102
102
  )
103
103
  elif model_provider == ModelProvider.ANTHROPIC:
104
- llm = Anthropic(model=model_name, temperature=0, max_tokens=max_tokens)
104
+ llm = Anthropic(
105
+ model=model_name, temperature=0,
106
+ max_tokens=max_tokens, cache_idx=2,
107
+ )
105
108
  elif model_provider == ModelProvider.GEMINI:
106
109
  from llama_index.llms.gemini import Gemini
107
110
  llm = Gemini(
108
111
  model=model_name, temperature=0,
109
112
  is_function_calling_model=True,
110
- max_tokens=max_tokens
113
+ max_tokens=max_tokens,
111
114
  )
112
115
  elif model_provider == ModelProvider.TOGETHER:
113
116
  from llama_index.llms.together import TogetherLLM
@@ -159,7 +162,7 @@ def remove_self_from_signature(func):
159
162
  func.__signature__ = new_sig
160
163
  return func
161
164
 
162
- def summarize_vectara_document(corpus_key: str, vectara_api_key, doc_id: str) -> str:
165
+ async def summarize_vectara_document(corpus_key: str, vectara_api_key: str, doc_id: str) -> str:
163
166
  """
164
167
  Summarize a document in a Vectara corpus using the Vectara API.
165
168
  """
@@ -175,8 +178,32 @@ def summarize_vectara_document(corpus_key: str, vectara_api_key, doc_id: str) ->
175
178
  'Accept': 'application/json',
176
179
  'x-api-key': vectara_api_key
177
180
  }
178
-
179
- response = requests.request("POST", url, headers=headers, data=payload, timeout=60)
180
- if response.status_code != 200:
181
- return f"Vectara Summarization failed with error code {response.status_code}, error={response.text}"
181
+ timeout = aiohttp.ClientTimeout(total=60)
182
+ async with aiohttp.ClientSession(timeout=timeout) as session:
183
+ async with session.post(url, headers=headers, data=payload) as response:
184
+ if response.status != 200:
185
+ error_json = await response.json()
186
+ return (
187
+ f"Vectara Summarization failed with error code {response.status}, "
188
+ f"error={error_json['messages'][0]}"
189
+ )
190
+ data = await response.json()
191
+ return data["summary"]
182
192
  return json.loads(response.text)["summary"]
193
+
194
+ async def summarize_documents(
195
+ vectara_corpus_key: str,
196
+ vectara_api_key: str,
197
+ doc_ids: list[str]
198
+ ) -> dict[str, str]:
199
+ """
200
+ Summarize multiple documents in a Vectara corpus using the Vectara API.
201
+ """
202
+ if not doc_ids:
203
+ return {}
204
+ tasks = [
205
+ summarize_vectara_document(vectara_corpus_key, vectara_api_key, doc_id)
206
+ for doc_id in doc_ids
207
+ ]
208
+ summaries = await asyncio.gather(*tasks, return_exceptions=True)
209
+ return dict(zip(doc_ids, summaries))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: vectara_agentic
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: A Python package for creating AI Assistants and AI Agents with Vectara
5
5
  Home-page: https://github.com/vectara/py-vectara-agentic
6
6
  Author: Ofer Mendelevitch
@@ -16,7 +16,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
16
16
  Requires-Python: >=3.10
17
17
  Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
- Requires-Dist: llama-index==0.12.25
19
+ Requires-Dist: llama-index==0.12.26
20
20
  Requires-Dist: llama-index-indices-managed-vectara==0.4.2
21
21
  Requires-Dist: llama-index-agent-llm-compiler==0.3.0
22
22
  Requires-Dist: llama-index-agent-lats==0.3.0
@@ -1,28 +1,29 @@
1
1
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  tests/endpoint.py,sha256=frnpdZQpnuQNNKNYgAn2rFTarNG8MCJaNA77Bw_W22A,1420
3
- tests/test_agent.py,sha256=CU7Zdb1J5kFjrkkIeDr70W1wsikYuOvoYKGYpg4EPxM,6678
3
+ tests/test_agent.py,sha256=t4omKBg9207hpT8b05v9TwuXJCM4knYSTdsXe740eho,4845
4
4
  tests/test_agent_planning.py,sha256=_mj73TNP9yUjkUJ-X31r-cQYreJ4qatXOtMrRvVpF4Y,2411
5
5
  tests/test_agent_type.py,sha256=JM0Q2GBGHSADoBacz_DW551zWSfbpf7qa8xXqtyWsc4,5671
6
6
  tests/test_fallback.py,sha256=M5YD7NHZ0joVU1frYIr9_OiRAIje5mrXrYVcekzlyGs,2829
7
7
  tests/test_private_llm.py,sha256=CY-_rCpxGUuxnZ3ypkodw5Jj-sJCNdh6rLbCvULwuJI,2247
8
+ tests/test_serialization.py,sha256=Ed23GN2zhSJNdPFrVK4aqLkOhJKviczR_o0t-r9TuRI,4762
8
9
  tests/test_tools.py,sha256=IVKn0HoS2erTCr1mOEGzTkktiY0PCfKNvqnD_pizjOg,3977
9
10
  tests/test_workflow.py,sha256=lVyrVHdRO5leYNbYtHTmKqMX0c8_xehCpUA7cXQKVsc,2175
10
11
  vectara_agentic/__init__.py,sha256=2GLDS3U6KckK-dBRl9v_x1kSV507gEhjOfuMmmu0Qxg,850
11
12
  vectara_agentic/_callback.py,sha256=5PfqjLmuaZIR6dnqmhniTD_zwCgfi7kOu-nexb6Kss4,9688
12
13
  vectara_agentic/_observability.py,sha256=fTL3KW0jQU-_JSpFgjO6-XzgDut_oiq9kt4QR-FkSqU,3804
13
14
  vectara_agentic/_prompts.py,sha256=LYyiOAiC8imz3U7MSJiuCYAP39afsp7ycXY7-9biyJI,9314
14
- vectara_agentic/_version.py,sha256=FGUM5lA5uZpmWWB52dt2AMCqWcU0M9b-2BB-raX-EN4,65
15
- vectara_agentic/agent.py,sha256=nbBl66n56kjEZX4Zconb9IZjESzpjBZIEQdL4uLfurI,43333
15
+ vectara_agentic/_version.py,sha256=HOBvs3gmojKxd7sNMHt6Q-0_rlFpgzlI1gXNZOS_Fqc,65
16
+ vectara_agentic/agent.py,sha256=ioC6EN86_d7SS1jEZ6CUe6OtetuGmLdWftj5bklPfMs,43522
16
17
  vectara_agentic/agent_config.py,sha256=y1hSvU5ns0cE2R7BqF65LFstixF1ytJcoVgicGXo7w0,3691
17
18
  vectara_agentic/agent_endpoint.py,sha256=QIMejCLlpW2qzXxeDAxv3anF46XMDdVMdKGWhJh3azY,1996
18
- vectara_agentic/db_tools.py,sha256=VUdcjDFPwauFd2A92mXNYZnCjeMiTzcTka7S5At_3oQ,3595
19
- vectara_agentic/sub_query_workflow.py,sha256=eS1S7l5PdyLPLZqxUJSR0oM2kvHb4raPGHk8t8td9sc,10939
20
- vectara_agentic/tools.py,sha256=RpPGWiPHe-9ZiOxNz389W-gNWxegg7m4RlEx_pH9_W0,42881
19
+ vectara_agentic/db_tools.py,sha256=Go03bzma9m-qDH0CPP8hWhf1nu_4S6s7ke0jGqz58Pk,10296
20
+ vectara_agentic/sub_query_workflow.py,sha256=3WoVnryR2NXyYXbLDM1XVLd7DtbCG0jgrVqeDUN4YNQ,10943
21
+ vectara_agentic/tools.py,sha256=Mm2qfJZWnbNa9G-ycYMP7NPLSo4uUJ9_y45YmXxtlSc,42571
21
22
  vectara_agentic/tools_catalog.py,sha256=oiw3wAfbpFhh0_6rMvZsyPqWV6QIzHqhZCNzqRxuyV8,4818
22
23
  vectara_agentic/types.py,sha256=HcS7vR8P2v2xQTlOc6ZFV2vvlr3OpzSNWhtcLMxqUZc,1792
23
- vectara_agentic/utils.py,sha256=U4VWCyrvpXfPb9SJpd4Xj7rJCN-cZCNReNm9_uQjnlk,6759
24
- vectara_agentic-0.2.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
25
- vectara_agentic-0.2.7.dist-info/METADATA,sha256=z-IFDKlGmNh9QSIUV4xOOCYKF-PrTUnDukp8M5BNMe4,25046
26
- vectara_agentic-0.2.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
27
- vectara_agentic-0.2.7.dist-info/top_level.txt,sha256=Y7TQTFdOYGYodQRltUGRieZKIYuzeZj2kHqAUpfCUfg,22
28
- vectara_agentic-0.2.7.dist-info/RECORD,,
24
+ vectara_agentic/utils.py,sha256=nBQqVb4_UNummqVz28DHm3VaKzy8OAq-xSjhU23uxWU,7646
25
+ vectara_agentic-0.2.8.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
26
+ vectara_agentic-0.2.8.dist-info/METADATA,sha256=IV5fm77XOPOvqfcpCZUKRxq9QgnoF3mPu-om_sTKEK8,25046
27
+ vectara_agentic-0.2.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
28
+ vectara_agentic-0.2.8.dist-info/top_level.txt,sha256=Y7TQTFdOYGYodQRltUGRieZKIYuzeZj2kHqAUpfCUfg,22
29
+ vectara_agentic-0.2.8.dist-info/RECORD,,