vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +247 -223
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.8.dist-info/METADATA +0 -408
- vanna-0.7.8.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -136,7 +136,7 @@ class VannaBase(ABC):
|
|
|
136
136
|
llm_response = self.submit_prompt(prompt, **kwargs)
|
|
137
137
|
self.log(title="LLM Response", message=llm_response)
|
|
138
138
|
|
|
139
|
-
if
|
|
139
|
+
if "intermediate_sql" in llm_response:
|
|
140
140
|
if not allow_llm_to_see_data:
|
|
141
141
|
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
|
|
142
142
|
|
|
@@ -152,7 +152,11 @@ class VannaBase(ABC):
|
|
|
152
152
|
question=question,
|
|
153
153
|
question_sql_list=question_sql_list,
|
|
154
154
|
ddl_list=ddl_list,
|
|
155
|
-
doc_list=doc_list
|
|
155
|
+
doc_list=doc_list
|
|
156
|
+
+ [
|
|
157
|
+
f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n"
|
|
158
|
+
+ df.to_markdown()
|
|
159
|
+
],
|
|
156
160
|
**kwargs,
|
|
157
161
|
)
|
|
158
162
|
self.log(title="Final SQL Prompt", message=prompt)
|
|
@@ -161,7 +165,6 @@ class VannaBase(ABC):
|
|
|
161
165
|
except Exception as e:
|
|
162
166
|
return f"Error running intermediate SQL: {e}"
|
|
163
167
|
|
|
164
|
-
|
|
165
168
|
return self.extract_sql(llm_response)
|
|
166
169
|
|
|
167
170
|
def extract_sql(self, llm_response: str) -> str:
|
|
@@ -181,30 +184,52 @@ class VannaBase(ABC):
|
|
|
181
184
|
str: The extracted SQL query.
|
|
182
185
|
"""
|
|
183
186
|
|
|
184
|
-
|
|
185
|
-
|
|
187
|
+
import re
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
Extracts the SQL query from the LLM response, handling various formats including:
|
|
191
|
+
- WITH clause
|
|
192
|
+
- SELECT statement
|
|
193
|
+
- CREATE TABLE AS SELECT
|
|
194
|
+
- Markdown code blocks
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
# Match CREATE TABLE ... AS SELECT
|
|
198
|
+
sqls = re.findall(
|
|
199
|
+
r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE
|
|
200
|
+
)
|
|
186
201
|
if sqls:
|
|
187
202
|
sql = sqls[-1]
|
|
188
203
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
189
204
|
return sql
|
|
190
205
|
|
|
191
|
-
#
|
|
192
|
-
sqls = re.findall(r"
|
|
206
|
+
# Match WITH clause (CTEs)
|
|
207
|
+
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
|
|
193
208
|
if sqls:
|
|
194
209
|
sql = sqls[-1]
|
|
195
210
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
196
211
|
return sql
|
|
197
212
|
|
|
198
|
-
#
|
|
199
|
-
sqls = re.findall(r"
|
|
213
|
+
# Match SELECT ... ;
|
|
214
|
+
sqls = re.findall(r"\bSELECT\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
|
|
200
215
|
if sqls:
|
|
201
216
|
sql = sqls[-1]
|
|
202
217
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
203
218
|
return sql
|
|
204
219
|
|
|
205
|
-
|
|
220
|
+
# Match ```sql ... ``` blocks
|
|
221
|
+
sqls = re.findall(
|
|
222
|
+
r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE
|
|
223
|
+
)
|
|
206
224
|
if sqls:
|
|
207
|
-
sql = sqls[-1]
|
|
225
|
+
sql = sqls[-1].strip()
|
|
226
|
+
self.log(title="Extracted SQL", message=f"{sql}")
|
|
227
|
+
return sql
|
|
228
|
+
|
|
229
|
+
# Match any ``` ... ``` code blocks
|
|
230
|
+
sqls = re.findall(r"```(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
|
|
231
|
+
if sqls:
|
|
232
|
+
sql = sqls[-1].strip()
|
|
208
233
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
209
234
|
return sql
|
|
210
235
|
|
|
@@ -229,7 +254,7 @@ class VannaBase(ABC):
|
|
|
229
254
|
parsed = sqlparse.parse(sql)
|
|
230
255
|
|
|
231
256
|
for statement in parsed:
|
|
232
|
-
if statement.get_type() ==
|
|
257
|
+
if statement.get_type() == "SELECT":
|
|
233
258
|
return True
|
|
234
259
|
|
|
235
260
|
return False
|
|
@@ -251,12 +276,14 @@ class VannaBase(ABC):
|
|
|
251
276
|
bool: True if a chart should be generated, False otherwise.
|
|
252
277
|
"""
|
|
253
278
|
|
|
254
|
-
if len(df) > 1 and df.select_dtypes(include=[
|
|
279
|
+
if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0:
|
|
255
280
|
return True
|
|
256
281
|
|
|
257
282
|
return False
|
|
258
283
|
|
|
259
|
-
def generate_rewritten_question(
|
|
284
|
+
def generate_rewritten_question(
|
|
285
|
+
self, last_question: str, new_question: str, **kwargs
|
|
286
|
+
) -> str:
|
|
260
287
|
"""
|
|
261
288
|
**Example:**
|
|
262
289
|
```python
|
|
@@ -277,8 +304,15 @@ class VannaBase(ABC):
|
|
|
277
304
|
return new_question
|
|
278
305
|
|
|
279
306
|
prompt = [
|
|
280
|
-
self.system_message(
|
|
281
|
-
|
|
307
|
+
self.system_message(
|
|
308
|
+
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."
|
|
309
|
+
),
|
|
310
|
+
self.user_message(
|
|
311
|
+
"First question: "
|
|
312
|
+
+ last_question
|
|
313
|
+
+ "\nSecond question: "
|
|
314
|
+
+ new_question
|
|
315
|
+
),
|
|
282
316
|
]
|
|
283
317
|
|
|
284
318
|
return self.submit_prompt(prompt=prompt, **kwargs)
|
|
@@ -309,8 +343,8 @@ class VannaBase(ABC):
|
|
|
309
343
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n"
|
|
310
344
|
),
|
|
311
345
|
self.user_message(
|
|
312
|
-
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
313
|
-
self._response_language()
|
|
346
|
+
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
347
|
+
+ self._response_language()
|
|
314
348
|
),
|
|
315
349
|
]
|
|
316
350
|
|
|
@@ -354,8 +388,8 @@ class VannaBase(ABC):
|
|
|
354
388
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
355
389
|
),
|
|
356
390
|
self.user_message(
|
|
357
|
-
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
358
|
-
self._response_language()
|
|
391
|
+
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
392
|
+
+ self._response_language()
|
|
359
393
|
),
|
|
360
394
|
]
|
|
361
395
|
|
|
@@ -551,7 +585,7 @@ class VannaBase(ABC):
|
|
|
551
585
|
|
|
552
586
|
def get_sql_prompt(
|
|
553
587
|
self,
|
|
554
|
-
initial_prompt
|
|
588
|
+
initial_prompt: str,
|
|
555
589
|
question: str,
|
|
556
590
|
question_sql_list: list,
|
|
557
591
|
ddl_list: list,
|
|
@@ -583,8 +617,10 @@ class VannaBase(ABC):
|
|
|
583
617
|
"""
|
|
584
618
|
|
|
585
619
|
if initial_prompt is None:
|
|
586
|
-
initial_prompt =
|
|
587
|
-
|
|
620
|
+
initial_prompt = (
|
|
621
|
+
f"You are a {self.dialect} expert. "
|
|
622
|
+
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
623
|
+
)
|
|
588
624
|
|
|
589
625
|
initial_prompt = self.add_ddl_to_prompt(
|
|
590
626
|
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
@@ -749,7 +785,7 @@ class VannaBase(ABC):
|
|
|
749
785
|
database: str,
|
|
750
786
|
role: Union[str, None] = None,
|
|
751
787
|
warehouse: Union[str, None] = None,
|
|
752
|
-
**kwargs
|
|
788
|
+
**kwargs,
|
|
753
789
|
):
|
|
754
790
|
try:
|
|
755
791
|
snowflake = __import__("snowflake.connector")
|
|
@@ -797,7 +833,7 @@ class VannaBase(ABC):
|
|
|
797
833
|
account=account,
|
|
798
834
|
database=database,
|
|
799
835
|
client_session_keep_alive=True,
|
|
800
|
-
**kwargs
|
|
836
|
+
**kwargs,
|
|
801
837
|
)
|
|
802
838
|
|
|
803
839
|
def run_sql_snowflake(sql: str) -> pd.DataFrame:
|
|
@@ -823,7 +859,7 @@ class VannaBase(ABC):
|
|
|
823
859
|
self.run_sql = run_sql_snowflake
|
|
824
860
|
self.run_sql_is_set = True
|
|
825
861
|
|
|
826
|
-
def connect_to_sqlite(self, url: str, check_same_thread: bool = False,
|
|
862
|
+
def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
|
|
827
863
|
"""
|
|
828
864
|
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
829
865
|
|
|
@@ -848,11 +884,7 @@ class VannaBase(ABC):
|
|
|
848
884
|
url = path
|
|
849
885
|
|
|
850
886
|
# Connect to the database
|
|
851
|
-
conn = sqlite3.connect(
|
|
852
|
-
url,
|
|
853
|
-
check_same_thread=check_same_thread,
|
|
854
|
-
**kwargs
|
|
855
|
-
)
|
|
887
|
+
conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs)
|
|
856
888
|
|
|
857
889
|
def run_sql_sqlite(sql: str):
|
|
858
890
|
return pd.read_sql_query(sql, conn)
|
|
@@ -868,9 +900,8 @@ class VannaBase(ABC):
|
|
|
868
900
|
user: str = None,
|
|
869
901
|
password: str = None,
|
|
870
902
|
port: int = None,
|
|
871
|
-
**kwargs
|
|
903
|
+
**kwargs,
|
|
872
904
|
):
|
|
873
|
-
|
|
874
905
|
"""
|
|
875
906
|
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
876
907
|
**Example:**
|
|
@@ -939,15 +970,20 @@ class VannaBase(ABC):
|
|
|
939
970
|
user=user,
|
|
940
971
|
password=password,
|
|
941
972
|
port=port,
|
|
942
|
-
**kwargs
|
|
973
|
+
**kwargs,
|
|
943
974
|
)
|
|
944
975
|
except psycopg2.Error as e:
|
|
945
976
|
raise ValidationError(e)
|
|
946
977
|
|
|
947
978
|
def connect_to_db():
|
|
948
|
-
return psycopg2.connect(
|
|
949
|
-
|
|
950
|
-
|
|
979
|
+
return psycopg2.connect(
|
|
980
|
+
host=host,
|
|
981
|
+
dbname=dbname,
|
|
982
|
+
user=user,
|
|
983
|
+
password=password,
|
|
984
|
+
port=port,
|
|
985
|
+
**kwargs,
|
|
986
|
+
)
|
|
951
987
|
|
|
952
988
|
def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
|
|
953
989
|
conn = None
|
|
@@ -980,14 +1016,13 @@ class VannaBase(ABC):
|
|
|
980
1016
|
raise ValidationError(e)
|
|
981
1017
|
|
|
982
1018
|
except Exception as e:
|
|
983
|
-
|
|
984
|
-
|
|
1019
|
+
conn.rollback()
|
|
1020
|
+
raise e
|
|
985
1021
|
|
|
986
1022
|
self.dialect = "PostgreSQL"
|
|
987
1023
|
self.run_sql_is_set = True
|
|
988
1024
|
self.run_sql = run_sql_postgres
|
|
989
1025
|
|
|
990
|
-
|
|
991
1026
|
def connect_to_mysql(
|
|
992
1027
|
self,
|
|
993
1028
|
host: str = None,
|
|
@@ -995,9 +1030,8 @@ class VannaBase(ABC):
|
|
|
995
1030
|
user: str = None,
|
|
996
1031
|
password: str = None,
|
|
997
1032
|
port: int = None,
|
|
998
|
-
**kwargs
|
|
1033
|
+
**kwargs,
|
|
999
1034
|
):
|
|
1000
|
-
|
|
1001
1035
|
try:
|
|
1002
1036
|
import pymysql.cursors
|
|
1003
1037
|
except ImportError:
|
|
@@ -1046,7 +1080,7 @@ class VannaBase(ABC):
|
|
|
1046
1080
|
database=dbname,
|
|
1047
1081
|
port=port,
|
|
1048
1082
|
cursorclass=pymysql.cursors.DictCursor,
|
|
1049
|
-
**kwargs
|
|
1083
|
+
**kwargs,
|
|
1050
1084
|
)
|
|
1051
1085
|
except pymysql.Error as e:
|
|
1052
1086
|
raise ValidationError(e)
|
|
@@ -1083,9 +1117,8 @@ class VannaBase(ABC):
|
|
|
1083
1117
|
user: str = None,
|
|
1084
1118
|
password: str = None,
|
|
1085
1119
|
port: int = None,
|
|
1086
|
-
**kwargs
|
|
1120
|
+
**kwargs,
|
|
1087
1121
|
):
|
|
1088
|
-
|
|
1089
1122
|
try:
|
|
1090
1123
|
import clickhouse_connect
|
|
1091
1124
|
except ImportError:
|
|
@@ -1133,7 +1166,7 @@ class VannaBase(ABC):
|
|
|
1133
1166
|
username=user,
|
|
1134
1167
|
password=password,
|
|
1135
1168
|
database=dbname,
|
|
1136
|
-
**kwargs
|
|
1169
|
+
**kwargs,
|
|
1137
1170
|
)
|
|
1138
1171
|
print(conn)
|
|
1139
1172
|
except Exception as e:
|
|
@@ -1156,13 +1189,8 @@ class VannaBase(ABC):
|
|
|
1156
1189
|
self.run_sql = run_sql_clickhouse
|
|
1157
1190
|
|
|
1158
1191
|
def connect_to_oracle(
|
|
1159
|
-
self,
|
|
1160
|
-
user: str = None,
|
|
1161
|
-
password: str = None,
|
|
1162
|
-
dsn: str = None,
|
|
1163
|
-
**kwargs
|
|
1192
|
+
self, user: str = None, password: str = None, dsn: str = None, **kwargs
|
|
1164
1193
|
):
|
|
1165
|
-
|
|
1166
1194
|
"""
|
|
1167
1195
|
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1168
1196
|
**Example:**
|
|
@@ -1182,7 +1210,6 @@ class VannaBase(ABC):
|
|
|
1182
1210
|
try:
|
|
1183
1211
|
import oracledb
|
|
1184
1212
|
except ImportError:
|
|
1185
|
-
|
|
1186
1213
|
raise DependencyError(
|
|
1187
1214
|
"You need to install required dependencies to execute this method,"
|
|
1188
1215
|
" run command: \npip install oracledb"
|
|
@@ -1192,7 +1219,9 @@ class VannaBase(ABC):
|
|
|
1192
1219
|
dsn = os.getenv("DSN")
|
|
1193
1220
|
|
|
1194
1221
|
if not dsn:
|
|
1195
|
-
raise ImproperlyConfigured(
|
|
1222
|
+
raise ImproperlyConfigured(
|
|
1223
|
+
"Please set your Oracle dsn which should include host:port/sid"
|
|
1224
|
+
)
|
|
1196
1225
|
|
|
1197
1226
|
if not user:
|
|
1198
1227
|
user = os.getenv("USER")
|
|
@@ -1209,12 +1238,7 @@ class VannaBase(ABC):
|
|
|
1209
1238
|
conn = None
|
|
1210
1239
|
|
|
1211
1240
|
try:
|
|
1212
|
-
conn = oracledb.connect(
|
|
1213
|
-
user=user,
|
|
1214
|
-
password=password,
|
|
1215
|
-
dsn=dsn,
|
|
1216
|
-
**kwargs
|
|
1217
|
-
)
|
|
1241
|
+
conn = oracledb.connect(user=user, password=password, dsn=dsn, **kwargs)
|
|
1218
1242
|
except oracledb.Error as e:
|
|
1219
1243
|
raise ValidationError(e)
|
|
1220
1244
|
|
|
@@ -1222,7 +1246,9 @@ class VannaBase(ABC):
|
|
|
1222
1246
|
if conn:
|
|
1223
1247
|
try:
|
|
1224
1248
|
sql = sql.rstrip()
|
|
1225
|
-
if sql.endswith(
|
|
1249
|
+
if sql.endswith(
|
|
1250
|
+
";"
|
|
1251
|
+
): # fix for a known problem with Oracle db where an extra ; will cause an error.
|
|
1226
1252
|
sql = sql[:-1]
|
|
1227
1253
|
|
|
1228
1254
|
cs = conn.cursor()
|
|
@@ -1247,10 +1273,7 @@ class VannaBase(ABC):
|
|
|
1247
1273
|
self.run_sql = run_sql_oracle
|
|
1248
1274
|
|
|
1249
1275
|
def connect_to_bigquery(
|
|
1250
|
-
self,
|
|
1251
|
-
cred_file_path: str = None,
|
|
1252
|
-
project_id: str = None,
|
|
1253
|
-
**kwargs
|
|
1276
|
+
self, cred_file_path: str = None, project_id: str = None, **kwargs
|
|
1254
1277
|
):
|
|
1255
1278
|
"""
|
|
1256
1279
|
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
@@ -1299,7 +1322,7 @@ class VannaBase(ABC):
|
|
|
1299
1322
|
if not cred_file_path:
|
|
1300
1323
|
try:
|
|
1301
1324
|
conn = bigquery.Client(project=project_id)
|
|
1302
|
-
except:
|
|
1325
|
+
except Exception:
|
|
1303
1326
|
print("Could not found any google cloud implicit credentials")
|
|
1304
1327
|
else:
|
|
1305
1328
|
# Validate file path and pemissions
|
|
@@ -1314,11 +1337,9 @@ class VannaBase(ABC):
|
|
|
1314
1337
|
|
|
1315
1338
|
try:
|
|
1316
1339
|
conn = bigquery.Client(
|
|
1317
|
-
project=project_id,
|
|
1318
|
-
credentials=credentials,
|
|
1319
|
-
**kwargs
|
|
1340
|
+
project=project_id, credentials=credentials, **kwargs
|
|
1320
1341
|
)
|
|
1321
|
-
except:
|
|
1342
|
+
except Exception:
|
|
1322
1343
|
raise ImproperlyConfigured(
|
|
1323
1344
|
"Could not connect to bigquery please correct credentials"
|
|
1324
1345
|
)
|
|
@@ -1430,20 +1451,21 @@ class VannaBase(ABC):
|
|
|
1430
1451
|
self.dialect = "T-SQL / Microsoft SQL Server"
|
|
1431
1452
|
self.run_sql = run_sql_mssql
|
|
1432
1453
|
self.run_sql_is_set = True
|
|
1454
|
+
|
|
1433
1455
|
def connect_to_presto(
|
|
1434
1456
|
self,
|
|
1435
1457
|
host: str,
|
|
1436
|
-
catalog: str =
|
|
1437
|
-
schema: str =
|
|
1458
|
+
catalog: str = "hive",
|
|
1459
|
+
schema: str = "default",
|
|
1438
1460
|
user: str = None,
|
|
1439
1461
|
password: str = None,
|
|
1440
1462
|
port: int = None,
|
|
1441
1463
|
combined_pem_path: str = None,
|
|
1442
|
-
protocol: str =
|
|
1464
|
+
protocol: str = "https",
|
|
1443
1465
|
requests_kwargs: dict = None,
|
|
1444
|
-
**kwargs
|
|
1466
|
+
**kwargs,
|
|
1445
1467
|
):
|
|
1446
|
-
|
|
1468
|
+
"""
|
|
1447
1469
|
Connect to a Presto database using the specified parameters.
|
|
1448
1470
|
|
|
1449
1471
|
Args:
|
|
@@ -1463,101 +1485,103 @@ class VannaBase(ABC):
|
|
|
1463
1485
|
|
|
1464
1486
|
Returns:
|
|
1465
1487
|
None
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1488
|
+
"""
|
|
1489
|
+
try:
|
|
1490
|
+
from pyhive import presto
|
|
1491
|
+
except ImportError:
|
|
1492
|
+
raise DependencyError(
|
|
1493
|
+
"You need to install required dependencies to execute this method,"
|
|
1494
|
+
" run command: \npip install pyhive"
|
|
1495
|
+
)
|
|
1474
1496
|
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
if not host:
|
|
1479
|
-
raise ImproperlyConfigured("Please set your presto host")
|
|
1480
|
-
|
|
1481
|
-
if not catalog:
|
|
1482
|
-
catalog = os.getenv("PRESTO_CATALOG")
|
|
1483
|
-
|
|
1484
|
-
if not catalog:
|
|
1485
|
-
raise ImproperlyConfigured("Please set your presto catalog")
|
|
1486
|
-
|
|
1487
|
-
if not user:
|
|
1488
|
-
user = os.getenv("PRESTO_USER")
|
|
1489
|
-
|
|
1490
|
-
if not user:
|
|
1491
|
-
raise ImproperlyConfigured("Please set your presto user")
|
|
1492
|
-
|
|
1493
|
-
if not password:
|
|
1494
|
-
password = os.getenv("PRESTO_PASSWORD")
|
|
1495
|
-
|
|
1496
|
-
if not port:
|
|
1497
|
-
port = os.getenv("PRESTO_PORT")
|
|
1498
|
-
|
|
1499
|
-
if not port:
|
|
1500
|
-
raise ImproperlyConfigured("Please set your presto port")
|
|
1501
|
-
|
|
1502
|
-
conn = None
|
|
1503
|
-
|
|
1504
|
-
try:
|
|
1505
|
-
if requests_kwargs is None and combined_pem_path is not None:
|
|
1506
|
-
# use the combined pem file to verify the SSL connection
|
|
1507
|
-
requests_kwargs = {
|
|
1508
|
-
'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
|
|
1509
|
-
}
|
|
1510
|
-
conn = presto.Connection(host=host,
|
|
1511
|
-
username=user,
|
|
1512
|
-
password=password,
|
|
1513
|
-
catalog=catalog,
|
|
1514
|
-
schema=schema,
|
|
1515
|
-
port=port,
|
|
1516
|
-
protocol=protocol,
|
|
1517
|
-
requests_kwargs=requests_kwargs,
|
|
1518
|
-
**kwargs)
|
|
1519
|
-
except presto.Error as e:
|
|
1520
|
-
raise ValidationError(e)
|
|
1521
|
-
|
|
1522
|
-
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1523
|
-
if conn:
|
|
1524
|
-
try:
|
|
1525
|
-
sql = sql.rstrip()
|
|
1526
|
-
# fix for a known problem with presto db where an extra ; will cause an error.
|
|
1527
|
-
if sql.endswith(';'):
|
|
1528
|
-
sql = sql[:-1]
|
|
1529
|
-
cs = conn.cursor()
|
|
1530
|
-
cs.execute(sql)
|
|
1531
|
-
results = cs.fetchall()
|
|
1497
|
+
if not host:
|
|
1498
|
+
host = os.getenv("PRESTO_HOST")
|
|
1532
1499
|
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
results, columns=[desc[0] for desc in cs.description]
|
|
1536
|
-
)
|
|
1537
|
-
return df
|
|
1500
|
+
if not host:
|
|
1501
|
+
raise ImproperlyConfigured("Please set your presto host")
|
|
1538
1502
|
|
|
1539
|
-
|
|
1540
|
-
|
|
1503
|
+
if not catalog:
|
|
1504
|
+
catalog = os.getenv("PRESTO_CATALOG")
|
|
1505
|
+
|
|
1506
|
+
if not catalog:
|
|
1507
|
+
raise ImproperlyConfigured("Please set your presto catalog")
|
|
1508
|
+
|
|
1509
|
+
if not user:
|
|
1510
|
+
user = os.getenv("PRESTO_USER")
|
|
1511
|
+
|
|
1512
|
+
if not user:
|
|
1513
|
+
raise ImproperlyConfigured("Please set your presto user")
|
|
1514
|
+
|
|
1515
|
+
if not password:
|
|
1516
|
+
password = os.getenv("PRESTO_PASSWORD")
|
|
1517
|
+
|
|
1518
|
+
if not port:
|
|
1519
|
+
port = os.getenv("PRESTO_PORT")
|
|
1520
|
+
|
|
1521
|
+
if not port:
|
|
1522
|
+
raise ImproperlyConfigured("Please set your presto port")
|
|
1523
|
+
|
|
1524
|
+
conn = None
|
|
1525
|
+
|
|
1526
|
+
try:
|
|
1527
|
+
if requests_kwargs is None and combined_pem_path is not None:
|
|
1528
|
+
# use the combined pem file to verify the SSL connection
|
|
1529
|
+
requests_kwargs = {
|
|
1530
|
+
"verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
|
|
1531
|
+
}
|
|
1532
|
+
conn = presto.Connection(
|
|
1533
|
+
host=host,
|
|
1534
|
+
username=user,
|
|
1535
|
+
password=password,
|
|
1536
|
+
catalog=catalog,
|
|
1537
|
+
schema=schema,
|
|
1538
|
+
port=port,
|
|
1539
|
+
protocol=protocol,
|
|
1540
|
+
requests_kwargs=requests_kwargs,
|
|
1541
|
+
**kwargs,
|
|
1542
|
+
)
|
|
1543
|
+
except presto.Error as e:
|
|
1541
1544
|
raise ValidationError(e)
|
|
1542
1545
|
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
+
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1547
|
+
if conn:
|
|
1548
|
+
try:
|
|
1549
|
+
sql = sql.rstrip()
|
|
1550
|
+
# fix for a known problem with presto db where an extra ; will cause an error.
|
|
1551
|
+
if sql.endswith(";"):
|
|
1552
|
+
sql = sql[:-1]
|
|
1553
|
+
cs = conn.cursor()
|
|
1554
|
+
cs.execute(sql)
|
|
1555
|
+
results = cs.fetchall()
|
|
1556
|
+
|
|
1557
|
+
# Create a pandas dataframe from the results
|
|
1558
|
+
df = pd.DataFrame(
|
|
1559
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1560
|
+
)
|
|
1561
|
+
return df
|
|
1546
1562
|
|
|
1547
|
-
|
|
1548
|
-
|
|
1563
|
+
except presto.Error as e:
|
|
1564
|
+
print(e)
|
|
1565
|
+
raise ValidationError(e)
|
|
1566
|
+
|
|
1567
|
+
except Exception as e:
|
|
1568
|
+
print(e)
|
|
1569
|
+
raise e
|
|
1570
|
+
|
|
1571
|
+
self.run_sql_is_set = True
|
|
1572
|
+
self.run_sql = run_sql_presto
|
|
1549
1573
|
|
|
1550
1574
|
def connect_to_hive(
|
|
1551
1575
|
self,
|
|
1552
1576
|
host: str = None,
|
|
1553
|
-
dbname: str =
|
|
1577
|
+
dbname: str = "default",
|
|
1554
1578
|
user: str = None,
|
|
1555
1579
|
password: str = None,
|
|
1556
1580
|
port: int = None,
|
|
1557
|
-
auth: str =
|
|
1558
|
-
**kwargs
|
|
1581
|
+
auth: str = "CUSTOM",
|
|
1582
|
+
**kwargs,
|
|
1559
1583
|
):
|
|
1560
|
-
|
|
1584
|
+
"""
|
|
1561
1585
|
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1562
1586
|
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1563
1587
|
|
|
@@ -1571,78 +1595,80 @@ class VannaBase(ABC):
|
|
|
1571
1595
|
|
|
1572
1596
|
Returns:
|
|
1573
1597
|
None
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
try:
|
|
1577
|
-
from pyhive import hive
|
|
1578
|
-
except ImportError:
|
|
1579
|
-
raise DependencyError(
|
|
1580
|
-
"You need to install required dependencies to execute this method,"
|
|
1581
|
-
" run command: \npip install pyhive"
|
|
1582
|
-
)
|
|
1583
|
-
|
|
1584
|
-
if not host:
|
|
1585
|
-
host = os.getenv("HIVE_HOST")
|
|
1598
|
+
"""
|
|
1586
1599
|
|
|
1587
|
-
|
|
1588
|
-
|
|
1600
|
+
try:
|
|
1601
|
+
from pyhive import hive
|
|
1602
|
+
except ImportError:
|
|
1603
|
+
raise DependencyError(
|
|
1604
|
+
"You need to install required dependencies to execute this method,"
|
|
1605
|
+
" run command: \npip install pyhive"
|
|
1606
|
+
)
|
|
1589
1607
|
|
|
1590
|
-
|
|
1591
|
-
|
|
1608
|
+
if not host:
|
|
1609
|
+
host = os.getenv("HIVE_HOST")
|
|
1592
1610
|
|
|
1593
|
-
|
|
1594
|
-
|
|
1611
|
+
if not host:
|
|
1612
|
+
raise ImproperlyConfigured("Please set your hive host")
|
|
1595
1613
|
|
|
1596
|
-
|
|
1597
|
-
|
|
1614
|
+
if not dbname:
|
|
1615
|
+
dbname = os.getenv("HIVE_DATABASE")
|
|
1598
1616
|
|
|
1599
|
-
|
|
1600
|
-
|
|
1617
|
+
if not dbname:
|
|
1618
|
+
raise ImproperlyConfigured("Please set your hive database")
|
|
1601
1619
|
|
|
1602
|
-
|
|
1603
|
-
|
|
1620
|
+
if not user:
|
|
1621
|
+
user = os.getenv("HIVE_USER")
|
|
1604
1622
|
|
|
1605
|
-
|
|
1606
|
-
|
|
1623
|
+
if not user:
|
|
1624
|
+
raise ImproperlyConfigured("Please set your hive user")
|
|
1607
1625
|
|
|
1608
|
-
|
|
1609
|
-
|
|
1626
|
+
if not password:
|
|
1627
|
+
password = os.getenv("HIVE_PASSWORD")
|
|
1610
1628
|
|
|
1611
|
-
|
|
1629
|
+
if not port:
|
|
1630
|
+
port = os.getenv("HIVE_PORT")
|
|
1612
1631
|
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
username=user,
|
|
1616
|
-
password=password,
|
|
1617
|
-
database=dbname,
|
|
1618
|
-
port=port,
|
|
1619
|
-
auth=auth)
|
|
1620
|
-
except hive.Error as e:
|
|
1621
|
-
raise ValidationError(e)
|
|
1632
|
+
if not port:
|
|
1633
|
+
raise ImproperlyConfigured("Please set your hive port")
|
|
1622
1634
|
|
|
1623
|
-
|
|
1624
|
-
if conn:
|
|
1625
|
-
try:
|
|
1626
|
-
cs = conn.cursor()
|
|
1627
|
-
cs.execute(sql)
|
|
1628
|
-
results = cs.fetchall()
|
|
1635
|
+
conn = None
|
|
1629
1636
|
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1637
|
+
try:
|
|
1638
|
+
conn = hive.Connection(
|
|
1639
|
+
host=host,
|
|
1640
|
+
username=user,
|
|
1641
|
+
password=password,
|
|
1642
|
+
database=dbname,
|
|
1643
|
+
port=port,
|
|
1644
|
+
auth=auth,
|
|
1633
1645
|
)
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
except hive.Error as e:
|
|
1637
|
-
print(e)
|
|
1646
|
+
except hive.Error as e:
|
|
1638
1647
|
raise ValidationError(e)
|
|
1639
1648
|
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1649
|
+
def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
|
|
1650
|
+
if conn:
|
|
1651
|
+
try:
|
|
1652
|
+
cs = conn.cursor()
|
|
1653
|
+
cs.execute(sql)
|
|
1654
|
+
results = cs.fetchall()
|
|
1643
1655
|
|
|
1644
|
-
|
|
1645
|
-
|
|
1656
|
+
# Create a pandas dataframe from the results
|
|
1657
|
+
df = pd.DataFrame(
|
|
1658
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1659
|
+
)
|
|
1660
|
+
return df
|
|
1661
|
+
|
|
1662
|
+
except hive.Error as e:
|
|
1663
|
+
print(e)
|
|
1664
|
+
raise ValidationError(e)
|
|
1665
|
+
|
|
1666
|
+
except Exception as e:
|
|
1667
|
+
print(e)
|
|
1668
|
+
raise e
|
|
1669
|
+
|
|
1670
|
+
self.run_sql_is_set = True
|
|
1671
|
+
self.run_sql = run_sql_hive
|
|
1646
1672
|
|
|
1647
1673
|
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
|
|
1648
1674
|
"""
|
|
@@ -1700,22 +1726,23 @@ class VannaBase(ABC):
|
|
|
1700
1726
|
question = input("Enter a question: ")
|
|
1701
1727
|
|
|
1702
1728
|
try:
|
|
1703
|
-
sql = self.generate_sql(
|
|
1729
|
+
sql = self.generate_sql(
|
|
1730
|
+
question=question, allow_llm_to_see_data=allow_llm_to_see_data
|
|
1731
|
+
)
|
|
1704
1732
|
except Exception as e:
|
|
1705
1733
|
print(e)
|
|
1706
1734
|
return None, None, None
|
|
1707
1735
|
|
|
1708
1736
|
if print_results:
|
|
1709
1737
|
try:
|
|
1710
|
-
|
|
1738
|
+
from IPython.display import Code, display
|
|
1739
|
+
|
|
1711
1740
|
display(Code(sql))
|
|
1712
1741
|
except Exception as e:
|
|
1713
1742
|
print(sql)
|
|
1714
1743
|
|
|
1715
1744
|
if self.run_sql_is_set is False:
|
|
1716
|
-
print(
|
|
1717
|
-
"If you want to run the SQL query, connect to a database first."
|
|
1718
|
-
)
|
|
1745
|
+
print("If you want to run the SQL query, connect to a database first.")
|
|
1719
1746
|
|
|
1720
1747
|
if print_results:
|
|
1721
1748
|
return None
|
|
@@ -1759,6 +1786,7 @@ class VannaBase(ABC):
|
|
|
1759
1786
|
fig.show()
|
|
1760
1787
|
except Exception as e:
|
|
1761
1788
|
# Print stack trace
|
|
1789
|
+
traceback.print_stack()
|
|
1762
1790
|
traceback.print_exc()
|
|
1763
1791
|
print("Couldn't run plotly code: ", e)
|
|
1764
1792
|
if print_results:
|
|
@@ -1874,12 +1902,8 @@ class VannaBase(ABC):
|
|
|
1874
1902
|
table_column = df.columns[
|
|
1875
1903
|
df.columns.str.lower().str.contains("table_name")
|
|
1876
1904
|
].to_list()[0]
|
|
1877
|
-
columns = [database_column,
|
|
1878
|
-
|
|
1879
|
-
table_column]
|
|
1880
|
-
candidates = ["column_name",
|
|
1881
|
-
"data_type",
|
|
1882
|
-
"comment"]
|
|
1905
|
+
columns = [database_column, schema_column, table_column]
|
|
1906
|
+
candidates = ["column_name", "data_type", "comment"]
|
|
1883
1907
|
matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
|
|
1884
1908
|
columns += df.columns[matches].to_list()
|
|
1885
1909
|
|