vectara-agentic 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- tests/test_agent.py +38 -71
- tests/test_agent_planning.py +47 -20
- tests/test_agent_type.py +84 -10
- tests/test_fallback.py +2 -2
- tests/test_private_llm.py +1 -1
- tests/test_serialization.py +110 -0
- tests/test_tools.py +1 -1
- vectara_agentic/_version.py +1 -1
- vectara_agentic/agent.py +34 -12
- vectara_agentic/db_tools.py +205 -39
- vectara_agentic/sub_query_workflow.py +20 -7
- vectara_agentic/tools.py +139 -110
- vectara_agentic/types.py +1 -0
- vectara_agentic/utils.py +76 -10
- {vectara_agentic-0.2.6.dist-info → vectara_agentic-0.2.8.dist-info}/METADATA +2 -2
- vectara_agentic-0.2.8.dist-info/RECORD +29 -0
- {vectara_agentic-0.2.6.dist-info → vectara_agentic-0.2.8.dist-info}/WHEEL +1 -1
- vectara_agentic-0.2.6.dist-info/RECORD +0 -28
- {vectara_agentic-0.2.6.dist-info → vectara_agentic-0.2.8.dist-info}/licenses/LICENSE +0 -0
- {vectara_agentic-0.2.6.dist-info → vectara_agentic-0.2.8.dist-info}/top_level.txt +0 -0
vectara_agentic/db_tools.py
CHANGED
|
@@ -1,38 +1,122 @@
|
|
|
1
1
|
"""
|
|
2
|
-
This module contains the code
|
|
3
|
-
|
|
4
|
-
|
|
2
|
+
This module contains the code adapted from DatabaseToolSpec
|
|
3
|
+
It makes the following adjustments:
|
|
4
|
+
* Adds load_sample_data and load_unique_values methods.
|
|
5
|
+
* Fixes serialization.
|
|
6
|
+
* Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
|
|
7
|
+
* Limits the returned rows to self.max_rows.
|
|
5
8
|
"""
|
|
6
|
-
from
|
|
7
|
-
|
|
9
|
+
from typing import Any, Optional, List, Awaitable, Callable
|
|
10
|
+
import asyncio
|
|
11
|
+
from inspect import signature
|
|
8
12
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
from sqlalchemy import MetaData, text
|
|
14
|
+
from sqlalchemy.engine import Engine
|
|
15
|
+
from sqlalchemy.exc import NoSuchTableError
|
|
16
|
+
from sqlalchemy.schema import CreateTable
|
|
17
|
+
|
|
18
|
+
from llama_index.core.readers.base import BaseReader
|
|
19
|
+
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
20
|
+
from llama_index.core.schema import Document
|
|
21
|
+
from llama_index.core.tools.function_tool import FunctionTool
|
|
22
|
+
from llama_index.core.tools.types import ToolMetadata
|
|
23
|
+
from llama_index.core.tools.utils import create_schema_from_function
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
AsyncCallable = Callable[..., Awaitable[Any]]
|
|
27
|
+
|
|
28
|
+
class DatabaseTools(BaseReader):
|
|
29
|
+
"""Database tools for vectara-agentic
|
|
30
|
+
This class provides a set of tools to interact with a database.
|
|
31
|
+
It allows you to load data, list tables, describe tables, and load unique values.
|
|
32
|
+
It also provides a method to load sample data from a specified table.
|
|
15
33
|
"""
|
|
16
|
-
|
|
17
|
-
|
|
34
|
+
spec_functions = [
|
|
35
|
+
"load_data", "load_sample_data", "list_tables",
|
|
36
|
+
"describe_tables", "load_unique_values",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*args: Any,
|
|
42
|
+
max_rows: int = 1000,
|
|
43
|
+
sql_database: Optional[SQLDatabase] = None,
|
|
44
|
+
engine: Optional[Engine] = None,
|
|
45
|
+
uri: Optional[str] = None,
|
|
46
|
+
scheme: Optional[str] = None,
|
|
47
|
+
host: Optional[str] = None,
|
|
48
|
+
port: Optional[str] = None,
|
|
49
|
+
user: Optional[str] = None,
|
|
50
|
+
password: Optional[str] = None,
|
|
51
|
+
dbname: Optional[str] = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> None:
|
|
18
54
|
self.max_rows = max_rows
|
|
19
55
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
56
|
+
if sql_database:
|
|
57
|
+
self.sql_database = sql_database
|
|
58
|
+
elif engine:
|
|
59
|
+
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
|
60
|
+
elif uri:
|
|
61
|
+
self.uri = uri
|
|
62
|
+
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
63
|
+
elif (scheme and host and port and user and password and dbname):
|
|
64
|
+
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
|
65
|
+
self.uri = uri
|
|
66
|
+
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"You must provide either a SQLDatabase, "
|
|
70
|
+
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
|
71
|
+
"set of credentials."
|
|
72
|
+
)
|
|
73
|
+
self._uri = getattr(self, "uri", None) or str(self.sql_database.engine.url)
|
|
74
|
+
self._metadata = MetaData()
|
|
75
|
+
self._metadata.reflect(bind=self.sql_database.engine)
|
|
76
|
+
|
|
77
|
+
def _get_metadata_from_fn_name(
|
|
78
|
+
self, fn_name: Callable,
|
|
79
|
+
) -> Optional[ToolMetadata]:
|
|
80
|
+
"""Return map from function name.
|
|
81
|
+
|
|
82
|
+
Return type is Optional, meaning that the schema can be None.
|
|
83
|
+
In this case, it's up to the downstream tool implementation to infer the schema.
|
|
84
|
+
"""
|
|
85
|
+
try:
|
|
86
|
+
func = getattr(self, fn_name)
|
|
87
|
+
except AttributeError:
|
|
88
|
+
return None
|
|
89
|
+
name = fn_name
|
|
90
|
+
docstring = func.__doc__ or ""
|
|
91
|
+
description = f"{name}{signature(func)}\n{docstring}"
|
|
92
|
+
fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
|
|
93
|
+
return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
|
|
94
|
+
|
|
95
|
+
def _load_data(self, query: str) -> List[Document]:
|
|
96
|
+
documents = []
|
|
97
|
+
with self.sql_database.engine.connect() as connection:
|
|
98
|
+
if query is None:
|
|
99
|
+
raise ValueError("A query parameter is necessary to filter the data")
|
|
100
|
+
result = connection.execute(text(query))
|
|
101
|
+
for item in result.fetchall():
|
|
102
|
+
doc_str = ", ".join([str(entry) for entry in item])
|
|
103
|
+
documents.append(Document(text=doc_str))
|
|
104
|
+
return documents
|
|
26
105
|
|
|
106
|
+
def load_data(self, *args: Any, **load_kwargs: Any) -> List[str]:
|
|
107
|
+
"""Query and load data from the Database, returning a list of Documents.
|
|
27
108
|
Args:
|
|
28
109
|
query (str): an SQL query to filter tables and rows.
|
|
29
|
-
|
|
30
110
|
Returns:
|
|
31
|
-
List[
|
|
111
|
+
List[Document]: a list of Document objects from the database.
|
|
32
112
|
"""
|
|
113
|
+
query = args[0] if args else load_kwargs.get("args",{}).get("query")
|
|
114
|
+
if query is None:
|
|
115
|
+
raise ValueError("A query parameter is necessary to filter the data")
|
|
116
|
+
|
|
33
117
|
count_query = f"SELECT COUNT(*) FROM ({query})"
|
|
34
118
|
try:
|
|
35
|
-
count_rows = self.
|
|
119
|
+
count_rows = self._load_data(count_query)
|
|
36
120
|
except Exception as e:
|
|
37
121
|
return [f"Error ({str(e)}) occurred while counting number of rows"]
|
|
38
122
|
num_rows = int(count_rows[0].text)
|
|
@@ -42,19 +126,12 @@ class DBLoadData(DBTool):
|
|
|
42
126
|
"Please refactor your query to make it return less rows. "
|
|
43
127
|
]
|
|
44
128
|
try:
|
|
45
|
-
res = self.
|
|
129
|
+
res = self._load_data(query)
|
|
46
130
|
except Exception as e:
|
|
47
131
|
return [f"Error ({str(e)}) occurred while executing the query {query}"]
|
|
48
132
|
return [d.text for d in res]
|
|
49
133
|
|
|
50
|
-
|
|
51
|
-
"""
|
|
52
|
-
A tool to load a sample of data from the specified database table.
|
|
53
|
-
|
|
54
|
-
This tool fetches the first num_rows (default 25) rows from the given table
|
|
55
|
-
using a provided database query function.
|
|
56
|
-
"""
|
|
57
|
-
def __call__(self, table_name: str, num_rows: int = 25) -> Any:
|
|
134
|
+
def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
|
|
58
135
|
"""
|
|
59
136
|
Fetches the first num_rows rows from the specified database table.
|
|
60
137
|
|
|
@@ -65,16 +142,39 @@ class DBLoadSampleData(DBTool):
|
|
|
65
142
|
Any: The result of the database query.
|
|
66
143
|
"""
|
|
67
144
|
try:
|
|
68
|
-
res = self.
|
|
145
|
+
res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
|
|
69
146
|
except Exception as e:
|
|
70
147
|
return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
|
|
71
|
-
return res
|
|
148
|
+
return [d.text for d in res]
|
|
72
149
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
150
|
+
def list_tables(self) -> List[str]:
|
|
151
|
+
"""List all tables in the database.
|
|
152
|
+
Returns:
|
|
153
|
+
List[str]: A list of table names in the database.
|
|
154
|
+
"""
|
|
155
|
+
return [x.name for x in self._metadata.sorted_tables]
|
|
156
|
+
|
|
157
|
+
def describe_tables(self, tables: Optional[List[str]] = None) -> str:
|
|
158
|
+
"""Describe the tables in the database.
|
|
159
|
+
Args:
|
|
160
|
+
tables (Optional[List[str]]): A list of table names to describe. If None, all tables are described.
|
|
161
|
+
Returns:
|
|
162
|
+
str: A string representation of the table schemas.
|
|
163
|
+
"""
|
|
164
|
+
table_names = tables or [table.name for table in self._metadata.sorted_tables]
|
|
165
|
+
table_schemas = []
|
|
166
|
+
for table_name in table_names:
|
|
167
|
+
table = next(
|
|
168
|
+
(table for table in self._metadata.sorted_tables if table.name == table_name),
|
|
169
|
+
None,
|
|
170
|
+
)
|
|
171
|
+
if table is None:
|
|
172
|
+
raise NoSuchTableError(f"Table '{table_name}' does not exist.")
|
|
173
|
+
schema = str(CreateTable(table).compile(self.sql_database.engine))
|
|
174
|
+
table_schemas.append(f"{schema}\n")
|
|
175
|
+
return "\n".join(table_schemas)
|
|
176
|
+
|
|
177
|
+
def load_unique_values(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
|
|
78
178
|
"""
|
|
79
179
|
Fetches the first num_vals unique values from the specified columns of the database table.
|
|
80
180
|
|
|
@@ -89,8 +189,74 @@ class DBLoadUniqueValues(DBTool):
|
|
|
89
189
|
res = {}
|
|
90
190
|
try:
|
|
91
191
|
for column in columns:
|
|
92
|
-
unique_vals = self.
|
|
192
|
+
unique_vals = self._load_data(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
|
|
93
193
|
res[column] = [d.text for d in unique_vals]
|
|
94
194
|
except Exception as e:
|
|
95
195
|
return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
|
|
96
196
|
return res
|
|
197
|
+
|
|
198
|
+
def to_tool_list(self) -> List[FunctionTool]:
|
|
199
|
+
"""
|
|
200
|
+
Returns a list of tools available.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
tool_list = []
|
|
204
|
+
for tool_name in self.spec_functions:
|
|
205
|
+
func_sync = None
|
|
206
|
+
func_async = None
|
|
207
|
+
func = getattr(self, tool_name)
|
|
208
|
+
if asyncio.iscoroutinefunction(func):
|
|
209
|
+
func_async = func
|
|
210
|
+
else:
|
|
211
|
+
func_sync = func
|
|
212
|
+
metadata = self._get_metadata_from_fn_name(tool_name)
|
|
213
|
+
|
|
214
|
+
if func_sync is None:
|
|
215
|
+
if func_async is not None:
|
|
216
|
+
func_sync = patch_sync(func_async)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Could not retrieve a function for spec: {tool_name}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
tool = FunctionTool.from_defaults(
|
|
223
|
+
fn=func_sync,
|
|
224
|
+
async_fn=func_async,
|
|
225
|
+
tool_metadata=metadata,
|
|
226
|
+
)
|
|
227
|
+
tool_list.append(tool)
|
|
228
|
+
return tool_list
|
|
229
|
+
|
|
230
|
+
# Custom pickling: exclude unpickleable objects
|
|
231
|
+
def __getstate__(self):
|
|
232
|
+
state = self.__dict__.copy()
|
|
233
|
+
if "sql_database" in state:
|
|
234
|
+
state["sql_database_state"] = {"uri": self._uri}
|
|
235
|
+
del state["sql_database"]
|
|
236
|
+
if "_metadata" in state:
|
|
237
|
+
del state["_metadata"]
|
|
238
|
+
return state
|
|
239
|
+
|
|
240
|
+
def __setstate__(self, state):
|
|
241
|
+
self.__dict__.update(state)
|
|
242
|
+
# Reconstruct the sql_database if it was removed
|
|
243
|
+
if "sql_database_state" in state:
|
|
244
|
+
uri = state["sql_database_state"].get("uri")
|
|
245
|
+
if uri:
|
|
246
|
+
self.sql_database = SQLDatabase.from_uri(uri)
|
|
247
|
+
self._uri = uri
|
|
248
|
+
else:
|
|
249
|
+
raise ValueError("Cannot reconstruct SQLDatabase without URI")
|
|
250
|
+
# Rebuild metadata after restoring the engine
|
|
251
|
+
self._metadata = MetaData()
|
|
252
|
+
self._metadata.reflect(bind=self.sql_database.engine)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def patch_sync(func_async: AsyncCallable) -> Callable:
|
|
256
|
+
"""Patch sync function from async function."""
|
|
257
|
+
|
|
258
|
+
def patched_sync(*args: Any, **kwargs: Any) -> Any:
|
|
259
|
+
loop = asyncio.get_event_loop()
|
|
260
|
+
return loop.run_until_complete(func_async(*args, **kwargs))
|
|
261
|
+
|
|
262
|
+
return patched_sync
|
|
@@ -102,7 +102,7 @@ class SubQuestionQueryWorkflow(Workflow):
|
|
|
102
102
|
- What is the name of the mayor of San Jose?
|
|
103
103
|
Here is the user question: {await ctx.get('original_query')}.
|
|
104
104
|
Here are previous chat messages: {chat_history}.
|
|
105
|
-
And here is the list of tools: {
|
|
105
|
+
And here is the list of tools: {ev.tools}
|
|
106
106
|
""",
|
|
107
107
|
)
|
|
108
108
|
|
|
@@ -236,6 +236,7 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
236
236
|
print(f"Query is {await ctx.get('original_query')}")
|
|
237
237
|
|
|
238
238
|
llm = await ctx.get("llm")
|
|
239
|
+
orig_query = await ctx.get("original_query")
|
|
239
240
|
response = llm.complete(
|
|
240
241
|
f"""
|
|
241
242
|
Given a user question, and a list of tools, output a list of
|
|
@@ -252,11 +253,13 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
252
253
|
As an example, for the question
|
|
253
254
|
"what is the name of the mayor of the largest city within 50 miles of San Francisco?",
|
|
254
255
|
the sub-questions could be:
|
|
255
|
-
- What is the largest city within 50 miles of San Francisco?
|
|
256
|
-
-
|
|
257
|
-
|
|
256
|
+
- What is the largest city within 50 miles of San Francisco?
|
|
257
|
+
- Who is the mayor of this city?
|
|
258
|
+
The answer to the first question is San Jose, which is given as context to the second question.
|
|
259
|
+
The answer to the second question is Matt Mahan.
|
|
260
|
+
Here is the user question: {orig_query}.
|
|
258
261
|
Here are previous chat messages: {chat_history}.
|
|
259
|
-
And here is the list of tools: {
|
|
262
|
+
And here is the list of tools: {ev.tools}
|
|
260
263
|
""",
|
|
261
264
|
)
|
|
262
265
|
|
|
@@ -277,7 +280,16 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
277
280
|
if await ctx.get("verbose"):
|
|
278
281
|
print(f"Sub-question is {ev.question}")
|
|
279
282
|
agent = await ctx.get("agent")
|
|
280
|
-
|
|
283
|
+
sub_questions = await ctx.get("sub_questions")
|
|
284
|
+
if ev.prev_answer:
|
|
285
|
+
prev_question = sub_questions[ev.num - 1]
|
|
286
|
+
prompt = f"""
|
|
287
|
+
The answer to the question '{prev_question}' is: '{ev.prev_answer}'
|
|
288
|
+
Now answer the following question: '{ev.question}'
|
|
289
|
+
"""
|
|
290
|
+
response = await agent.achat(prompt)
|
|
291
|
+
else:
|
|
292
|
+
response = await agent.achat(ev.question)
|
|
281
293
|
if await ctx.get("verbose"):
|
|
282
294
|
print(f"Answer is {response}")
|
|
283
295
|
|
|
@@ -286,7 +298,8 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
286
298
|
return self.QueryEvent(
|
|
287
299
|
question=sub_questions[ev.num + 1],
|
|
288
300
|
prev_answer = response.response,
|
|
289
|
-
num=ev.num + 1
|
|
301
|
+
num=ev.num + 1
|
|
302
|
+
)
|
|
290
303
|
|
|
291
304
|
output = self.OutputsModel(response=response.response)
|
|
292
305
|
return StopEvent(result=output)
|