vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,15 +8,15 @@ import requests
|
|
|
8
8
|
from ..advanced import VannaAdvanced
|
|
9
9
|
from ..base import VannaBase
|
|
10
10
|
from ..types import (
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
11
|
+
DataFrameJSON,
|
|
12
|
+
NewOrganization,
|
|
13
|
+
OrganizationList,
|
|
14
|
+
Question,
|
|
15
|
+
QuestionSQLPair,
|
|
16
|
+
Status,
|
|
17
|
+
StatusWithId,
|
|
18
|
+
StringData,
|
|
19
|
+
TrainingData,
|
|
20
20
|
)
|
|
21
21
|
from ..utils import sanitize_model_name
|
|
22
22
|
|
|
@@ -85,18 +85,25 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
|
85
85
|
}
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
|
-
response = requests.post(
|
|
88
|
+
response = requests.post(
|
|
89
|
+
self._graphql_endpoint, headers=self._graphql_headers, json={"query": query}
|
|
90
|
+
)
|
|
89
91
|
response_json = response.json()
|
|
90
|
-
if
|
|
91
|
-
|
|
92
|
-
|
|
92
|
+
if (
|
|
93
|
+
response.status_code == 200
|
|
94
|
+
and "data" in response_json
|
|
95
|
+
and "get_all_sql_functions" in response_json["data"]
|
|
96
|
+
):
|
|
97
|
+
self.log(response_json["data"]["get_all_sql_functions"])
|
|
98
|
+
resp = response_json["data"]["get_all_sql_functions"]
|
|
93
99
|
|
|
94
100
|
print(resp)
|
|
95
101
|
|
|
96
102
|
return resp
|
|
97
103
|
else:
|
|
98
|
-
raise Exception(
|
|
99
|
-
|
|
104
|
+
raise Exception(
|
|
105
|
+
f"Query failed to run by returning code of {response.status_code}. {response.text}"
|
|
106
|
+
)
|
|
100
107
|
|
|
101
108
|
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
102
109
|
query = """
|
|
@@ -121,21 +128,38 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
|
121
128
|
}
|
|
122
129
|
}
|
|
123
130
|
"""
|
|
124
|
-
static_function_arguments = [
|
|
125
|
-
|
|
126
|
-
|
|
131
|
+
static_function_arguments = [
|
|
132
|
+
{"name": key, "value": str(value)} for key, value in additional_data.items()
|
|
133
|
+
]
|
|
134
|
+
variables = {
|
|
135
|
+
"question": question,
|
|
136
|
+
"staticFunctionArguments": static_function_arguments,
|
|
137
|
+
}
|
|
138
|
+
response = requests.post(
|
|
139
|
+
self._graphql_endpoint,
|
|
140
|
+
headers=self._graphql_headers,
|
|
141
|
+
json={"query": query, "variables": variables},
|
|
142
|
+
)
|
|
127
143
|
response_json = response.json()
|
|
128
|
-
if
|
|
129
|
-
|
|
130
|
-
|
|
144
|
+
if (
|
|
145
|
+
response.status_code == 200
|
|
146
|
+
and "data" in response_json
|
|
147
|
+
and "get_and_instantiate_function" in response_json["data"]
|
|
148
|
+
):
|
|
149
|
+
self.log(response_json["data"]["get_and_instantiate_function"])
|
|
150
|
+
resp = response_json["data"]["get_and_instantiate_function"]
|
|
131
151
|
|
|
132
152
|
print(resp)
|
|
133
153
|
|
|
134
154
|
return resp
|
|
135
155
|
else:
|
|
136
|
-
raise Exception(
|
|
156
|
+
raise Exception(
|
|
157
|
+
f"Query failed to run by returning code of {response.status_code}. {response.text}"
|
|
158
|
+
)
|
|
137
159
|
|
|
138
|
-
def create_function(
|
|
160
|
+
def create_function(
|
|
161
|
+
self, question: str, sql: str, plotly_code: str, **kwargs
|
|
162
|
+
) -> dict:
|
|
139
163
|
query = """
|
|
140
164
|
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
|
|
141
165
|
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
|
|
@@ -153,16 +177,27 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
|
153
177
|
}
|
|
154
178
|
"""
|
|
155
179
|
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
|
|
156
|
-
response = requests.post(
|
|
180
|
+
response = requests.post(
|
|
181
|
+
self._graphql_endpoint,
|
|
182
|
+
headers=self._graphql_headers,
|
|
183
|
+
json={"query": query, "variables": variables},
|
|
184
|
+
)
|
|
157
185
|
response_json = response.json()
|
|
158
|
-
if
|
|
159
|
-
|
|
186
|
+
if (
|
|
187
|
+
response.status_code == 200
|
|
188
|
+
and "data" in response_json
|
|
189
|
+
and response_json["data"] is not None
|
|
190
|
+
and "generate_and_create_sql_function" in response_json["data"]
|
|
191
|
+
):
|
|
192
|
+
resp = response_json["data"]["generate_and_create_sql_function"]
|
|
160
193
|
|
|
161
194
|
print(resp)
|
|
162
195
|
|
|
163
196
|
return resp
|
|
164
197
|
else:
|
|
165
|
-
raise Exception(
|
|
198
|
+
raise Exception(
|
|
199
|
+
f"Query failed to run by returning code of {response.status_code}. {response.text}"
|
|
200
|
+
)
|
|
166
201
|
|
|
167
202
|
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
168
203
|
"""
|
|
@@ -187,41 +222,64 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
|
187
222
|
"""
|
|
188
223
|
|
|
189
224
|
SQLFunctionUpdate = {
|
|
190
|
-
|
|
225
|
+
"function_name",
|
|
226
|
+
"description",
|
|
227
|
+
"arguments",
|
|
228
|
+
"sql_template",
|
|
229
|
+
"post_processing_code_template",
|
|
191
230
|
}
|
|
192
231
|
|
|
193
232
|
# Define the expected keys for each argument in the arguments list
|
|
194
|
-
ArgumentKeys = {
|
|
233
|
+
ArgumentKeys = {
|
|
234
|
+
"name",
|
|
235
|
+
"general_type",
|
|
236
|
+
"description",
|
|
237
|
+
"is_user_editable",
|
|
238
|
+
"available_values",
|
|
239
|
+
}
|
|
195
240
|
|
|
196
241
|
# Function to validate and transform arguments
|
|
197
242
|
def validate_arguments(args):
|
|
198
243
|
return [
|
|
199
|
-
{key: arg[key] for key in arg if key in ArgumentKeys}
|
|
200
|
-
for arg in args
|
|
244
|
+
{key: arg[key] for key in arg if key in ArgumentKeys} for arg in args
|
|
201
245
|
]
|
|
202
246
|
|
|
203
247
|
# Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
|
|
204
|
-
updated_function = {
|
|
248
|
+
updated_function = {
|
|
249
|
+
key: value
|
|
250
|
+
for key, value in updated_function.items()
|
|
251
|
+
if key in SQLFunctionUpdate
|
|
252
|
+
}
|
|
205
253
|
|
|
206
254
|
# Special handling for 'arguments' to ensure they conform to the spec
|
|
207
|
-
if
|
|
208
|
-
updated_function[
|
|
255
|
+
if "arguments" in updated_function:
|
|
256
|
+
updated_function["arguments"] = validate_arguments(
|
|
257
|
+
updated_function["arguments"]
|
|
258
|
+
)
|
|
209
259
|
|
|
210
260
|
variables = {
|
|
211
|
-
"input": {
|
|
212
|
-
"old_function_name": old_function_name,
|
|
213
|
-
**updated_function
|
|
214
|
-
}
|
|
261
|
+
"input": {"old_function_name": old_function_name, **updated_function}
|
|
215
262
|
}
|
|
216
263
|
|
|
217
264
|
print("variables", variables)
|
|
218
265
|
|
|
219
|
-
response = requests.post(
|
|
266
|
+
response = requests.post(
|
|
267
|
+
self._graphql_endpoint,
|
|
268
|
+
headers=self._graphql_headers,
|
|
269
|
+
json={"query": mutation, "variables": variables},
|
|
270
|
+
)
|
|
220
271
|
response_json = response.json()
|
|
221
|
-
if
|
|
222
|
-
|
|
272
|
+
if (
|
|
273
|
+
response.status_code == 200
|
|
274
|
+
and "data" in response_json
|
|
275
|
+
and response_json["data"] is not None
|
|
276
|
+
and "update_sql_function" in response_json["data"]
|
|
277
|
+
):
|
|
278
|
+
return response_json["data"]["update_sql_function"]
|
|
223
279
|
else:
|
|
224
|
-
raise Exception(
|
|
280
|
+
raise Exception(
|
|
281
|
+
f"Mutation failed to run by returning code of {response.status_code}. {response.text}"
|
|
282
|
+
)
|
|
225
283
|
|
|
226
284
|
def delete_function(self, function_name: str) -> bool:
|
|
227
285
|
mutation = """
|
|
@@ -230,12 +288,23 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
|
230
288
|
}
|
|
231
289
|
"""
|
|
232
290
|
variables = {"function_name": function_name}
|
|
233
|
-
response = requests.post(
|
|
291
|
+
response = requests.post(
|
|
292
|
+
self._graphql_endpoint,
|
|
293
|
+
headers=self._graphql_headers,
|
|
294
|
+
json={"query": mutation, "variables": variables},
|
|
295
|
+
)
|
|
234
296
|
response_json = response.json()
|
|
235
|
-
if
|
|
236
|
-
|
|
297
|
+
if (
|
|
298
|
+
response.status_code == 200
|
|
299
|
+
and "data" in response_json
|
|
300
|
+
and response_json["data"] is not None
|
|
301
|
+
and "delete_sql_function" in response_json["data"]
|
|
302
|
+
):
|
|
303
|
+
return response_json["data"]["delete_sql_function"]
|
|
237
304
|
else:
|
|
238
|
-
raise Exception(
|
|
305
|
+
raise Exception(
|
|
306
|
+
f"Mutation failed to run by returning code of {response.status_code}. {response.text}"
|
|
307
|
+
)
|
|
239
308
|
|
|
240
309
|
def create_model(self, model: str, **kwargs) -> bool:
|
|
241
310
|
"""
|
|
@@ -80,13 +80,12 @@ class Vllm(VannaBase):
|
|
|
80
80
|
}
|
|
81
81
|
|
|
82
82
|
if self.auth_key is not None:
|
|
83
|
-
headers = {
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
headers = {
|
|
84
|
+
"Content-Type": "application/json",
|
|
85
|
+
"Authorization": f"Bearer {self.auth_key}",
|
|
86
86
|
}
|
|
87
87
|
|
|
88
|
-
response = requests.post(url, headers=headers,json=data)
|
|
89
|
-
|
|
88
|
+
response = requests.post(url, headers=headers, json=data)
|
|
90
89
|
|
|
91
90
|
else:
|
|
92
91
|
response = requests.post(url, json=data)
|
|
@@ -95,4 +94,4 @@ class Vllm(VannaBase):
|
|
|
95
94
|
|
|
96
95
|
self.log(response.text)
|
|
97
96
|
|
|
98
|
-
return response_dict[
|
|
97
|
+
return response_dict["choices"][0]["message"]["content"]
|
|
@@ -6,7 +6,6 @@ from vanna.base import VannaBase
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class WeaviateDatabase(VannaBase):
|
|
9
|
-
|
|
10
9
|
def __init__(self, config=None):
|
|
11
10
|
"""
|
|
12
11
|
Initialize the VannaEnhanced class with the provided configuration.
|
|
@@ -42,30 +41,35 @@ class WeaviateDatabase(VannaBase):
|
|
|
42
41
|
self.training_data_cluster = {
|
|
43
42
|
"sql": "SQLTrainingDataEntry",
|
|
44
43
|
"ddl": "DDLEntry",
|
|
45
|
-
"doc": "DocumentationEntry"
|
|
44
|
+
"doc": "DocumentationEntry",
|
|
46
45
|
}
|
|
47
46
|
|
|
48
47
|
self._create_collections_if_not_exist()
|
|
49
48
|
|
|
50
49
|
def _create_collections_if_not_exist(self):
|
|
51
50
|
properties_dict = {
|
|
52
|
-
self.training_data_cluster[
|
|
53
|
-
wvc.config.Property(
|
|
51
|
+
self.training_data_cluster["ddl"]: [
|
|
52
|
+
wvc.config.Property(
|
|
53
|
+
name="description", data_type=wvc.config.DataType.TEXT
|
|
54
|
+
),
|
|
54
55
|
],
|
|
55
|
-
self.training_data_cluster[
|
|
56
|
-
wvc.config.Property(
|
|
56
|
+
self.training_data_cluster["doc"]: [
|
|
57
|
+
wvc.config.Property(
|
|
58
|
+
name="description", data_type=wvc.config.DataType.TEXT
|
|
59
|
+
),
|
|
57
60
|
],
|
|
58
|
-
self.training_data_cluster[
|
|
61
|
+
self.training_data_cluster["sql"]: [
|
|
59
62
|
wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
|
|
60
|
-
wvc.config.Property(
|
|
61
|
-
|
|
63
|
+
wvc.config.Property(
|
|
64
|
+
name="natural_language_question", data_type=wvc.config.DataType.TEXT
|
|
65
|
+
),
|
|
66
|
+
],
|
|
62
67
|
}
|
|
63
68
|
|
|
64
69
|
for cluster, properties in properties_dict.items():
|
|
65
70
|
if not self.weaviate_client.collections.exists(cluster):
|
|
66
71
|
self.weaviate_client.collections.create(
|
|
67
|
-
name=cluster,
|
|
68
|
-
properties=properties
|
|
72
|
+
name=cluster, properties=properties
|
|
69
73
|
)
|
|
70
74
|
|
|
71
75
|
def _initialize_weaviate_client(self):
|
|
@@ -74,28 +78,26 @@ class WeaviateDatabase(VannaBase):
|
|
|
74
78
|
cluster_url=self.weaviate_url,
|
|
75
79
|
auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
|
|
76
80
|
additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
|
|
77
|
-
skip_init_checks=True
|
|
81
|
+
skip_init_checks=True,
|
|
78
82
|
)
|
|
79
83
|
else:
|
|
80
84
|
return weaviate.connect_to_local(
|
|
81
85
|
port=self.weaviate_port,
|
|
82
86
|
grpc_port=self.weaviate_grpc_port,
|
|
83
87
|
additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
|
|
84
|
-
skip_init_checks=True
|
|
88
|
+
skip_init_checks=True,
|
|
85
89
|
)
|
|
86
90
|
|
|
87
91
|
def generate_embedding(self, data: str, **kwargs):
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
embedding_model = TextEmbedding(model_name=self.fastembed_model)
|
|
93
|
+
embedding = next(embedding_model.embed(data))
|
|
94
|
+
return embedding.tolist()
|
|
92
95
|
|
|
93
96
|
def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
|
|
94
97
|
self.weaviate_client.connect()
|
|
95
|
-
response = self.weaviate_client.collections.get(
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
98
|
+
response = self.weaviate_client.collections.get(
|
|
99
|
+
self.training_data_cluster[cluster_key]
|
|
100
|
+
).data.insert(properties=data_object, vector=vector)
|
|
99
101
|
self.weaviate_client.close()
|
|
100
102
|
return response
|
|
101
103
|
|
|
@@ -103,31 +105,37 @@ class WeaviateDatabase(VannaBase):
|
|
|
103
105
|
data_object = {
|
|
104
106
|
"description": ddl,
|
|
105
107
|
}
|
|
106
|
-
response = self._insert_data(
|
|
107
|
-
return f
|
|
108
|
+
response = self._insert_data("ddl", data_object, self.generate_embedding(ddl))
|
|
109
|
+
return f"{response}-ddl"
|
|
108
110
|
|
|
109
111
|
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
110
112
|
data_object = {
|
|
111
113
|
"description": doc,
|
|
112
114
|
}
|
|
113
|
-
response = self._insert_data(
|
|
114
|
-
return f
|
|
115
|
+
response = self._insert_data("doc", data_object, self.generate_embedding(doc))
|
|
116
|
+
return f"{response}-doc"
|
|
115
117
|
|
|
116
118
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
117
119
|
data_object = {
|
|
118
120
|
"sql": sql,
|
|
119
121
|
"natural_language_question": question,
|
|
120
122
|
}
|
|
121
|
-
response = self._insert_data(
|
|
122
|
-
|
|
123
|
+
response = self._insert_data(
|
|
124
|
+
"sql", data_object, self.generate_embedding(question)
|
|
125
|
+
)
|
|
126
|
+
return f"{response}-sql"
|
|
123
127
|
|
|
124
|
-
def _query_collection(
|
|
128
|
+
def _query_collection(
|
|
129
|
+
self, cluster_key: str, vector_input: list, return_properties: list
|
|
130
|
+
) -> list:
|
|
125
131
|
self.weaviate_client.connect()
|
|
126
|
-
collection = self.weaviate_client.collections.get(
|
|
132
|
+
collection = self.weaviate_client.collections.get(
|
|
133
|
+
self.training_data_cluster[cluster_key]
|
|
134
|
+
)
|
|
127
135
|
response = collection.query.near_vector(
|
|
128
136
|
near_vector=vector_input,
|
|
129
137
|
limit=self.n_results,
|
|
130
|
-
return_properties=return_properties
|
|
138
|
+
return_properties=return_properties,
|
|
131
139
|
)
|
|
132
140
|
response_list = [item.properties for item in response.objects]
|
|
133
141
|
self.weaviate_client.close()
|
|
@@ -135,18 +143,23 @@ class WeaviateDatabase(VannaBase):
|
|
|
135
143
|
|
|
136
144
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
137
145
|
vector_input = self.generate_embedding(question)
|
|
138
|
-
response_list = self._query_collection(
|
|
146
|
+
response_list = self._query_collection("ddl", vector_input, ["description"])
|
|
139
147
|
return [item["description"] for item in response_list]
|
|
140
148
|
|
|
141
149
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
142
150
|
vector_input = self.generate_embedding(question)
|
|
143
|
-
response_list = self._query_collection(
|
|
151
|
+
response_list = self._query_collection("doc", vector_input, ["description"])
|
|
144
152
|
return [item["description"] for item in response_list]
|
|
145
153
|
|
|
146
154
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
147
155
|
vector_input = self.generate_embedding(question)
|
|
148
|
-
response_list = self._query_collection(
|
|
149
|
-
|
|
156
|
+
response_list = self._query_collection(
|
|
157
|
+
"sql", vector_input, ["sql", "natural_language_question"]
|
|
158
|
+
)
|
|
159
|
+
return [
|
|
160
|
+
{"question": item["natural_language_question"], "sql": item["sql"]}
|
|
161
|
+
for item in response_list
|
|
162
|
+
]
|
|
150
163
|
|
|
151
164
|
def get_training_data(self, **kwargs) -> list:
|
|
152
165
|
self.weaviate_client.connect()
|
|
@@ -163,13 +176,19 @@ class WeaviateDatabase(VannaBase):
|
|
|
163
176
|
self.weaviate_client.connect()
|
|
164
177
|
success = False
|
|
165
178
|
if id.endswith("-sql"):
|
|
166
|
-
id = id.replace(
|
|
167
|
-
success = self.weaviate_client.collections.get(
|
|
179
|
+
id = id.replace("-sql", "")
|
|
180
|
+
success = self.weaviate_client.collections.get(
|
|
181
|
+
self.training_data_cluster["sql"]
|
|
182
|
+
).data.delete_by_id(id)
|
|
168
183
|
elif id.endswith("-ddl"):
|
|
169
|
-
id = id.replace(
|
|
170
|
-
success = self.weaviate_client.collections.get(
|
|
184
|
+
id = id.replace("-ddl", "")
|
|
185
|
+
success = self.weaviate_client.collections.get(
|
|
186
|
+
self.training_data_cluster["ddl"]
|
|
187
|
+
).data.delete_by_id(id)
|
|
171
188
|
elif id.endswith("-doc"):
|
|
172
|
-
id = id.replace(
|
|
173
|
-
success = self.weaviate_client.collections.get(
|
|
189
|
+
id = id.replace("-doc", "")
|
|
190
|
+
success = self.weaviate_client.collections.get(
|
|
191
|
+
self.training_data_cluster["doc"]
|
|
192
|
+
).data.delete_by_id(id)
|
|
174
193
|
self.weaviate_client.close()
|
|
175
194
|
return success
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from xinference_client.client.restful.restful_client import (
|
|
2
|
-
|
|
3
|
-
|
|
2
|
+
Client,
|
|
3
|
+
RESTfulChatModelHandle,
|
|
4
4
|
)
|
|
5
5
|
|
|
6
6
|
from ..base import VannaBase
|
|
@@ -43,11 +43,11 @@ class Xinference(VannaBase):
|
|
|
43
43
|
|
|
44
44
|
xinference_model = self.xinference_client.get_model(model_uid)
|
|
45
45
|
if isinstance(xinference_model, RESTfulChatModelHandle):
|
|
46
|
-
print(
|
|
47
|
-
f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
|
|
48
|
-
)
|
|
46
|
+
print(f"Using model_uid {model_uid} for {num_tokens} tokens (approx)")
|
|
49
47
|
|
|
50
48
|
response = xinference_model.chat(prompt)
|
|
51
49
|
return response["choices"][0]["message"]["content"]
|
|
52
50
|
else:
|
|
53
|
-
raise NotImplementedError(
|
|
51
|
+
raise NotImplementedError(
|
|
52
|
+
f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle"
|
|
53
|
+
)
|
vanna/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Server implementations for the Vanna Agents framework.
|
|
3
|
+
|
|
4
|
+
This module provides Flask and FastAPI server factories for serving
|
|
5
|
+
Vanna agents over HTTP with SSE, WebSocket, and polling endpoints.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .base import ChatHandler, ChatRequest, ChatStreamChunk
|
|
9
|
+
from .cli.server_runner import ExampleAgentLoader
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ChatHandler",
|
|
13
|
+
"ChatRequest",
|
|
14
|
+
"ChatStreamChunk",
|
|
15
|
+
"ExampleAgentLoader",
|
|
16
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base server components for the Vanna Agents framework.
|
|
3
|
+
|
|
4
|
+
This module provides framework-agnostic components for handling chat
|
|
5
|
+
requests and responses.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .chat_handler import ChatHandler
|
|
9
|
+
from .models import ChatRequest, ChatStreamChunk, ChatResponse
|
|
10
|
+
from .templates import INDEX_HTML
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ChatHandler",
|
|
14
|
+
"ChatRequest",
|
|
15
|
+
"ChatStreamChunk",
|
|
16
|
+
"ChatResponse",
|
|
17
|
+
"INDEX_HTML",
|
|
18
|
+
]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Framework-agnostic chat handling logic.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import AsyncGenerator, List
|
|
7
|
+
|
|
8
|
+
from ...core import Agent
|
|
9
|
+
from .models import ChatRequest, ChatResponse, ChatStreamChunk
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ChatHandler:
|
|
13
|
+
"""Core chat handling logic - framework agnostic."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
agent: Agent,
|
|
18
|
+
):
|
|
19
|
+
"""Initialize chat handler.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
agent: The agent to handle chat requests
|
|
23
|
+
"""
|
|
24
|
+
self.agent = agent
|
|
25
|
+
|
|
26
|
+
async def handle_stream(
|
|
27
|
+
self, request: ChatRequest
|
|
28
|
+
) -> AsyncGenerator[ChatStreamChunk, None]:
|
|
29
|
+
"""Stream chat responses.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
request: Chat request
|
|
33
|
+
|
|
34
|
+
Yields:
|
|
35
|
+
Chat stream chunks
|
|
36
|
+
"""
|
|
37
|
+
conversation_id = request.conversation_id or self._generate_conversation_id()
|
|
38
|
+
# Use request_id from client for tracking, or use the one generated internally
|
|
39
|
+
request_id = request.request_id or str(uuid.uuid4())
|
|
40
|
+
|
|
41
|
+
async for component in self.agent.send_message(
|
|
42
|
+
request_context=request.request_context,
|
|
43
|
+
message=request.message,
|
|
44
|
+
conversation_id=conversation_id,
|
|
45
|
+
):
|
|
46
|
+
yield ChatStreamChunk.from_component(component, conversation_id, request_id)
|
|
47
|
+
|
|
48
|
+
async def handle_poll(self, request: ChatRequest) -> ChatResponse:
|
|
49
|
+
"""Handle polling-based chat.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
request: Chat request
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Complete chat response
|
|
56
|
+
"""
|
|
57
|
+
chunks = []
|
|
58
|
+
async for chunk in self.handle_stream(request):
|
|
59
|
+
chunks.append(chunk)
|
|
60
|
+
|
|
61
|
+
return ChatResponse.from_chunks(chunks)
|
|
62
|
+
|
|
63
|
+
def _generate_conversation_id(self) -> str:
|
|
64
|
+
"""Generate new conversation ID."""
|
|
65
|
+
return f"conv_{uuid.uuid4().hex[:8]}"
|