vectara-agentic 0.2.7__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vectara-agentic might be problematic. Click here for more details.
- {vectara_agentic-0.2.7/vectara_agentic.egg-info → vectara_agentic-0.2.8}/PKG-INFO +2 -2
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/requirements.txt +1 -1
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_agent.py +0 -37
- vectara_agentic-0.2.8/tests/test_serialization.py +110 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/_version.py +1 -1
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/agent.py +8 -4
- vectara_agentic-0.2.8/vectara_agentic/db_tools.py +262 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/sub_query_workflow.py +4 -3
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/tools.py +51 -43
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/utils.py +36 -9
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8/vectara_agentic.egg-info}/PKG-INFO +2 -2
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic.egg-info/SOURCES.txt +1 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic.egg-info/requires.txt +1 -1
- vectara_agentic-0.2.7/vectara_agentic/db_tools.py +0 -96
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/LICENSE +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/MANIFEST.in +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/README.md +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/setup.cfg +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/setup.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/__init__.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/endpoint.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_agent_planning.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_agent_type.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_fallback.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_private_llm.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_tools.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/tests/test_workflow.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/__init__.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/_callback.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/_observability.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/_prompts.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/agent_config.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/agent_endpoint.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/tools_catalog.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic/types.py +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic.egg-info/dependency_links.txt +0 -0
- {vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: vectara_agentic
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Summary: A Python package for creating AI Assistants and AI Agents with Vectara
|
|
5
5
|
Home-page: https://github.com/vectara/py-vectara-agentic
|
|
6
6
|
Author: Ofer Mendelevitch
|
|
@@ -16,7 +16,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
16
16
|
Requires-Python: >=3.10
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
|
-
Requires-Dist: llama-index==0.12.
|
|
19
|
+
Requires-Dist: llama-index==0.12.26
|
|
20
20
|
Requires-Dist: llama-index-indices-managed-vectara==0.4.2
|
|
21
21
|
Requires-Dist: llama-index-agent-llm-compiler==0.3.0
|
|
22
22
|
Requires-Dist: llama-index-agent-lats==0.3.0
|
|
@@ -106,43 +106,6 @@ class TestAgentPackage(unittest.TestCase):
|
|
|
106
106
|
self.assertIsInstance(agent, Agent)
|
|
107
107
|
self.assertEqual(agent._topic, "question answering")
|
|
108
108
|
|
|
109
|
-
def test_serialization(self):
|
|
110
|
-
with ARIZE_LOCK:
|
|
111
|
-
config = AgentConfig(
|
|
112
|
-
agent_type=AgentType.REACT,
|
|
113
|
-
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
114
|
-
tool_llm_provider=ModelProvider.TOGETHER,
|
|
115
|
-
observer=ObserverType.ARIZE_PHOENIX
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
agent = Agent.from_corpus(
|
|
119
|
-
tool_name="RAG Tool",
|
|
120
|
-
agent_config=config,
|
|
121
|
-
vectara_corpus_key="corpus_key",
|
|
122
|
-
vectara_api_key="api_key",
|
|
123
|
-
data_description="information",
|
|
124
|
-
assistant_specialty="question answering",
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
agent_reloaded = agent.loads(agent.dumps())
|
|
128
|
-
agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
|
|
129
|
-
|
|
130
|
-
self.assertIsInstance(agent_reloaded, Agent)
|
|
131
|
-
self.assertEqual(agent, agent_reloaded)
|
|
132
|
-
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
133
|
-
|
|
134
|
-
self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
|
|
135
|
-
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
|
|
136
|
-
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
|
|
137
|
-
|
|
138
|
-
self.assertIsInstance(agent_reloaded, Agent)
|
|
139
|
-
self.assertEqual(agent, agent_reloaded_again)
|
|
140
|
-
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
141
|
-
|
|
142
|
-
self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
|
|
143
|
-
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
|
|
144
|
-
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
|
|
145
|
-
|
|
146
109
|
def test_chat_history(self):
|
|
147
110
|
tools = [ToolsFactory().create_tool(mult)]
|
|
148
111
|
topic = "AI topic"
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import threading
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from vectara_agentic.agent import Agent, AgentType
|
|
6
|
+
from vectara_agentic.agent_config import AgentConfig
|
|
7
|
+
from vectara_agentic.types import ModelProvider, ObserverType
|
|
8
|
+
from vectara_agentic.tools import ToolsFactory
|
|
9
|
+
|
|
10
|
+
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
11
|
+
from sqlalchemy import create_engine
|
|
12
|
+
|
|
13
|
+
def mult(x: float, y: float) -> float:
|
|
14
|
+
return x * y
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ARIZE_LOCK = threading.Lock()
|
|
18
|
+
|
|
19
|
+
class TestAgentSerialization(unittest.TestCase):
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def tearDown(cls):
|
|
23
|
+
try:
|
|
24
|
+
os.remove('ev_database.db')
|
|
25
|
+
except FileNotFoundError:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
def test_serialization(self):
|
|
29
|
+
with ARIZE_LOCK:
|
|
30
|
+
config = AgentConfig(
|
|
31
|
+
agent_type=AgentType.REACT,
|
|
32
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
33
|
+
tool_llm_provider=ModelProvider.TOGETHER,
|
|
34
|
+
observer=ObserverType.ARIZE_PHOENIX
|
|
35
|
+
)
|
|
36
|
+
db_tools = ToolsFactory().database_tools(
|
|
37
|
+
tool_name_prefix = "ev",
|
|
38
|
+
content_description = 'Electric Vehicles in the state of Washington and other population information',
|
|
39
|
+
sql_database = SQLDatabase(create_engine('sqlite:///ev_database.db')),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
tools = [ToolsFactory().create_tool(mult)] + ToolsFactory().standard_tools() + db_tools
|
|
43
|
+
topic = "AI topic"
|
|
44
|
+
instructions = "Always do as your father tells you, if your mother agrees!"
|
|
45
|
+
agent = Agent(
|
|
46
|
+
tools=tools,
|
|
47
|
+
topic=topic,
|
|
48
|
+
custom_instructions=instructions,
|
|
49
|
+
agent_config=config
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
agent_reloaded = agent.loads(agent.dumps())
|
|
53
|
+
agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
|
|
54
|
+
|
|
55
|
+
self.assertIsInstance(agent_reloaded, Agent)
|
|
56
|
+
self.assertEqual(agent, agent_reloaded)
|
|
57
|
+
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
58
|
+
|
|
59
|
+
self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
|
|
60
|
+
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
|
|
61
|
+
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
|
|
62
|
+
|
|
63
|
+
self.assertIsInstance(agent_reloaded, Agent)
|
|
64
|
+
self.assertEqual(agent, agent_reloaded_again)
|
|
65
|
+
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
66
|
+
|
|
67
|
+
self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
|
|
68
|
+
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
|
|
69
|
+
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
|
|
70
|
+
|
|
71
|
+
def test_serialization_from_corpus(self):
|
|
72
|
+
with ARIZE_LOCK:
|
|
73
|
+
config = AgentConfig(
|
|
74
|
+
agent_type=AgentType.REACT,
|
|
75
|
+
main_llm_provider=ModelProvider.ANTHROPIC,
|
|
76
|
+
tool_llm_provider=ModelProvider.TOGETHER,
|
|
77
|
+
observer=ObserverType.ARIZE_PHOENIX
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
agent = Agent.from_corpus(
|
|
81
|
+
tool_name="RAG Tool",
|
|
82
|
+
agent_config=config,
|
|
83
|
+
vectara_corpus_key="corpus_key",
|
|
84
|
+
vectara_api_key="api_key",
|
|
85
|
+
data_description="information",
|
|
86
|
+
assistant_specialty="question answering",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
agent_reloaded = agent.loads(agent.dumps())
|
|
90
|
+
agent_reloaded_again = agent_reloaded.loads(agent_reloaded.dumps())
|
|
91
|
+
|
|
92
|
+
self.assertIsInstance(agent_reloaded, Agent)
|
|
93
|
+
self.assertEqual(agent, agent_reloaded)
|
|
94
|
+
self.assertEqual(agent.agent_type, agent_reloaded.agent_type)
|
|
95
|
+
|
|
96
|
+
self.assertEqual(agent.agent_config.observer, agent_reloaded.agent_config.observer)
|
|
97
|
+
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded.agent_config.main_llm_provider)
|
|
98
|
+
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded.agent_config.tool_llm_provider)
|
|
99
|
+
|
|
100
|
+
self.assertIsInstance(agent_reloaded, Agent)
|
|
101
|
+
self.assertEqual(agent, agent_reloaded_again)
|
|
102
|
+
self.assertEqual(agent.agent_type, agent_reloaded_again.agent_type)
|
|
103
|
+
|
|
104
|
+
self.assertEqual(agent.agent_config.observer, agent_reloaded_again.agent_config.observer)
|
|
105
|
+
self.assertEqual(agent.agent_config.main_llm_provider, agent_reloaded_again.agent_config.main_llm_provider)
|
|
106
|
+
self.assertEqual(agent.agent_config.tool_llm_provider, agent_reloaded_again.agent_config.tool_llm_provider)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == "__main__":
|
|
110
|
+
unittest.main()
|
|
@@ -768,6 +768,7 @@ class Agent:
|
|
|
768
768
|
"""
|
|
769
769
|
max_attempts = 4 if self.fallback_agent_config else 2
|
|
770
770
|
attempt = 0
|
|
771
|
+
orig_llm = self.llm.metadata.model_name
|
|
771
772
|
while attempt < max_attempts:
|
|
772
773
|
try:
|
|
773
774
|
current_agent = self._get_current_agent()
|
|
@@ -788,16 +789,20 @@ class Agent:
|
|
|
788
789
|
agent_response.async_response_gen = _stream_response_wrapper # Override the generator
|
|
789
790
|
return agent_response
|
|
790
791
|
|
|
791
|
-
except Exception:
|
|
792
|
+
except Exception as e:
|
|
793
|
+
last_error = e
|
|
792
794
|
if attempt >= 2:
|
|
793
795
|
if self.verbose:
|
|
794
|
-
print("LLM call failed. Switching agent configuration.")
|
|
796
|
+
print(f"LLM call failed on attempt {attempt}. Switching agent configuration.")
|
|
795
797
|
self._switch_agent_config()
|
|
796
798
|
time.sleep(1)
|
|
797
799
|
attempt += 1
|
|
798
800
|
|
|
799
801
|
return AgentResponse(
|
|
800
|
-
response=
|
|
802
|
+
response=(
|
|
803
|
+
f"For {orig_llm} LLM - failure can't be resolved after "
|
|
804
|
+
f"{max_attempts} attempts ({last_error})."
|
|
805
|
+
)
|
|
801
806
|
)
|
|
802
807
|
|
|
803
808
|
#
|
|
@@ -861,7 +866,6 @@ class Agent:
|
|
|
861
866
|
def to_dict(self) -> Dict[str, Any]:
|
|
862
867
|
"""Serialize the Agent instance to a dictionary."""
|
|
863
868
|
tool_info = []
|
|
864
|
-
|
|
865
869
|
for tool in self.tools:
|
|
866
870
|
if hasattr(tool.metadata, "fn_schema"):
|
|
867
871
|
fn_schema_cls = tool.metadata.fn_schema
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the code adapted from DatabaseToolSpec
|
|
3
|
+
It makes the following adjustments:
|
|
4
|
+
* Adds load_sample_data and load_unique_values methods.
|
|
5
|
+
* Fixes serialization.
|
|
6
|
+
* Makes sure the load_data method returns a list of text values from the database (and not Document[] objects).
|
|
7
|
+
* Limits the returned rows to self.max_rows.
|
|
8
|
+
"""
|
|
9
|
+
from typing import Any, Optional, List, Awaitable, Callable
|
|
10
|
+
import asyncio
|
|
11
|
+
from inspect import signature
|
|
12
|
+
|
|
13
|
+
from sqlalchemy import MetaData, text
|
|
14
|
+
from sqlalchemy.engine import Engine
|
|
15
|
+
from sqlalchemy.exc import NoSuchTableError
|
|
16
|
+
from sqlalchemy.schema import CreateTable
|
|
17
|
+
|
|
18
|
+
from llama_index.core.readers.base import BaseReader
|
|
19
|
+
from llama_index.core.utilities.sql_wrapper import SQLDatabase
|
|
20
|
+
from llama_index.core.schema import Document
|
|
21
|
+
from llama_index.core.tools.function_tool import FunctionTool
|
|
22
|
+
from llama_index.core.tools.types import ToolMetadata
|
|
23
|
+
from llama_index.core.tools.utils import create_schema_from_function
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
AsyncCallable = Callable[..., Awaitable[Any]]
|
|
27
|
+
|
|
28
|
+
class DatabaseTools(BaseReader):
|
|
29
|
+
"""Database tools for vectara-agentic
|
|
30
|
+
This class provides a set of tools to interact with a database.
|
|
31
|
+
It allows you to load data, list tables, describe tables, and load unique values.
|
|
32
|
+
It also provides a method to load sample data from a specified table.
|
|
33
|
+
"""
|
|
34
|
+
spec_functions = [
|
|
35
|
+
"load_data", "load_sample_data", "list_tables",
|
|
36
|
+
"describe_tables", "load_unique_values",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
*args: Any,
|
|
42
|
+
max_rows: int = 1000,
|
|
43
|
+
sql_database: Optional[SQLDatabase] = None,
|
|
44
|
+
engine: Optional[Engine] = None,
|
|
45
|
+
uri: Optional[str] = None,
|
|
46
|
+
scheme: Optional[str] = None,
|
|
47
|
+
host: Optional[str] = None,
|
|
48
|
+
port: Optional[str] = None,
|
|
49
|
+
user: Optional[str] = None,
|
|
50
|
+
password: Optional[str] = None,
|
|
51
|
+
dbname: Optional[str] = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> None:
|
|
54
|
+
self.max_rows = max_rows
|
|
55
|
+
|
|
56
|
+
if sql_database:
|
|
57
|
+
self.sql_database = sql_database
|
|
58
|
+
elif engine:
|
|
59
|
+
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
|
60
|
+
elif uri:
|
|
61
|
+
self.uri = uri
|
|
62
|
+
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
63
|
+
elif (scheme and host and port and user and password and dbname):
|
|
64
|
+
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
|
65
|
+
self.uri = uri
|
|
66
|
+
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"You must provide either a SQLDatabase, "
|
|
70
|
+
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
|
71
|
+
"set of credentials."
|
|
72
|
+
)
|
|
73
|
+
self._uri = getattr(self, "uri", None) or str(self.sql_database.engine.url)
|
|
74
|
+
self._metadata = MetaData()
|
|
75
|
+
self._metadata.reflect(bind=self.sql_database.engine)
|
|
76
|
+
|
|
77
|
+
def _get_metadata_from_fn_name(
|
|
78
|
+
self, fn_name: Callable,
|
|
79
|
+
) -> Optional[ToolMetadata]:
|
|
80
|
+
"""Return map from function name.
|
|
81
|
+
|
|
82
|
+
Return type is Optional, meaning that the schema can be None.
|
|
83
|
+
In this case, it's up to the downstream tool implementation to infer the schema.
|
|
84
|
+
"""
|
|
85
|
+
try:
|
|
86
|
+
func = getattr(self, fn_name)
|
|
87
|
+
except AttributeError:
|
|
88
|
+
return None
|
|
89
|
+
name = fn_name
|
|
90
|
+
docstring = func.__doc__ or ""
|
|
91
|
+
description = f"{name}{signature(func)}\n{docstring}"
|
|
92
|
+
fn_schema = create_schema_from_function(fn_name, getattr(self, fn_name))
|
|
93
|
+
return ToolMetadata(name=name, description=description, fn_schema=fn_schema)
|
|
94
|
+
|
|
95
|
+
def _load_data(self, query: str) -> List[Document]:
|
|
96
|
+
documents = []
|
|
97
|
+
with self.sql_database.engine.connect() as connection:
|
|
98
|
+
if query is None:
|
|
99
|
+
raise ValueError("A query parameter is necessary to filter the data")
|
|
100
|
+
result = connection.execute(text(query))
|
|
101
|
+
for item in result.fetchall():
|
|
102
|
+
doc_str = ", ".join([str(entry) for entry in item])
|
|
103
|
+
documents.append(Document(text=doc_str))
|
|
104
|
+
return documents
|
|
105
|
+
|
|
106
|
+
def load_data(self, *args: Any, **load_kwargs: Any) -> List[str]:
|
|
107
|
+
"""Query and load data from the Database, returning a list of Documents.
|
|
108
|
+
Args:
|
|
109
|
+
query (str): an SQL query to filter tables and rows.
|
|
110
|
+
Returns:
|
|
111
|
+
List[Document]: a list of Document objects from the database.
|
|
112
|
+
"""
|
|
113
|
+
query = args[0] if args else load_kwargs.get("args",{}).get("query")
|
|
114
|
+
if query is None:
|
|
115
|
+
raise ValueError("A query parameter is necessary to filter the data")
|
|
116
|
+
|
|
117
|
+
count_query = f"SELECT COUNT(*) FROM ({query})"
|
|
118
|
+
try:
|
|
119
|
+
count_rows = self._load_data(count_query)
|
|
120
|
+
except Exception as e:
|
|
121
|
+
return [f"Error ({str(e)}) occurred while counting number of rows"]
|
|
122
|
+
num_rows = int(count_rows[0].text)
|
|
123
|
+
if num_rows > self.max_rows:
|
|
124
|
+
return [
|
|
125
|
+
f"The query is expected to return more than {self.max_rows} rows. "
|
|
126
|
+
"Please refactor your query to make it return less rows. "
|
|
127
|
+
]
|
|
128
|
+
try:
|
|
129
|
+
res = self._load_data(query)
|
|
130
|
+
except Exception as e:
|
|
131
|
+
return [f"Error ({str(e)}) occurred while executing the query {query}"]
|
|
132
|
+
return [d.text for d in res]
|
|
133
|
+
|
|
134
|
+
def load_sample_data(self, table_name: str, num_rows: int = 25) -> Any:
|
|
135
|
+
"""
|
|
136
|
+
Fetches the first num_rows rows from the specified database table.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
table_name (str): The name of the database table.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Any: The result of the database query.
|
|
143
|
+
"""
|
|
144
|
+
try:
|
|
145
|
+
res = self._load_data(f"SELECT * FROM {table_name} LIMIT {num_rows}")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
|
|
148
|
+
return [d.text for d in res]
|
|
149
|
+
|
|
150
|
+
def list_tables(self) -> List[str]:
|
|
151
|
+
"""List all tables in the database.
|
|
152
|
+
Returns:
|
|
153
|
+
List[str]: A list of table names in the database.
|
|
154
|
+
"""
|
|
155
|
+
return [x.name for x in self._metadata.sorted_tables]
|
|
156
|
+
|
|
157
|
+
def describe_tables(self, tables: Optional[List[str]] = None) -> str:
|
|
158
|
+
"""Describe the tables in the database.
|
|
159
|
+
Args:
|
|
160
|
+
tables (Optional[List[str]]): A list of table names to describe. If None, all tables are described.
|
|
161
|
+
Returns:
|
|
162
|
+
str: A string representation of the table schemas.
|
|
163
|
+
"""
|
|
164
|
+
table_names = tables or [table.name for table in self._metadata.sorted_tables]
|
|
165
|
+
table_schemas = []
|
|
166
|
+
for table_name in table_names:
|
|
167
|
+
table = next(
|
|
168
|
+
(table for table in self._metadata.sorted_tables if table.name == table_name),
|
|
169
|
+
None,
|
|
170
|
+
)
|
|
171
|
+
if table is None:
|
|
172
|
+
raise NoSuchTableError(f"Table '{table_name}' does not exist.")
|
|
173
|
+
schema = str(CreateTable(table).compile(self.sql_database.engine))
|
|
174
|
+
table_schemas.append(f"{schema}\n")
|
|
175
|
+
return "\n".join(table_schemas)
|
|
176
|
+
|
|
177
|
+
def load_unique_values(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
|
|
178
|
+
"""
|
|
179
|
+
Fetches the first num_vals unique values from the specified columns of the database table.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
table_name (str): The name of the database table.
|
|
183
|
+
columns (list[str]): The list of columns to fetch unique values from.
|
|
184
|
+
num_vals (int): The number of unique values to fetch for each column. Default is 200.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Any: the result of the database query
|
|
188
|
+
"""
|
|
189
|
+
res = {}
|
|
190
|
+
try:
|
|
191
|
+
for column in columns:
|
|
192
|
+
unique_vals = self._load_data(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
|
|
193
|
+
res[column] = [d.text for d in unique_vals]
|
|
194
|
+
except Exception as e:
|
|
195
|
+
return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
|
|
196
|
+
return res
|
|
197
|
+
|
|
198
|
+
def to_tool_list(self) -> List[FunctionTool]:
|
|
199
|
+
"""
|
|
200
|
+
Returns a list of tools available.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
tool_list = []
|
|
204
|
+
for tool_name in self.spec_functions:
|
|
205
|
+
func_sync = None
|
|
206
|
+
func_async = None
|
|
207
|
+
func = getattr(self, tool_name)
|
|
208
|
+
if asyncio.iscoroutinefunction(func):
|
|
209
|
+
func_async = func
|
|
210
|
+
else:
|
|
211
|
+
func_sync = func
|
|
212
|
+
metadata = self._get_metadata_from_fn_name(tool_name)
|
|
213
|
+
|
|
214
|
+
if func_sync is None:
|
|
215
|
+
if func_async is not None:
|
|
216
|
+
func_sync = patch_sync(func_async)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Could not retrieve a function for spec: {tool_name}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
tool = FunctionTool.from_defaults(
|
|
223
|
+
fn=func_sync,
|
|
224
|
+
async_fn=func_async,
|
|
225
|
+
tool_metadata=metadata,
|
|
226
|
+
)
|
|
227
|
+
tool_list.append(tool)
|
|
228
|
+
return tool_list
|
|
229
|
+
|
|
230
|
+
# Custom pickling: exclude unpickleable objects
|
|
231
|
+
def __getstate__(self):
|
|
232
|
+
state = self.__dict__.copy()
|
|
233
|
+
if "sql_database" in state:
|
|
234
|
+
state["sql_database_state"] = {"uri": self._uri}
|
|
235
|
+
del state["sql_database"]
|
|
236
|
+
if "_metadata" in state:
|
|
237
|
+
del state["_metadata"]
|
|
238
|
+
return state
|
|
239
|
+
|
|
240
|
+
def __setstate__(self, state):
|
|
241
|
+
self.__dict__.update(state)
|
|
242
|
+
# Reconstruct the sql_database if it was removed
|
|
243
|
+
if "sql_database_state" in state:
|
|
244
|
+
uri = state["sql_database_state"].get("uri")
|
|
245
|
+
if uri:
|
|
246
|
+
self.sql_database = SQLDatabase.from_uri(uri)
|
|
247
|
+
self._uri = uri
|
|
248
|
+
else:
|
|
249
|
+
raise ValueError("Cannot reconstruct SQLDatabase without URI")
|
|
250
|
+
# Rebuild metadata after restoring the engine
|
|
251
|
+
self._metadata = MetaData()
|
|
252
|
+
self._metadata.reflect(bind=self.sql_database.engine)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def patch_sync(func_async: AsyncCallable) -> Callable:
|
|
256
|
+
"""Patch sync function from async function."""
|
|
257
|
+
|
|
258
|
+
def patched_sync(*args: Any, **kwargs: Any) -> Any:
|
|
259
|
+
loop = asyncio.get_event_loop()
|
|
260
|
+
return loop.run_until_complete(func_async(*args, **kwargs))
|
|
261
|
+
|
|
262
|
+
return patched_sync
|
|
@@ -102,7 +102,7 @@ class SubQuestionQueryWorkflow(Workflow):
|
|
|
102
102
|
- What is the name of the mayor of San Jose?
|
|
103
103
|
Here is the user question: {await ctx.get('original_query')}.
|
|
104
104
|
Here are previous chat messages: {chat_history}.
|
|
105
|
-
And here is the list of tools: {
|
|
105
|
+
And here is the list of tools: {ev.tools}
|
|
106
106
|
""",
|
|
107
107
|
)
|
|
108
108
|
|
|
@@ -236,6 +236,7 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
236
236
|
print(f"Query is {await ctx.get('original_query')}")
|
|
237
237
|
|
|
238
238
|
llm = await ctx.get("llm")
|
|
239
|
+
orig_query = await ctx.get("original_query")
|
|
239
240
|
response = llm.complete(
|
|
240
241
|
f"""
|
|
241
242
|
Given a user question, and a list of tools, output a list of
|
|
@@ -256,9 +257,9 @@ class SequentialSubQuestionsWorkflow(Workflow):
|
|
|
256
257
|
- Who is the mayor of this city?
|
|
257
258
|
The answer to the first question is San Jose, which is given as context to the second question.
|
|
258
259
|
The answer to the second question is Matt Mahan.
|
|
259
|
-
Here is the user question: {
|
|
260
|
+
Here is the user question: {orig_query}.
|
|
260
261
|
Here are previous chat messages: {chat_history}.
|
|
261
|
-
And here is the list of tools: {
|
|
262
|
+
And here is the list of tools: {ev.tools}
|
|
262
263
|
""",
|
|
263
264
|
)
|
|
264
265
|
|
|
@@ -6,6 +6,7 @@ import inspect
|
|
|
6
6
|
import re
|
|
7
7
|
import importlib
|
|
8
8
|
import os
|
|
9
|
+
import asyncio
|
|
9
10
|
|
|
10
11
|
from typing import Callable, List, Dict, Any, Optional, Union, Type
|
|
11
12
|
from pydantic import BaseModel, Field, create_model
|
|
@@ -20,8 +21,8 @@ from llama_index.core.workflow.context import Context
|
|
|
20
21
|
|
|
21
22
|
from .types import ToolType
|
|
22
23
|
from .tools_catalog import ToolsCatalog, get_bad_topics
|
|
23
|
-
from .db_tools import
|
|
24
|
-
from .utils import is_float,
|
|
24
|
+
from .db_tools import DatabaseTools
|
|
25
|
+
from .utils import is_float, summarize_documents
|
|
25
26
|
from .agent_config import AgentConfig
|
|
26
27
|
|
|
27
28
|
LI_packages = {
|
|
@@ -31,7 +32,6 @@ LI_packages = {
|
|
|
31
32
|
"exa": ToolType.QUERY,
|
|
32
33
|
"neo4j": ToolType.QUERY,
|
|
33
34
|
"kuzu": ToolType.QUERY,
|
|
34
|
-
"database": ToolType.QUERY,
|
|
35
35
|
"google": {
|
|
36
36
|
"GmailToolSpec": {
|
|
37
37
|
"load_data": ToolType.QUERY,
|
|
@@ -109,9 +109,20 @@ class VectaraTool(FunctionTool):
|
|
|
109
109
|
fn, name, description, return_direct, fn_schema, async_fn, tool_metadata,
|
|
110
110
|
callback, async_callback
|
|
111
111
|
)
|
|
112
|
-
vectara_tool = cls(
|
|
112
|
+
vectara_tool = cls(
|
|
113
|
+
tool_type=tool_type, fn=tool.fn, metadata=tool.metadata, async_fn=tool.async_fn,
|
|
114
|
+
)
|
|
113
115
|
return vectara_tool
|
|
114
116
|
|
|
117
|
+
def __str__(self) -> str:
|
|
118
|
+
return (
|
|
119
|
+
f"Tool(name={self.metadata.name}, "
|
|
120
|
+
f"Tool metadata={self.metadata})"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def __repr__(self) -> str:
|
|
124
|
+
return str(self)
|
|
125
|
+
|
|
115
126
|
def __eq__(self, other):
|
|
116
127
|
if not isinstance(other, VectaraTool):
|
|
117
128
|
return False
|
|
@@ -508,17 +519,29 @@ class VectaraToolFactory:
|
|
|
508
519
|
raw_input={"args": args, "kwargs": kwargs},
|
|
509
520
|
raw_output={"response": msg},
|
|
510
521
|
)
|
|
511
|
-
tool_output = "Matching documents:\n"
|
|
512
522
|
unique_ids = set()
|
|
523
|
+
docs = []
|
|
513
524
|
for doc in response:
|
|
514
525
|
if doc.id_ in unique_ids:
|
|
515
526
|
continue
|
|
516
527
|
unique_ids.add(doc.id_)
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
528
|
+
docs.append((doc.id_, doc.metadata))
|
|
529
|
+
tool_output = "Matching documents:\n"
|
|
530
|
+
if summarize:
|
|
531
|
+
summaries_dict = asyncio.run(
|
|
532
|
+
summarize_documents(
|
|
533
|
+
self.vectara_corpus_key,
|
|
534
|
+
self.vectara_api_key,
|
|
535
|
+
list(unique_ids)
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
for doc_id, metadata in docs:
|
|
539
|
+
summary = summaries_dict.get(doc_id, "")
|
|
540
|
+
tool_output += f"document_id: '{doc_id}'\nmetadata: '{metadata}'\nsummary: '{summary}'\n\n"
|
|
541
|
+
else:
|
|
542
|
+
for doc in docs:
|
|
521
543
|
tool_output += f"document_id: '{doc.id_}'\nmetadata: '{doc.metadata}'\n\n"
|
|
544
|
+
|
|
522
545
|
out = ToolOutput(
|
|
523
546
|
tool_name=search_function.__name__,
|
|
524
547
|
content=tool_output,
|
|
@@ -529,12 +552,14 @@ class VectaraToolFactory:
|
|
|
529
552
|
|
|
530
553
|
base_params = [
|
|
531
554
|
inspect.Parameter("query", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
|
|
532
|
-
inspect.Parameter("top_k", inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
555
|
+
inspect.Parameter("top_k", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
|
|
533
556
|
inspect.Parameter("summarize", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=True, annotation=bool),
|
|
534
557
|
]
|
|
535
558
|
search_tool_extra_desc = tool_description + "\n" + """
|
|
559
|
+
This tool is meant to perform a search for relevant documents, it is not meant for asking questions.
|
|
536
560
|
The response includes metadata about each relevant document.
|
|
537
|
-
If summarize=True, it also includes a summary of each document
|
|
561
|
+
If summarize=True, it also includes a summary of each document, but takes a lot longer to respond,
|
|
562
|
+
so avoid using it unless necessary.
|
|
538
563
|
"""
|
|
539
564
|
|
|
540
565
|
tool = _create_tool_from_dynamic_function(
|
|
@@ -905,7 +930,7 @@ class ToolsFactory:
|
|
|
905
930
|
user: str = "postgres",
|
|
906
931
|
password: str = "Password",
|
|
907
932
|
dbname: str = "postgres",
|
|
908
|
-
max_rows: int =
|
|
933
|
+
max_rows: int = 1000,
|
|
909
934
|
) -> List[VectaraTool]:
|
|
910
935
|
"""
|
|
911
936
|
Returns a list of database tools.
|
|
@@ -923,24 +948,16 @@ class ToolsFactory:
|
|
|
923
948
|
dbname (str, optional): The database name. Defaults to "postgres".
|
|
924
949
|
You must specify either the sql_database object or the scheme, host, port, user, password, and dbname.
|
|
925
950
|
max_rows (int, optional): if specified, instructs the load_data tool to never return more than max_rows
|
|
926
|
-
rows. Defaults to
|
|
951
|
+
rows. Defaults to 1000.
|
|
927
952
|
|
|
928
953
|
Returns:
|
|
929
954
|
List[VectaraTool]: A list of VectaraTool objects.
|
|
930
955
|
"""
|
|
931
956
|
if sql_database:
|
|
932
|
-
|
|
933
|
-
tool_package_name="database",
|
|
934
|
-
tool_spec_name="DatabaseToolSpec",
|
|
935
|
-
tool_name_prefix=tool_name_prefix,
|
|
936
|
-
sql_database=sql_database,
|
|
937
|
-
)
|
|
957
|
+
dbt = DatabaseTools(sql_database=sql_database)
|
|
938
958
|
else:
|
|
939
959
|
if scheme in ["postgresql", "mysql", "sqlite", "mssql", "oracle"]:
|
|
940
|
-
|
|
941
|
-
tool_package_name="database",
|
|
942
|
-
tool_spec_name="DatabaseToolSpec",
|
|
943
|
-
tool_name_prefix=tool_name_prefix,
|
|
960
|
+
dbt = DatabaseTools(
|
|
944
961
|
scheme=scheme,
|
|
945
962
|
host=host,
|
|
946
963
|
port=port,
|
|
@@ -955,28 +972,19 @@ class ToolsFactory:
|
|
|
955
972
|
)
|
|
956
973
|
|
|
957
974
|
# Update tools with description
|
|
975
|
+
tools = dbt.to_tool_list()
|
|
976
|
+
vtools = []
|
|
958
977
|
for tool in tools:
|
|
959
978
|
if content_description:
|
|
960
979
|
tool.metadata.description = (
|
|
961
980
|
tool.metadata.description + f"The database tables include data about {content_description}."
|
|
962
981
|
)
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
sample_data_fn = DBLoadSampleData(load_data_fn_original)
|
|
973
|
-
sample_data_fn.__name__ = f"{tool_name_prefix}_load_sample_data"
|
|
974
|
-
sample_data_tool = self.create_tool(sample_data_fn, ToolType.QUERY)
|
|
975
|
-
|
|
976
|
-
load_unique_values_fn = DBLoadUniqueValues(load_data_fn_original)
|
|
977
|
-
load_unique_values_fn.__name__ = f"{tool_name_prefix}_load_unique_values"
|
|
978
|
-
load_unique_values_tool = self.create_tool(load_unique_values_fn, ToolType.QUERY)
|
|
979
|
-
|
|
980
|
-
tools[load_data_tool_index] = load_data_tool
|
|
981
|
-
tools.extend([sample_data_tool, load_unique_values_tool])
|
|
982
|
-
return tools
|
|
982
|
+
if len(tool_name_prefix) > 0:
|
|
983
|
+
tool.metadata.name = tool_name_prefix + "_" + tool.metadata.name
|
|
984
|
+
vtool = VectaraTool(
|
|
985
|
+
tool_type=ToolType.QUERY,
|
|
986
|
+
fn=tool.fn, async_fn=tool.async_fn,
|
|
987
|
+
metadata=tool.metadata
|
|
988
|
+
)
|
|
989
|
+
vtools.append(vtool)
|
|
990
|
+
return vtools
|
|
@@ -6,9 +6,9 @@ from typing import Tuple, Callable, Optional
|
|
|
6
6
|
from functools import lru_cache
|
|
7
7
|
from inspect import signature
|
|
8
8
|
import json
|
|
9
|
-
import
|
|
10
|
-
|
|
9
|
+
import asyncio
|
|
11
10
|
import tiktoken
|
|
11
|
+
import aiohttp
|
|
12
12
|
|
|
13
13
|
from llama_index.core.llms import LLM
|
|
14
14
|
from llama_index.llms.openai import OpenAI
|
|
@@ -101,13 +101,16 @@ def get_llm(
|
|
|
101
101
|
max_tokens=max_tokens
|
|
102
102
|
)
|
|
103
103
|
elif model_provider == ModelProvider.ANTHROPIC:
|
|
104
|
-
llm = Anthropic(
|
|
104
|
+
llm = Anthropic(
|
|
105
|
+
model=model_name, temperature=0,
|
|
106
|
+
max_tokens=max_tokens, cache_idx=2,
|
|
107
|
+
)
|
|
105
108
|
elif model_provider == ModelProvider.GEMINI:
|
|
106
109
|
from llama_index.llms.gemini import Gemini
|
|
107
110
|
llm = Gemini(
|
|
108
111
|
model=model_name, temperature=0,
|
|
109
112
|
is_function_calling_model=True,
|
|
110
|
-
max_tokens=max_tokens
|
|
113
|
+
max_tokens=max_tokens,
|
|
111
114
|
)
|
|
112
115
|
elif model_provider == ModelProvider.TOGETHER:
|
|
113
116
|
from llama_index.llms.together import TogetherLLM
|
|
@@ -159,7 +162,7 @@ def remove_self_from_signature(func):
|
|
|
159
162
|
func.__signature__ = new_sig
|
|
160
163
|
return func
|
|
161
164
|
|
|
162
|
-
def summarize_vectara_document(corpus_key: str, vectara_api_key, doc_id: str) -> str:
|
|
165
|
+
async def summarize_vectara_document(corpus_key: str, vectara_api_key: str, doc_id: str) -> str:
|
|
163
166
|
"""
|
|
164
167
|
Summarize a document in a Vectara corpus using the Vectara API.
|
|
165
168
|
"""
|
|
@@ -175,8 +178,32 @@ def summarize_vectara_document(corpus_key: str, vectara_api_key, doc_id: str) ->
|
|
|
175
178
|
'Accept': 'application/json',
|
|
176
179
|
'x-api-key': vectara_api_key
|
|
177
180
|
}
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
181
|
+
timeout = aiohttp.ClientTimeout(total=60)
|
|
182
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
183
|
+
async with session.post(url, headers=headers, data=payload) as response:
|
|
184
|
+
if response.status != 200:
|
|
185
|
+
error_json = await response.json()
|
|
186
|
+
return (
|
|
187
|
+
f"Vectara Summarization failed with error code {response.status}, "
|
|
188
|
+
f"error={error_json['messages'][0]}"
|
|
189
|
+
)
|
|
190
|
+
data = await response.json()
|
|
191
|
+
return data["summary"]
|
|
182
192
|
return json.loads(response.text)["summary"]
|
|
193
|
+
|
|
194
|
+
async def summarize_documents(
|
|
195
|
+
vectara_corpus_key: str,
|
|
196
|
+
vectara_api_key: str,
|
|
197
|
+
doc_ids: list[str]
|
|
198
|
+
) -> dict[str, str]:
|
|
199
|
+
"""
|
|
200
|
+
Summarize multiple documents in a Vectara corpus using the Vectara API.
|
|
201
|
+
"""
|
|
202
|
+
if not doc_ids:
|
|
203
|
+
return {}
|
|
204
|
+
tasks = [
|
|
205
|
+
summarize_vectara_document(vectara_corpus_key, vectara_api_key, doc_id)
|
|
206
|
+
for doc_id in doc_ids
|
|
207
|
+
]
|
|
208
|
+
summaries = await asyncio.gather(*tasks, return_exceptions=True)
|
|
209
|
+
return dict(zip(doc_ids, summaries))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: vectara_agentic
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Summary: A Python package for creating AI Assistants and AI Agents with Vectara
|
|
5
5
|
Home-page: https://github.com/vectara/py-vectara-agentic
|
|
6
6
|
Author: Ofer Mendelevitch
|
|
@@ -16,7 +16,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
16
16
|
Requires-Python: >=3.10
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
|
-
Requires-Dist: llama-index==0.12.
|
|
19
|
+
Requires-Dist: llama-index==0.12.26
|
|
20
20
|
Requires-Dist: llama-index-indices-managed-vectara==0.4.2
|
|
21
21
|
Requires-Dist: llama-index-agent-llm-compiler==0.3.0
|
|
22
22
|
Requires-Dist: llama-index-agent-lats==0.3.0
|
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
This module contains the code to extend and improve DatabaseToolSpec
|
|
3
|
-
Specifically adding load_sample_data and load_unique_values methods, as well as
|
|
4
|
-
making sure the load_data method returns a list of text values from the database, not Document[] objects.
|
|
5
|
-
"""
|
|
6
|
-
from abc import ABC
|
|
7
|
-
from typing import Callable, Any
|
|
8
|
-
|
|
9
|
-
#
|
|
10
|
-
# Additional database tool
|
|
11
|
-
#
|
|
12
|
-
class DBTool(ABC):
|
|
13
|
-
"""
|
|
14
|
-
A base class for vectara-agentic database tools extensions
|
|
15
|
-
"""
|
|
16
|
-
def __init__(self, load_data_fn: Callable, max_rows: int = 1000):
|
|
17
|
-
self.load_data_fn = load_data_fn
|
|
18
|
-
self.max_rows = max_rows
|
|
19
|
-
|
|
20
|
-
class DBLoadData(DBTool):
|
|
21
|
-
"""
|
|
22
|
-
A tool to Run SQL query on the database and return the result.
|
|
23
|
-
"""
|
|
24
|
-
def __call__(self, query: str) -> Any:
|
|
25
|
-
"""Query and load data from the Database, returning a list of Documents.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
query (str): an SQL query to filter tables and rows.
|
|
29
|
-
|
|
30
|
-
Returns:
|
|
31
|
-
List[text]: a list of text values from the database.
|
|
32
|
-
"""
|
|
33
|
-
count_query = f"SELECT COUNT(*) FROM ({query})"
|
|
34
|
-
try:
|
|
35
|
-
count_rows = self.load_data_fn(count_query)
|
|
36
|
-
except Exception as e:
|
|
37
|
-
return [f"Error ({str(e)}) occurred while counting number of rows"]
|
|
38
|
-
num_rows = int(count_rows[0].text)
|
|
39
|
-
if num_rows > self.max_rows:
|
|
40
|
-
return [
|
|
41
|
-
f"The query is expected to return more than {self.max_rows} rows. "
|
|
42
|
-
"Please refactor your query to make it return less rows. "
|
|
43
|
-
]
|
|
44
|
-
try:
|
|
45
|
-
res = self.load_data_fn(query)
|
|
46
|
-
except Exception as e:
|
|
47
|
-
return [f"Error ({str(e)}) occurred while executing the query {query}"]
|
|
48
|
-
return [d.text for d in res]
|
|
49
|
-
|
|
50
|
-
class DBLoadSampleData(DBTool):
|
|
51
|
-
"""
|
|
52
|
-
A tool to load a sample of data from the specified database table.
|
|
53
|
-
|
|
54
|
-
This tool fetches the first num_rows (default 25) rows from the given table
|
|
55
|
-
using a provided database query function.
|
|
56
|
-
"""
|
|
57
|
-
def __call__(self, table_name: str, num_rows: int = 25) -> Any:
|
|
58
|
-
"""
|
|
59
|
-
Fetches the first num_rows rows from the specified database table.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
table_name (str): The name of the database table.
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
Any: The result of the database query.
|
|
66
|
-
"""
|
|
67
|
-
try:
|
|
68
|
-
res = self.load_data_fn(f"SELECT * FROM {table_name} LIMIT {num_rows}")
|
|
69
|
-
except Exception as e:
|
|
70
|
-
return [f"Error ({str(e)}) occurred while loading sample data for table {table_name}"]
|
|
71
|
-
return res
|
|
72
|
-
|
|
73
|
-
class DBLoadUniqueValues(DBTool):
|
|
74
|
-
"""
|
|
75
|
-
A tool to list all unique values for each column in a set of columns of a database table.
|
|
76
|
-
"""
|
|
77
|
-
def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> Any:
|
|
78
|
-
"""
|
|
79
|
-
Fetches the first num_vals unique values from the specified columns of the database table.
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
table_name (str): The name of the database table.
|
|
83
|
-
columns (list[str]): The list of columns to fetch unique values from.
|
|
84
|
-
num_vals (int): The number of unique values to fetch for each column. Default is 200.
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
Any: the result of the database query
|
|
88
|
-
"""
|
|
89
|
-
res = {}
|
|
90
|
-
try:
|
|
91
|
-
for column in columns:
|
|
92
|
-
unique_vals = self.load_data_fn(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
|
|
93
|
-
res[column] = [d.text for d in unique_vals]
|
|
94
|
-
except Exception as e:
|
|
95
|
-
return {f"Error ({str(e)}) occurred while loading unique values for table {table_name}"}
|
|
96
|
-
return res
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{vectara_agentic-0.2.7 → vectara_agentic-0.2.8}/vectara_agentic.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|