vectara-agentic 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vectara-agentic might be problematic. Click here for more details.

@@ -1,38 +1,122 @@
1
1
  """
2
- This module contains the code to extend and improve DatabaseToolSpec
3
- Specifically adding load_sample_data and load_unique_values methods, as well as
4
- making sure the load_data method returns a list of text values from the database, not Document[] objects.
2
+ This module contains the code adapted from DatabaseToolSpec
3
+ It makes the following adjustments:
4
+ * Adds load_sample_data and load_unique_values methods.
5
+ * Fixes serialization.
6
+ * Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
7
+ * Limits the returned rows to self.max_rows.
5
8
  """
6
- from abc import ABC
7
- from typing import Callable, Any
9
+ from typing import Any, Optional, List, Awaitable, Callable
10
+ import asyncio
11
+ from inspect import signature
8
12
 
9
- #
10
- # Additional database tool
11
- #
12
- class DBTool(ABC):
13
- """
14
- A base class for vectara-agentic database tools extensions
13
+ from sqlalchemy import MetaData, text
14
+ from sqlalchemy.engine import Engine
15
+ from sqlalchemy.exc import NoSuchTableError
16
+ from sqlalchemy.schema import CreateTable
17
+
18
+ from llama_index.core.readers.base import BaseReader
19
+ from llama_index.core.utilities.sql_wrapper import SQLDatabase
20
+ from llama_index.core.schema import Document
21
+ from llama_index.core.tools.function_tool import FunctionTool
22
+ from llama_index.core.tools.types import ToolMetadata
23
+ from llama_index.core.tools.utils import create_schema_from_function
24
+
25
+
26
+ AsyncCallable = Callable[..., Awaitable[Any]]
27
+
28
+ class DatabaseTools(BaseReader):
29
+ """Database tools for vectara-agentic
30
+ This class provides a set of tools to interact with a database.
31
+ It allows you to load data, list tables, describe tables, and load unique values.
32
+ It also provides a method to load sample data from a specified table.
15
33
  """
16
- def __init__(self, load_data_fn: Callable, max_rows: int = 1000):
17
- self.load_data_fn = load_data_fn
34
+ spec_functions = [
35
+ "load_data", "load_sample_data", "list_tables",
36
+ "describe_tables", "load_unique_values",
37
+ ]
38
+
39
+ def __init__(
40
+ self,
41
+ *args: Any,
42
+ max_rows: int = 1000,
43
+ sql_database: Optional[SQLDatabase] = None,
44
+ engine: Optional[Engine] = None,
45
+ uri: Optional[str] = None,
46
+ scheme: Optional[str] = None,
47
+ host: Optional[str] = None,
48
+ port: Optional[str] = None,
49
+ user: Optional[str] = None,
50
+ password: Optional[str] = None,
51
+ dbname: Optional[str] = None,
52
+ **kwargs: Any,
53
+ ) -> None:
18
54
  self.max_rows = max_rows
19
55
 
20
- class DBLoadData(DBTool):
21
- """
22
- A tool to Run SQL query on the database and return the result.
23
- """
24
- def __call__(self, query: str) -> Any:
25
- """Query and load data from the Database, returning a list of Documents.
56
+ if sql_database:
57
+ self.sql_database = sql_database
58
+ elif engine:
59
+ self.sql_database = SQLDatabase(engine, *args, **kwargs)
60
+ elif uri:
61
+ self.uri = uri
62
+ self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
63
+ elif (scheme and host and port and user and password and dbname):
64
+ uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
65
+ self.uri = uri
66
+ self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
67
+ else:
68
+ raise ValueError(
69
+ "You must provide either a SQLDatabase, "
70
+ "a SQL Alchemy Engine, a valid connection URI, or a valid "
71
+ "set of credentials."
72
+ )
73
+ self._uri = getattr(self, "uri", None) or str(self.sql_database.engine.url)
74
+ self._metadata = MetaData()
75
+ self._metadata.reflect(bind=self.sql_database.engine)
76
+
77
+ def _get_metadata_from_fn_name(
78
+ self, fn_name: Callable,
79
+ ) -> Optional[ToolMetadata]:
80
+ """Return map from function name.
81
+
82
+ Return type is Optional, meaning that the schema can be None.
83
+ In this case, it's up to the downstream tool implementation to infer the schema.
84
+ """
85
+ try:
86
+ func = getattr(self, fn_name)
87
+ except AttributeError:
88
+ return None
89
+ name = fn_name
90
+ docstring = func.__doc__ or ""
91
+ description = f"{name}{signature(func)}\n{docstring}"
92
+ fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
93
+ return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
94
+
95
+ def _load_data(self, query: str) -> List[Document]:
96
+ documents = []
97
+ with self.sql_database.engine.connect() as connection:
98
+ if query is None:
99
+ raise ValueError("A query parameter is necessary to filter the data")
100
+ result = connection.execute(text(query))
101
+ for item in result.fetchall():
102
+ doc_str = ", ".join([str(entry) for entry in item])
103
+ documents.append(Document(text=doc_str))
104
+ return documents
26
105
 
106
+ def load_data(self, *args: Any, **load_kwargs: Any) -> List[str]:
107
+ """Query and load data from the Database, returning a list of Documents.
27
108
  Args:
28
109
  query (str): an SQL query to filter tables and rows.
29
-
30
110
  Returns:
31
- List[text]: a list of text values from the database.
111
+ List[Document]: a list of Document objects from the database.
32
112
  """
113
+ query = args[0] if args else load_kwargs.get("args",{}).get("query")
114
+ if query is None:
115
+ raise ValueError("A query parameter is necessary to filter the data")
116
+
33
117
  count_query = f"SELECT COUNT(*) FROM ({query})"
34
118
  try:
35
- count_rows = self.load_data_fn(count_query)
119
+ count_rows = self._load_data(count_query)
36
120
  except Exception as e:
37
121
  return [f"Error ({str(e)}) occurred while counting number of rows"]
38
122
  num_rows = int(count_rows[0].text)
@@ -42,19 +126,12 @@ class DBLoadData(DBTool):
42
126
  "Please refactor your query to make it return less rows. "
43
127
  ]
44
128
  try:
45
- res = self.load_data_fn(query)
129
+ res = self._load_data(query)
46
130
  except Exception as e:
47
131
  return [f"Error ({str(e)}) occurred while executing the query {query}"]
48
132
  return [d.text for d in res]
49
133
 
50
- class DBLoadSampleData(DBTool):
51
- """
52
- A tool to load a sample of data from the specified database table.
53
-
54
- This tool fetches the first num_rows (default 25) rows from the given table
55
- using a provided database query function.
56
- """
57
- def __call__(self, table_name: str, num_rows: int = 25) -> Any:
134
+ def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
58
135
  """
59
136
  Fetches the first num_rows rows from the specified database table.
60
137
 
@@ -65,16 +142,39 @@ class DBLoadSampleData(DBTool):
65
142
  Any: The result of the database query.
66
143
  """
67
144
  try:
68
- res = self.load_data_fn(f"SELECT * FROM {table_name} LIMIT {num_rows}")
145
+ res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
69
146
  except Exception as e:
70
147
  return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
71
- return res
148
+ return [d.text for d in res]
72
149
 
73
- class DBLoadUniqueValues(DBTool):
74
- """
75
- A tool to list all unique values for each column in a set of columns of a database table.
76
- """
77
- def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
150
+ def list_tables(self) -> List[str]:
151
+ """List all tables in the database.
152
+ Returns:
153
+ List[str]: A list of table names in the database.
154
+ """
155
+ return [x.name for x in self._metadata.sorted_tables]
156
+
157
+ def describe_tables(self, tables: Optional[List[str]] = None) -> str:
158
+ """Describe the tables in the database.
159
+ Args:
160
+ tables (Optional[List[str]]): A list of table names to describe. If None, all tables are described.
161
+ Returns:
162
+ str: A string representation of the table schemas.
163
+ """
164
+ table_names = tables or [table.name for table in self._metadata.sorted_tables]
165
+ table_schemas = []
166
+ for table_name in table_names:
167
+ table = next(
168
+ (table for table in self._metadata.sorted_tables if table.name == table_name),
169
+ None,
170
+ )
171
+ if table is None:
172
+ raise NoSuchTableError(f"Table '{table_name}' does not exist.")
173
+ schema = str(CreateTable(table).compile(self.sql_database.engine))
174
+ table_schemas.append(f"{schema}\n")
175
+ return "\n".join(table_schemas)
176
+
177
+ def load_unique_values(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
78
178
  """
79
179
  Fetches the first num_vals unique values from the specified columns of the database table.
80
180
 
@@ -89,8 +189,74 @@ class DBLoadUniqueValues(DBTool):
89
189
  res = {}
90
190
  try:
91
191
  for column in columns:
92
- unique_vals = self.load_data_fn(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
192
+ unique_vals = self._load_data(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
93
193
  res[column] = [d.text for d in unique_vals]
94
194
  except Exception as e:
95
195
  return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
96
196
  return res
197
+
198
+ def to_tool_list(self) -> List[FunctionTool]:
199
+ """
200
+ Returns a list of tools available.
201
+ """
202
+
203
+ tool_list = []
204
+ for tool_name in self.spec_functions:
205
+ func_sync = None
206
+ func_async = None
207
+ func = getattr(self, tool_name)
208
+ if asyncio.iscoroutinefunction(func):
209
+ func_async = func
210
+ else:
211
+ func_sync = func
212
+ metadata = self._get_metadata_from_fn_name(tool_name)
213
+
214
+ if func_sync is None:
215
+ if func_async is not None:
216
+ func_sync = patch_sync(func_async)
217
+ else:
218
+ raise ValueError(
219
+ f"Could not retrieve a function for spec: {tool_name}"
220
+ )
221
+
222
+ tool = FunctionTool.from_defaults(
223
+ fn=func_sync,
224
+ async_fn=func_async,
225
+ tool_metadata=metadata,
226
+ )
227
+ tool_list.append(tool)
228
+ return tool_list
229
+
230
+ # Custom pickling: exclude unpickleable objects
231
+ def __getstate__(self):
232
+ state = self.__dict__.copy()
233
+ if "sql_database" in state:
234
+ state["sql_database_state"] = {"uri": self._uri}
235
+ del state["sql_database"]
236
+ if "_metadata" in state:
237
+ del state["_metadata"]
238
+ return state
239
+
240
+ def __setstate__(self, state):
241
+ self.__dict__.update(state)
242
+ # Reconstruct the sql_database if it was removed
243
+ if "sql_database_state" in state:
244
+ uri = state["sql_database_state"].get("uri")
245
+ if uri:
246
+ self.sql_database = SQLDatabase.from_uri(uri)
247
+ self._uri = uri
248
+ else:
249
+ raise ValueError("Cannot reconstruct SQLDatabase without URI")
250
+ # Rebuild metadata after restoring the engine
251
+ self._metadata = MetaData()
252
+ self._metadata.reflect(bind=self.sql_database.engine)
253
+
254
+
255
+ def patch_sync(func_async: AsyncCallable) -> Callable:
256
+ """Patch sync function from async function."""
257
+
258
+ def patched_sync(*args: Any, **kwargs: Any) -> Any:
259
+ loop = asyncio.get_event_loop()
260
+ return loop.run_until_complete(func_async(*args, **kwargs))
261
+
262
+ return patched_sync
@@ -102,7 +102,7 @@ class SubQuestionQueryWorkflow(Workflow):
102
102
  - What is the name of the mayor of San Jose?
103
103
  Here is the user question: {await ctx.get('original_query')}.
104
104
  Here are previous chat messages: {chat_history}.
105
- And here is the list of tools: {await ctx.get('tools')}
105
+ And here is the list of tools: {ev.tools}
106
106
  """,
107
107
  )
108
108
 
@@ -236,6 +236,7 @@ class SequentialSubQuestionsWorkflow(Workflow):
236
236
  print(f"Query is {await ctx.get('original_query')}")
237
237
 
238
238
  llm = await ctx.get("llm")
239
+ orig_query = await ctx.get("original_query")
239
240
  response = llm.complete(
240
241
  f"""
241
242
  Given a user question, and a list of tools, output a list of
@@ -252,11 +253,13 @@ class SequentialSubQuestionsWorkflow(Workflow):
252
253
  As an example, for the question
253
254
  "what is the name of the mayor of the largest city within 50 miles of San Francisco?",
254
255
  the sub-questions could be:
255
- - What is the largest city within 50 miles of San Francisco? (answer is San Jose)
256
- - What is the name of the mayor of San Jose?
257
- Here is the user question: {await ctx.get('original_query')}.
256
+ - What is the largest city within 50 miles of San Francisco?
257
+ - Who is the mayor of this city?
258
+ The answer to the first question is San Jose, which is given as context to the second question.
259
+ The answer to the second question is Matt Mahan.
260
+ Here is the user question: {orig_query}.
258
261
  Here are previous chat messages: {chat_history}.
259
- And here is the list of tools: {await ctx.get('tools')}
262
+ And here is the list of tools: {ev.tools}
260
263
  """,
261
264
  )
262
265
 
@@ -277,7 +280,16 @@ class SequentialSubQuestionsWorkflow(Workflow):
277
280
  if await ctx.get("verbose"):
278
281
  print(f"Sub-question is {ev.question}")
279
282
  agent = await ctx.get("agent")
280
- response = await agent.achat(ev.question)
283
+ sub_questions = await ctx.get("sub_questions")
284
+ if ev.prev_answer:
285
+ prev_question = sub_questions[ev.num - 1]
286
+ prompt = f"""
287
+ The answer to the question '{prev_question}' is: '{ev.prev_answer}'
288
+ Now answer the following question: '{ev.question}'
289
+ """
290
+ response = await agent.achat(prompt)
291
+ else:
292
+ response = await agent.achat(ev.question)
281
293
  if await ctx.get("verbose"):
282
294
  print(f"Answer is {response}")
283
295
 
@@ -286,7 +298,8 @@ class SequentialSubQuestionsWorkflow(Workflow):
286
298
  return self.QueryEvent(
287
299
  question=sub_questions[ev.num + 1],
288
300
  prev_answer = response.response,
289
- num=ev.num + 1)
301
+ num=ev.num + 1
302
+ )
290
303
 
291
304
  output = self.OutputsModel(response=response.response)
292
305
  return StopEvent(result=output)