vanna 0.7.9__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +167 -395
- vanna/agents/__init__.py +7 -0
- vanna/capabilities/__init__.py +17 -0
- vanna/capabilities/agent_memory/__init__.py +21 -0
- vanna/capabilities/agent_memory/base.py +103 -0
- vanna/capabilities/agent_memory/models.py +53 -0
- vanna/capabilities/file_system/__init__.py +14 -0
- vanna/capabilities/file_system/base.py +71 -0
- vanna/capabilities/file_system/models.py +25 -0
- vanna/capabilities/sql_runner/__init__.py +13 -0
- vanna/capabilities/sql_runner/base.py +37 -0
- vanna/capabilities/sql_runner/models.py +13 -0
- vanna/components/__init__.py +92 -0
- vanna/components/base.py +11 -0
- vanna/components/rich/__init__.py +83 -0
- vanna/components/rich/containers/__init__.py +7 -0
- vanna/components/rich/containers/card.py +20 -0
- vanna/components/rich/data/__init__.py +9 -0
- vanna/components/rich/data/chart.py +17 -0
- vanna/components/rich/data/dataframe.py +93 -0
- vanna/components/rich/feedback/__init__.py +21 -0
- vanna/components/rich/feedback/badge.py +16 -0
- vanna/components/rich/feedback/icon_text.py +14 -0
- vanna/components/rich/feedback/log_viewer.py +41 -0
- vanna/components/rich/feedback/notification.py +19 -0
- vanna/components/rich/feedback/progress.py +37 -0
- vanna/components/rich/feedback/status_card.py +28 -0
- vanna/components/rich/feedback/status_indicator.py +14 -0
- vanna/components/rich/interactive/__init__.py +21 -0
- vanna/components/rich/interactive/button.py +95 -0
- vanna/components/rich/interactive/task_list.py +58 -0
- vanna/components/rich/interactive/ui_state.py +93 -0
- vanna/components/rich/specialized/__init__.py +7 -0
- vanna/components/rich/specialized/artifact.py +20 -0
- vanna/components/rich/text.py +16 -0
- vanna/components/simple/__init__.py +15 -0
- vanna/components/simple/image.py +15 -0
- vanna/components/simple/link.py +15 -0
- vanna/components/simple/text.py +11 -0
- vanna/core/__init__.py +193 -0
- vanna/core/_compat.py +19 -0
- vanna/core/agent/__init__.py +10 -0
- vanna/core/agent/agent.py +1407 -0
- vanna/core/agent/config.py +123 -0
- vanna/core/audit/__init__.py +28 -0
- vanna/core/audit/base.py +299 -0
- vanna/core/audit/models.py +131 -0
- vanna/core/component_manager.py +329 -0
- vanna/core/components.py +53 -0
- vanna/core/enhancer/__init__.py +11 -0
- vanna/core/enhancer/base.py +94 -0
- vanna/core/enhancer/default.py +118 -0
- vanna/core/enricher/__init__.py +10 -0
- vanna/core/enricher/base.py +59 -0
- vanna/core/errors.py +47 -0
- vanna/core/evaluation/__init__.py +81 -0
- vanna/core/evaluation/base.py +186 -0
- vanna/core/evaluation/dataset.py +254 -0
- vanna/core/evaluation/evaluators.py +376 -0
- vanna/core/evaluation/report.py +289 -0
- vanna/core/evaluation/runner.py +313 -0
- vanna/core/filter/__init__.py +10 -0
- vanna/core/filter/base.py +67 -0
- vanna/core/lifecycle/__init__.py +10 -0
- vanna/core/lifecycle/base.py +83 -0
- vanna/core/llm/__init__.py +16 -0
- vanna/core/llm/base.py +40 -0
- vanna/core/llm/models.py +61 -0
- vanna/core/middleware/__init__.py +10 -0
- vanna/core/middleware/base.py +69 -0
- vanna/core/observability/__init__.py +11 -0
- vanna/core/observability/base.py +88 -0
- vanna/core/observability/models.py +47 -0
- vanna/core/recovery/__init__.py +11 -0
- vanna/core/recovery/base.py +84 -0
- vanna/core/recovery/models.py +32 -0
- vanna/core/registry.py +278 -0
- vanna/core/rich_component.py +156 -0
- vanna/core/simple_component.py +27 -0
- vanna/core/storage/__init__.py +14 -0
- vanna/core/storage/base.py +46 -0
- vanna/core/storage/models.py +46 -0
- vanna/core/system_prompt/__init__.py +13 -0
- vanna/core/system_prompt/base.py +36 -0
- vanna/core/system_prompt/default.py +157 -0
- vanna/core/tool/__init__.py +18 -0
- vanna/core/tool/base.py +70 -0
- vanna/core/tool/models.py +84 -0
- vanna/core/user/__init__.py +17 -0
- vanna/core/user/base.py +29 -0
- vanna/core/user/models.py +25 -0
- vanna/core/user/request_context.py +70 -0
- vanna/core/user/resolver.py +42 -0
- vanna/core/validation.py +164 -0
- vanna/core/workflow/__init__.py +12 -0
- vanna/core/workflow/base.py +254 -0
- vanna/core/workflow/default.py +789 -0
- vanna/examples/__init__.py +1 -0
- vanna/examples/__main__.py +44 -0
- vanna/examples/anthropic_quickstart.py +80 -0
- vanna/examples/artifact_example.py +293 -0
- vanna/examples/claude_sqlite_example.py +236 -0
- vanna/examples/coding_agent_example.py +300 -0
- vanna/examples/custom_system_prompt_example.py +174 -0
- vanna/examples/default_workflow_handler_example.py +208 -0
- vanna/examples/email_auth_example.py +340 -0
- vanna/examples/evaluation_example.py +269 -0
- vanna/examples/extensibility_example.py +262 -0
- vanna/examples/minimal_example.py +67 -0
- vanna/examples/mock_auth_example.py +227 -0
- vanna/examples/mock_custom_tool.py +311 -0
- vanna/examples/mock_quickstart.py +79 -0
- vanna/examples/mock_quota_example.py +145 -0
- vanna/examples/mock_rich_components_demo.py +396 -0
- vanna/examples/mock_sqlite_example.py +223 -0
- vanna/examples/openai_quickstart.py +83 -0
- vanna/examples/primitive_components_demo.py +305 -0
- vanna/examples/quota_lifecycle_example.py +139 -0
- vanna/examples/visualization_example.py +251 -0
- vanna/integrations/__init__.py +17 -0
- vanna/integrations/anthropic/__init__.py +9 -0
- vanna/integrations/anthropic/llm.py +270 -0
- vanna/integrations/azureopenai/__init__.py +9 -0
- vanna/integrations/azureopenai/llm.py +329 -0
- vanna/integrations/azuresearch/__init__.py +7 -0
- vanna/integrations/azuresearch/agent_memory.py +413 -0
- vanna/integrations/bigquery/__init__.py +5 -0
- vanna/integrations/bigquery/sql_runner.py +81 -0
- vanna/integrations/chromadb/__init__.py +104 -0
- vanna/integrations/chromadb/agent_memory.py +416 -0
- vanna/integrations/clickhouse/__init__.py +5 -0
- vanna/integrations/clickhouse/sql_runner.py +82 -0
- vanna/integrations/duckdb/__init__.py +5 -0
- vanna/integrations/duckdb/sql_runner.py +65 -0
- vanna/integrations/faiss/__init__.py +7 -0
- vanna/integrations/faiss/agent_memory.py +431 -0
- vanna/integrations/google/__init__.py +9 -0
- vanna/integrations/google/gemini.py +370 -0
- vanna/integrations/hive/__init__.py +5 -0
- vanna/integrations/hive/sql_runner.py +87 -0
- vanna/integrations/local/__init__.py +17 -0
- vanna/integrations/local/agent_memory/__init__.py +7 -0
- vanna/integrations/local/agent_memory/in_memory.py +285 -0
- vanna/integrations/local/audit.py +59 -0
- vanna/integrations/local/file_system.py +242 -0
- vanna/integrations/local/file_system_conversation_store.py +255 -0
- vanna/integrations/local/storage.py +62 -0
- vanna/integrations/marqo/__init__.py +7 -0
- vanna/integrations/marqo/agent_memory.py +354 -0
- vanna/integrations/milvus/__init__.py +7 -0
- vanna/integrations/milvus/agent_memory.py +458 -0
- vanna/integrations/mock/__init__.py +9 -0
- vanna/integrations/mock/llm.py +65 -0
- vanna/integrations/mssql/__init__.py +5 -0
- vanna/integrations/mssql/sql_runner.py +66 -0
- vanna/integrations/mysql/__init__.py +5 -0
- vanna/integrations/mysql/sql_runner.py +92 -0
- vanna/integrations/ollama/__init__.py +7 -0
- vanna/integrations/ollama/llm.py +252 -0
- vanna/integrations/openai/__init__.py +10 -0
- vanna/integrations/openai/llm.py +267 -0
- vanna/integrations/openai/responses.py +163 -0
- vanna/integrations/opensearch/__init__.py +7 -0
- vanna/integrations/opensearch/agent_memory.py +411 -0
- vanna/integrations/oracle/__init__.py +5 -0
- vanna/integrations/oracle/sql_runner.py +75 -0
- vanna/integrations/pinecone/__init__.py +7 -0
- vanna/integrations/pinecone/agent_memory.py +329 -0
- vanna/integrations/plotly/__init__.py +5 -0
- vanna/integrations/plotly/chart_generator.py +313 -0
- vanna/integrations/postgres/__init__.py +9 -0
- vanna/integrations/postgres/sql_runner.py +112 -0
- vanna/integrations/premium/agent_memory/__init__.py +7 -0
- vanna/integrations/premium/agent_memory/premium.py +186 -0
- vanna/integrations/presto/__init__.py +5 -0
- vanna/integrations/presto/sql_runner.py +107 -0
- vanna/integrations/qdrant/__init__.py +7 -0
- vanna/integrations/qdrant/agent_memory.py +461 -0
- vanna/integrations/snowflake/__init__.py +5 -0
- vanna/integrations/snowflake/sql_runner.py +147 -0
- vanna/integrations/sqlite/__init__.py +9 -0
- vanna/integrations/sqlite/sql_runner.py +65 -0
- vanna/integrations/weaviate/__init__.py +7 -0
- vanna/integrations/weaviate/agent_memory.py +428 -0
- vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
- vanna/legacy/__init__.py +403 -0
- vanna/legacy/adapter.py +463 -0
- vanna/{advanced → legacy/advanced}/__init__.py +3 -1
- vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
- vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
- vanna/{base → legacy/base}/base.py +224 -217
- vanna/legacy/bedrock/__init__.py +1 -0
- vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
- vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
- vanna/legacy/cohere/__init__.py +2 -0
- vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
- vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
- vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
- vanna/legacy/faiss/__init__.py +1 -0
- vanna/{faiss → legacy/faiss}/faiss.py +113 -59
- vanna/{flask → legacy/flask}/__init__.py +84 -43
- vanna/{flask → legacy/flask}/assets.py +5 -5
- vanna/{flask → legacy/flask}/auth.py +5 -4
- vanna/{google → legacy/google}/bigquery_vector.py +75 -42
- vanna/{google → legacy/google}/gemini_chat.py +7 -3
- vanna/{hf → legacy/hf}/hf.py +0 -1
- vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
- vanna/{mock → legacy/mock}/llm.py +0 -1
- vanna/legacy/mock/vectordb.py +67 -0
- vanna/legacy/ollama/ollama.py +110 -0
- vanna/{openai → legacy/openai}/openai_chat.py +2 -6
- vanna/legacy/opensearch/opensearch_vector.py +369 -0
- vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
- vanna/legacy/oracle/oracle_vector.py +584 -0
- vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
- vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
- vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
- vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
- vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
- vanna/{remote.py → legacy/remote.py} +28 -26
- vanna/{utils.py → legacy/utils.py} +6 -11
- vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
- vanna/{vllm → legacy/vllm}/vllm.py +5 -6
- vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
- vanna/{xinference → legacy/xinference}/xinference.py +6 -6
- vanna/py.typed +0 -0
- vanna/servers/__init__.py +16 -0
- vanna/servers/__main__.py +8 -0
- vanna/servers/base/__init__.py +18 -0
- vanna/servers/base/chat_handler.py +65 -0
- vanna/servers/base/models.py +111 -0
- vanna/servers/base/rich_chat_handler.py +141 -0
- vanna/servers/base/templates.py +331 -0
- vanna/servers/cli/__init__.py +7 -0
- vanna/servers/cli/server_runner.py +204 -0
- vanna/servers/fastapi/__init__.py +7 -0
- vanna/servers/fastapi/app.py +163 -0
- vanna/servers/fastapi/routes.py +183 -0
- vanna/servers/flask/__init__.py +7 -0
- vanna/servers/flask/app.py +132 -0
- vanna/servers/flask/routes.py +137 -0
- vanna/tools/__init__.py +41 -0
- vanna/tools/agent_memory.py +322 -0
- vanna/tools/file_system.py +879 -0
- vanna/tools/python.py +222 -0
- vanna/tools/run_sql.py +165 -0
- vanna/tools/visualize_data.py +195 -0
- vanna/utils/__init__.py +0 -0
- vanna/web_components/__init__.py +44 -0
- vanna-2.0.0.dist-info/METADATA +485 -0
- vanna-2.0.0.dist-info/RECORD +289 -0
- vanna-2.0.0.dist-info/entry_points.txt +3 -0
- vanna/bedrock/__init__.py +0 -1
- vanna/cohere/__init__.py +0 -2
- vanna/faiss/__init__.py +0 -1
- vanna/mock/vectordb.py +0 -55
- vanna/ollama/ollama.py +0 -103
- vanna/opensearch/opensearch_vector.py +0 -392
- vanna/opensearch/opensearch_vector_semantic.py +0 -175
- vanna/oracle/oracle_vector.py +0 -585
- vanna/qianfan/Qianfan_Chat.py +0 -165
- vanna/qianfan/Qianfan_embeddings.py +0 -36
- vanna/qianwen/QianwenAI_chat.py +0 -133
- vanna-0.7.9.dist-info/METADATA +0 -408
- vanna-0.7.9.dist-info/RECORD +0 -79
- /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
- /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
- /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
- /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
- /vanna/{base → legacy/base}/__init__.py +0 -0
- /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
- /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
- /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
- /vanna/{google → legacy/google}/__init__.py +0 -0
- /vanna/{hf → legacy/hf}/__init__.py +0 -0
- /vanna/{local.py → legacy/local.py} +0 -0
- /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
- /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
- /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
- /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
- /vanna/{mock → legacy/mock}/__init__.py +0 -0
- /vanna/{mock → legacy/mock}/embedding.py +0 -0
- /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/__init__.py +0 -0
- /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
- /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
- /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
- /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
- /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
- /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
- /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
- /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
- /vanna/{types → legacy/types}/__init__.py +0 -0
- /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
- /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
- /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
- /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
- {vanna-0.7.9.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -136,7 +136,7 @@ class VannaBase(ABC):
|
|
|
136
136
|
llm_response = self.submit_prompt(prompt, **kwargs)
|
|
137
137
|
self.log(title="LLM Response", message=llm_response)
|
|
138
138
|
|
|
139
|
-
if
|
|
139
|
+
if "intermediate_sql" in llm_response:
|
|
140
140
|
if not allow_llm_to_see_data:
|
|
141
141
|
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
|
|
142
142
|
|
|
@@ -152,7 +152,11 @@ class VannaBase(ABC):
|
|
|
152
152
|
question=question,
|
|
153
153
|
question_sql_list=question_sql_list,
|
|
154
154
|
ddl_list=ddl_list,
|
|
155
|
-
doc_list=doc_list
|
|
155
|
+
doc_list=doc_list
|
|
156
|
+
+ [
|
|
157
|
+
f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n"
|
|
158
|
+
+ df.to_markdown()
|
|
159
|
+
],
|
|
156
160
|
**kwargs,
|
|
157
161
|
)
|
|
158
162
|
self.log(title="Final SQL Prompt", message=prompt)
|
|
@@ -161,7 +165,6 @@ class VannaBase(ABC):
|
|
|
161
165
|
except Exception as e:
|
|
162
166
|
return f"Error running intermediate SQL: {e}"
|
|
163
167
|
|
|
164
|
-
|
|
165
168
|
return self.extract_sql(llm_response)
|
|
166
169
|
|
|
167
170
|
def extract_sql(self, llm_response: str) -> str:
|
|
@@ -182,6 +185,7 @@ class VannaBase(ABC):
|
|
|
182
185
|
"""
|
|
183
186
|
|
|
184
187
|
import re
|
|
188
|
+
|
|
185
189
|
"""
|
|
186
190
|
Extracts the SQL query from the LLM response, handling various formats including:
|
|
187
191
|
- WITH clause
|
|
@@ -191,7 +195,9 @@ class VannaBase(ABC):
|
|
|
191
195
|
"""
|
|
192
196
|
|
|
193
197
|
# Match CREATE TABLE ... AS SELECT
|
|
194
|
-
sqls = re.findall(
|
|
198
|
+
sqls = re.findall(
|
|
199
|
+
r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE
|
|
200
|
+
)
|
|
195
201
|
if sqls:
|
|
196
202
|
sql = sqls[-1]
|
|
197
203
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
@@ -212,7 +218,9 @@ class VannaBase(ABC):
|
|
|
212
218
|
return sql
|
|
213
219
|
|
|
214
220
|
# Match ```sql ... ``` blocks
|
|
215
|
-
sqls = re.findall(
|
|
221
|
+
sqls = re.findall(
|
|
222
|
+
r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE
|
|
223
|
+
)
|
|
216
224
|
if sqls:
|
|
217
225
|
sql = sqls[-1].strip()
|
|
218
226
|
self.log(title="Extracted SQL", message=f"{sql}")
|
|
@@ -246,7 +254,7 @@ class VannaBase(ABC):
|
|
|
246
254
|
parsed = sqlparse.parse(sql)
|
|
247
255
|
|
|
248
256
|
for statement in parsed:
|
|
249
|
-
if statement.get_type() ==
|
|
257
|
+
if statement.get_type() == "SELECT":
|
|
250
258
|
return True
|
|
251
259
|
|
|
252
260
|
return False
|
|
@@ -268,12 +276,14 @@ class VannaBase(ABC):
|
|
|
268
276
|
bool: True if a chart should be generated, False otherwise.
|
|
269
277
|
"""
|
|
270
278
|
|
|
271
|
-
if len(df) > 1 and df.select_dtypes(include=[
|
|
279
|
+
if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0:
|
|
272
280
|
return True
|
|
273
281
|
|
|
274
282
|
return False
|
|
275
283
|
|
|
276
|
-
def generate_rewritten_question(
|
|
284
|
+
def generate_rewritten_question(
|
|
285
|
+
self, last_question: str, new_question: str, **kwargs
|
|
286
|
+
) -> str:
|
|
277
287
|
"""
|
|
278
288
|
**Example:**
|
|
279
289
|
```python
|
|
@@ -294,8 +304,15 @@ class VannaBase(ABC):
|
|
|
294
304
|
return new_question
|
|
295
305
|
|
|
296
306
|
prompt = [
|
|
297
|
-
self.system_message(
|
|
298
|
-
|
|
307
|
+
self.system_message(
|
|
308
|
+
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."
|
|
309
|
+
),
|
|
310
|
+
self.user_message(
|
|
311
|
+
"First question: "
|
|
312
|
+
+ last_question
|
|
313
|
+
+ "\nSecond question: "
|
|
314
|
+
+ new_question
|
|
315
|
+
),
|
|
299
316
|
]
|
|
300
317
|
|
|
301
318
|
return self.submit_prompt(prompt=prompt, **kwargs)
|
|
@@ -326,8 +343,8 @@ class VannaBase(ABC):
|
|
|
326
343
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n"
|
|
327
344
|
),
|
|
328
345
|
self.user_message(
|
|
329
|
-
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
330
|
-
self._response_language()
|
|
346
|
+
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
347
|
+
+ self._response_language()
|
|
331
348
|
),
|
|
332
349
|
]
|
|
333
350
|
|
|
@@ -371,8 +388,8 @@ class VannaBase(ABC):
|
|
|
371
388
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
372
389
|
),
|
|
373
390
|
self.user_message(
|
|
374
|
-
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
375
|
-
self._response_language()
|
|
391
|
+
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
392
|
+
+ self._response_language()
|
|
376
393
|
),
|
|
377
394
|
]
|
|
378
395
|
|
|
@@ -568,7 +585,7 @@ class VannaBase(ABC):
|
|
|
568
585
|
|
|
569
586
|
def get_sql_prompt(
|
|
570
587
|
self,
|
|
571
|
-
initial_prompt
|
|
588
|
+
initial_prompt: str,
|
|
572
589
|
question: str,
|
|
573
590
|
question_sql_list: list,
|
|
574
591
|
ddl_list: list,
|
|
@@ -600,8 +617,10 @@ class VannaBase(ABC):
|
|
|
600
617
|
"""
|
|
601
618
|
|
|
602
619
|
if initial_prompt is None:
|
|
603
|
-
initial_prompt =
|
|
604
|
-
|
|
620
|
+
initial_prompt = (
|
|
621
|
+
f"You are a {self.dialect} expert. "
|
|
622
|
+
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
623
|
+
)
|
|
605
624
|
|
|
606
625
|
initial_prompt = self.add_ddl_to_prompt(
|
|
607
626
|
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
@@ -766,7 +785,7 @@ class VannaBase(ABC):
|
|
|
766
785
|
database: str,
|
|
767
786
|
role: Union[str, None] = None,
|
|
768
787
|
warehouse: Union[str, None] = None,
|
|
769
|
-
**kwargs
|
|
788
|
+
**kwargs,
|
|
770
789
|
):
|
|
771
790
|
try:
|
|
772
791
|
snowflake = __import__("snowflake.connector")
|
|
@@ -814,7 +833,7 @@ class VannaBase(ABC):
|
|
|
814
833
|
account=account,
|
|
815
834
|
database=database,
|
|
816
835
|
client_session_keep_alive=True,
|
|
817
|
-
**kwargs
|
|
836
|
+
**kwargs,
|
|
818
837
|
)
|
|
819
838
|
|
|
820
839
|
def run_sql_snowflake(sql: str) -> pd.DataFrame:
|
|
@@ -840,7 +859,7 @@ class VannaBase(ABC):
|
|
|
840
859
|
self.run_sql = run_sql_snowflake
|
|
841
860
|
self.run_sql_is_set = True
|
|
842
861
|
|
|
843
|
-
def connect_to_sqlite(self, url: str, check_same_thread: bool = False,
|
|
862
|
+
def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
|
|
844
863
|
"""
|
|
845
864
|
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
846
865
|
|
|
@@ -865,11 +884,7 @@ class VannaBase(ABC):
|
|
|
865
884
|
url = path
|
|
866
885
|
|
|
867
886
|
# Connect to the database
|
|
868
|
-
conn = sqlite3.connect(
|
|
869
|
-
url,
|
|
870
|
-
check_same_thread=check_same_thread,
|
|
871
|
-
**kwargs
|
|
872
|
-
)
|
|
887
|
+
conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs)
|
|
873
888
|
|
|
874
889
|
def run_sql_sqlite(sql: str):
|
|
875
890
|
return pd.read_sql_query(sql, conn)
|
|
@@ -885,9 +900,8 @@ class VannaBase(ABC):
|
|
|
885
900
|
user: str = None,
|
|
886
901
|
password: str = None,
|
|
887
902
|
port: int = None,
|
|
888
|
-
**kwargs
|
|
903
|
+
**kwargs,
|
|
889
904
|
):
|
|
890
|
-
|
|
891
905
|
"""
|
|
892
906
|
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
893
907
|
**Example:**
|
|
@@ -956,15 +970,20 @@ class VannaBase(ABC):
|
|
|
956
970
|
user=user,
|
|
957
971
|
password=password,
|
|
958
972
|
port=port,
|
|
959
|
-
**kwargs
|
|
973
|
+
**kwargs,
|
|
960
974
|
)
|
|
961
975
|
except psycopg2.Error as e:
|
|
962
976
|
raise ValidationError(e)
|
|
963
977
|
|
|
964
978
|
def connect_to_db():
|
|
965
|
-
return psycopg2.connect(
|
|
966
|
-
|
|
967
|
-
|
|
979
|
+
return psycopg2.connect(
|
|
980
|
+
host=host,
|
|
981
|
+
dbname=dbname,
|
|
982
|
+
user=user,
|
|
983
|
+
password=password,
|
|
984
|
+
port=port,
|
|
985
|
+
**kwargs,
|
|
986
|
+
)
|
|
968
987
|
|
|
969
988
|
def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
|
|
970
989
|
conn = None
|
|
@@ -997,14 +1016,13 @@ class VannaBase(ABC):
|
|
|
997
1016
|
raise ValidationError(e)
|
|
998
1017
|
|
|
999
1018
|
except Exception as e:
|
|
1000
|
-
|
|
1001
|
-
|
|
1019
|
+
conn.rollback()
|
|
1020
|
+
raise e
|
|
1002
1021
|
|
|
1003
1022
|
self.dialect = "PostgreSQL"
|
|
1004
1023
|
self.run_sql_is_set = True
|
|
1005
1024
|
self.run_sql = run_sql_postgres
|
|
1006
1025
|
|
|
1007
|
-
|
|
1008
1026
|
def connect_to_mysql(
|
|
1009
1027
|
self,
|
|
1010
1028
|
host: str = None,
|
|
@@ -1012,9 +1030,8 @@ class VannaBase(ABC):
|
|
|
1012
1030
|
user: str = None,
|
|
1013
1031
|
password: str = None,
|
|
1014
1032
|
port: int = None,
|
|
1015
|
-
**kwargs
|
|
1033
|
+
**kwargs,
|
|
1016
1034
|
):
|
|
1017
|
-
|
|
1018
1035
|
try:
|
|
1019
1036
|
import pymysql.cursors
|
|
1020
1037
|
except ImportError:
|
|
@@ -1063,7 +1080,7 @@ class VannaBase(ABC):
|
|
|
1063
1080
|
database=dbname,
|
|
1064
1081
|
port=port,
|
|
1065
1082
|
cursorclass=pymysql.cursors.DictCursor,
|
|
1066
|
-
**kwargs
|
|
1083
|
+
**kwargs,
|
|
1067
1084
|
)
|
|
1068
1085
|
except pymysql.Error as e:
|
|
1069
1086
|
raise ValidationError(e)
|
|
@@ -1100,9 +1117,8 @@ class VannaBase(ABC):
|
|
|
1100
1117
|
user: str = None,
|
|
1101
1118
|
password: str = None,
|
|
1102
1119
|
port: int = None,
|
|
1103
|
-
**kwargs
|
|
1120
|
+
**kwargs,
|
|
1104
1121
|
):
|
|
1105
|
-
|
|
1106
1122
|
try:
|
|
1107
1123
|
import clickhouse_connect
|
|
1108
1124
|
except ImportError:
|
|
@@ -1150,7 +1166,7 @@ class VannaBase(ABC):
|
|
|
1150
1166
|
username=user,
|
|
1151
1167
|
password=password,
|
|
1152
1168
|
database=dbname,
|
|
1153
|
-
**kwargs
|
|
1169
|
+
**kwargs,
|
|
1154
1170
|
)
|
|
1155
1171
|
print(conn)
|
|
1156
1172
|
except Exception as e:
|
|
@@ -1173,13 +1189,8 @@ class VannaBase(ABC):
|
|
|
1173
1189
|
self.run_sql = run_sql_clickhouse
|
|
1174
1190
|
|
|
1175
1191
|
def connect_to_oracle(
|
|
1176
|
-
self,
|
|
1177
|
-
user: str = None,
|
|
1178
|
-
password: str = None,
|
|
1179
|
-
dsn: str = None,
|
|
1180
|
-
**kwargs
|
|
1192
|
+
self, user: str = None, password: str = None, dsn: str = None, **kwargs
|
|
1181
1193
|
):
|
|
1182
|
-
|
|
1183
1194
|
"""
|
|
1184
1195
|
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1185
1196
|
**Example:**
|
|
@@ -1199,7 +1210,6 @@ class VannaBase(ABC):
|
|
|
1199
1210
|
try:
|
|
1200
1211
|
import oracledb
|
|
1201
1212
|
except ImportError:
|
|
1202
|
-
|
|
1203
1213
|
raise DependencyError(
|
|
1204
1214
|
"You need to install required dependencies to execute this method,"
|
|
1205
1215
|
" run command: \npip install oracledb"
|
|
@@ -1209,7 +1219,9 @@ class VannaBase(ABC):
|
|
|
1209
1219
|
dsn = os.getenv("DSN")
|
|
1210
1220
|
|
|
1211
1221
|
if not dsn:
|
|
1212
|
-
raise ImproperlyConfigured(
|
|
1222
|
+
raise ImproperlyConfigured(
|
|
1223
|
+
"Please set your Oracle dsn which should include host:port/sid"
|
|
1224
|
+
)
|
|
1213
1225
|
|
|
1214
1226
|
if not user:
|
|
1215
1227
|
user = os.getenv("USER")
|
|
@@ -1226,12 +1238,7 @@ class VannaBase(ABC):
|
|
|
1226
1238
|
conn = None
|
|
1227
1239
|
|
|
1228
1240
|
try:
|
|
1229
|
-
conn = oracledb.connect(
|
|
1230
|
-
user=user,
|
|
1231
|
-
password=password,
|
|
1232
|
-
dsn=dsn,
|
|
1233
|
-
**kwargs
|
|
1234
|
-
)
|
|
1241
|
+
conn = oracledb.connect(user=user, password=password, dsn=dsn, **kwargs)
|
|
1235
1242
|
except oracledb.Error as e:
|
|
1236
1243
|
raise ValidationError(e)
|
|
1237
1244
|
|
|
@@ -1239,7 +1246,9 @@ class VannaBase(ABC):
|
|
|
1239
1246
|
if conn:
|
|
1240
1247
|
try:
|
|
1241
1248
|
sql = sql.rstrip()
|
|
1242
|
-
if sql.endswith(
|
|
1249
|
+
if sql.endswith(
|
|
1250
|
+
";"
|
|
1251
|
+
): # fix for a known problem with Oracle db where an extra ; will cause an error.
|
|
1243
1252
|
sql = sql[:-1]
|
|
1244
1253
|
|
|
1245
1254
|
cs = conn.cursor()
|
|
@@ -1264,10 +1273,7 @@ class VannaBase(ABC):
|
|
|
1264
1273
|
self.run_sql = run_sql_oracle
|
|
1265
1274
|
|
|
1266
1275
|
def connect_to_bigquery(
|
|
1267
|
-
self,
|
|
1268
|
-
cred_file_path: str = None,
|
|
1269
|
-
project_id: str = None,
|
|
1270
|
-
**kwargs
|
|
1276
|
+
self, cred_file_path: str = None, project_id: str = None, **kwargs
|
|
1271
1277
|
):
|
|
1272
1278
|
"""
|
|
1273
1279
|
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
@@ -1316,7 +1322,7 @@ class VannaBase(ABC):
|
|
|
1316
1322
|
if not cred_file_path:
|
|
1317
1323
|
try:
|
|
1318
1324
|
conn = bigquery.Client(project=project_id)
|
|
1319
|
-
except:
|
|
1325
|
+
except Exception:
|
|
1320
1326
|
print("Could not found any google cloud implicit credentials")
|
|
1321
1327
|
else:
|
|
1322
1328
|
# Validate file path and pemissions
|
|
@@ -1331,11 +1337,9 @@ class VannaBase(ABC):
|
|
|
1331
1337
|
|
|
1332
1338
|
try:
|
|
1333
1339
|
conn = bigquery.Client(
|
|
1334
|
-
project=project_id,
|
|
1335
|
-
credentials=credentials,
|
|
1336
|
-
**kwargs
|
|
1340
|
+
project=project_id, credentials=credentials, **kwargs
|
|
1337
1341
|
)
|
|
1338
|
-
except:
|
|
1342
|
+
except Exception:
|
|
1339
1343
|
raise ImproperlyConfigured(
|
|
1340
1344
|
"Could not connect to bigquery please correct credentials"
|
|
1341
1345
|
)
|
|
@@ -1447,20 +1451,21 @@ class VannaBase(ABC):
|
|
|
1447
1451
|
self.dialect = "T-SQL / Microsoft SQL Server"
|
|
1448
1452
|
self.run_sql = run_sql_mssql
|
|
1449
1453
|
self.run_sql_is_set = True
|
|
1454
|
+
|
|
1450
1455
|
def connect_to_presto(
|
|
1451
1456
|
self,
|
|
1452
1457
|
host: str,
|
|
1453
|
-
catalog: str =
|
|
1454
|
-
schema: str =
|
|
1458
|
+
catalog: str = "hive",
|
|
1459
|
+
schema: str = "default",
|
|
1455
1460
|
user: str = None,
|
|
1456
1461
|
password: str = None,
|
|
1457
1462
|
port: int = None,
|
|
1458
1463
|
combined_pem_path: str = None,
|
|
1459
|
-
protocol: str =
|
|
1464
|
+
protocol: str = "https",
|
|
1460
1465
|
requests_kwargs: dict = None,
|
|
1461
|
-
**kwargs
|
|
1466
|
+
**kwargs,
|
|
1462
1467
|
):
|
|
1463
|
-
|
|
1468
|
+
"""
|
|
1464
1469
|
Connect to a Presto database using the specified parameters.
|
|
1465
1470
|
|
|
1466
1471
|
Args:
|
|
@@ -1480,101 +1485,103 @@ class VannaBase(ABC):
|
|
|
1480
1485
|
|
|
1481
1486
|
Returns:
|
|
1482
1487
|
None
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1488
|
+
"""
|
|
1489
|
+
try:
|
|
1490
|
+
from pyhive import presto
|
|
1491
|
+
except ImportError:
|
|
1492
|
+
raise DependencyError(
|
|
1493
|
+
"You need to install required dependencies to execute this method,"
|
|
1494
|
+
" run command: \npip install pyhive"
|
|
1495
|
+
)
|
|
1491
1496
|
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
if not host:
|
|
1496
|
-
raise ImproperlyConfigured("Please set your presto host")
|
|
1497
|
-
|
|
1498
|
-
if not catalog:
|
|
1499
|
-
catalog = os.getenv("PRESTO_CATALOG")
|
|
1500
|
-
|
|
1501
|
-
if not catalog:
|
|
1502
|
-
raise ImproperlyConfigured("Please set your presto catalog")
|
|
1503
|
-
|
|
1504
|
-
if not user:
|
|
1505
|
-
user = os.getenv("PRESTO_USER")
|
|
1506
|
-
|
|
1507
|
-
if not user:
|
|
1508
|
-
raise ImproperlyConfigured("Please set your presto user")
|
|
1509
|
-
|
|
1510
|
-
if not password:
|
|
1511
|
-
password = os.getenv("PRESTO_PASSWORD")
|
|
1512
|
-
|
|
1513
|
-
if not port:
|
|
1514
|
-
port = os.getenv("PRESTO_PORT")
|
|
1515
|
-
|
|
1516
|
-
if not port:
|
|
1517
|
-
raise ImproperlyConfigured("Please set your presto port")
|
|
1518
|
-
|
|
1519
|
-
conn = None
|
|
1520
|
-
|
|
1521
|
-
try:
|
|
1522
|
-
if requests_kwargs is None and combined_pem_path is not None:
|
|
1523
|
-
# use the combined pem file to verify the SSL connection
|
|
1524
|
-
requests_kwargs = {
|
|
1525
|
-
'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
|
|
1526
|
-
}
|
|
1527
|
-
conn = presto.Connection(host=host,
|
|
1528
|
-
username=user,
|
|
1529
|
-
password=password,
|
|
1530
|
-
catalog=catalog,
|
|
1531
|
-
schema=schema,
|
|
1532
|
-
port=port,
|
|
1533
|
-
protocol=protocol,
|
|
1534
|
-
requests_kwargs=requests_kwargs,
|
|
1535
|
-
**kwargs)
|
|
1536
|
-
except presto.Error as e:
|
|
1537
|
-
raise ValidationError(e)
|
|
1538
|
-
|
|
1539
|
-
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1540
|
-
if conn:
|
|
1541
|
-
try:
|
|
1542
|
-
sql = sql.rstrip()
|
|
1543
|
-
# fix for a known problem with presto db where an extra ; will cause an error.
|
|
1544
|
-
if sql.endswith(';'):
|
|
1545
|
-
sql = sql[:-1]
|
|
1546
|
-
cs = conn.cursor()
|
|
1547
|
-
cs.execute(sql)
|
|
1548
|
-
results = cs.fetchall()
|
|
1497
|
+
if not host:
|
|
1498
|
+
host = os.getenv("PRESTO_HOST")
|
|
1549
1499
|
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
results, columns=[desc[0] for desc in cs.description]
|
|
1553
|
-
)
|
|
1554
|
-
return df
|
|
1500
|
+
if not host:
|
|
1501
|
+
raise ImproperlyConfigured("Please set your presto host")
|
|
1555
1502
|
|
|
1556
|
-
|
|
1557
|
-
|
|
1503
|
+
if not catalog:
|
|
1504
|
+
catalog = os.getenv("PRESTO_CATALOG")
|
|
1505
|
+
|
|
1506
|
+
if not catalog:
|
|
1507
|
+
raise ImproperlyConfigured("Please set your presto catalog")
|
|
1508
|
+
|
|
1509
|
+
if not user:
|
|
1510
|
+
user = os.getenv("PRESTO_USER")
|
|
1511
|
+
|
|
1512
|
+
if not user:
|
|
1513
|
+
raise ImproperlyConfigured("Please set your presto user")
|
|
1514
|
+
|
|
1515
|
+
if not password:
|
|
1516
|
+
password = os.getenv("PRESTO_PASSWORD")
|
|
1517
|
+
|
|
1518
|
+
if not port:
|
|
1519
|
+
port = os.getenv("PRESTO_PORT")
|
|
1520
|
+
|
|
1521
|
+
if not port:
|
|
1522
|
+
raise ImproperlyConfigured("Please set your presto port")
|
|
1523
|
+
|
|
1524
|
+
conn = None
|
|
1525
|
+
|
|
1526
|
+
try:
|
|
1527
|
+
if requests_kwargs is None and combined_pem_path is not None:
|
|
1528
|
+
# use the combined pem file to verify the SSL connection
|
|
1529
|
+
requests_kwargs = {
|
|
1530
|
+
"verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
|
|
1531
|
+
}
|
|
1532
|
+
conn = presto.Connection(
|
|
1533
|
+
host=host,
|
|
1534
|
+
username=user,
|
|
1535
|
+
password=password,
|
|
1536
|
+
catalog=catalog,
|
|
1537
|
+
schema=schema,
|
|
1538
|
+
port=port,
|
|
1539
|
+
protocol=protocol,
|
|
1540
|
+
requests_kwargs=requests_kwargs,
|
|
1541
|
+
**kwargs,
|
|
1542
|
+
)
|
|
1543
|
+
except presto.Error as e:
|
|
1558
1544
|
raise ValidationError(e)
|
|
1559
1545
|
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1546
|
+
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1547
|
+
if conn:
|
|
1548
|
+
try:
|
|
1549
|
+
sql = sql.rstrip()
|
|
1550
|
+
# fix for a known problem with presto db where an extra ; will cause an error.
|
|
1551
|
+
if sql.endswith(";"):
|
|
1552
|
+
sql = sql[:-1]
|
|
1553
|
+
cs = conn.cursor()
|
|
1554
|
+
cs.execute(sql)
|
|
1555
|
+
results = cs.fetchall()
|
|
1563
1556
|
|
|
1564
|
-
|
|
1565
|
-
|
|
1557
|
+
# Create a pandas dataframe from the results
|
|
1558
|
+
df = pd.DataFrame(
|
|
1559
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1560
|
+
)
|
|
1561
|
+
return df
|
|
1562
|
+
|
|
1563
|
+
except presto.Error as e:
|
|
1564
|
+
print(e)
|
|
1565
|
+
raise ValidationError(e)
|
|
1566
|
+
|
|
1567
|
+
except Exception as e:
|
|
1568
|
+
print(e)
|
|
1569
|
+
raise e
|
|
1570
|
+
|
|
1571
|
+
self.run_sql_is_set = True
|
|
1572
|
+
self.run_sql = run_sql_presto
|
|
1566
1573
|
|
|
1567
1574
|
def connect_to_hive(
|
|
1568
1575
|
self,
|
|
1569
1576
|
host: str = None,
|
|
1570
|
-
dbname: str =
|
|
1577
|
+
dbname: str = "default",
|
|
1571
1578
|
user: str = None,
|
|
1572
1579
|
password: str = None,
|
|
1573
1580
|
port: int = None,
|
|
1574
|
-
auth: str =
|
|
1575
|
-
**kwargs
|
|
1581
|
+
auth: str = "CUSTOM",
|
|
1582
|
+
**kwargs,
|
|
1576
1583
|
):
|
|
1577
|
-
|
|
1584
|
+
"""
|
|
1578
1585
|
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1579
1586
|
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1580
1587
|
|
|
@@ -1588,78 +1595,80 @@ class VannaBase(ABC):
|
|
|
1588
1595
|
|
|
1589
1596
|
Returns:
|
|
1590
1597
|
None
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
try:
|
|
1594
|
-
from pyhive import hive
|
|
1595
|
-
except ImportError:
|
|
1596
|
-
raise DependencyError(
|
|
1597
|
-
"You need to install required dependencies to execute this method,"
|
|
1598
|
-
" run command: \npip install pyhive"
|
|
1599
|
-
)
|
|
1600
|
-
|
|
1601
|
-
if not host:
|
|
1602
|
-
host = os.getenv("HIVE_HOST")
|
|
1598
|
+
"""
|
|
1603
1599
|
|
|
1604
|
-
|
|
1605
|
-
|
|
1600
|
+
try:
|
|
1601
|
+
from pyhive import hive
|
|
1602
|
+
except ImportError:
|
|
1603
|
+
raise DependencyError(
|
|
1604
|
+
"You need to install required dependencies to execute this method,"
|
|
1605
|
+
" run command: \npip install pyhive"
|
|
1606
|
+
)
|
|
1606
1607
|
|
|
1607
|
-
|
|
1608
|
-
|
|
1608
|
+
if not host:
|
|
1609
|
+
host = os.getenv("HIVE_HOST")
|
|
1609
1610
|
|
|
1610
|
-
|
|
1611
|
-
|
|
1611
|
+
if not host:
|
|
1612
|
+
raise ImproperlyConfigured("Please set your hive host")
|
|
1612
1613
|
|
|
1613
|
-
|
|
1614
|
-
|
|
1614
|
+
if not dbname:
|
|
1615
|
+
dbname = os.getenv("HIVE_DATABASE")
|
|
1615
1616
|
|
|
1616
|
-
|
|
1617
|
-
|
|
1617
|
+
if not dbname:
|
|
1618
|
+
raise ImproperlyConfigured("Please set your hive database")
|
|
1618
1619
|
|
|
1619
|
-
|
|
1620
|
-
|
|
1620
|
+
if not user:
|
|
1621
|
+
user = os.getenv("HIVE_USER")
|
|
1621
1622
|
|
|
1622
|
-
|
|
1623
|
-
|
|
1623
|
+
if not user:
|
|
1624
|
+
raise ImproperlyConfigured("Please set your hive user")
|
|
1624
1625
|
|
|
1625
|
-
|
|
1626
|
-
|
|
1626
|
+
if not password:
|
|
1627
|
+
password = os.getenv("HIVE_PASSWORD")
|
|
1627
1628
|
|
|
1628
|
-
|
|
1629
|
+
if not port:
|
|
1630
|
+
port = os.getenv("HIVE_PORT")
|
|
1629
1631
|
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
username=user,
|
|
1633
|
-
password=password,
|
|
1634
|
-
database=dbname,
|
|
1635
|
-
port=port,
|
|
1636
|
-
auth=auth)
|
|
1637
|
-
except hive.Error as e:
|
|
1638
|
-
raise ValidationError(e)
|
|
1632
|
+
if not port:
|
|
1633
|
+
raise ImproperlyConfigured("Please set your hive port")
|
|
1639
1634
|
|
|
1640
|
-
|
|
1641
|
-
if conn:
|
|
1642
|
-
try:
|
|
1643
|
-
cs = conn.cursor()
|
|
1644
|
-
cs.execute(sql)
|
|
1645
|
-
results = cs.fetchall()
|
|
1635
|
+
conn = None
|
|
1646
1636
|
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1637
|
+
try:
|
|
1638
|
+
conn = hive.Connection(
|
|
1639
|
+
host=host,
|
|
1640
|
+
username=user,
|
|
1641
|
+
password=password,
|
|
1642
|
+
database=dbname,
|
|
1643
|
+
port=port,
|
|
1644
|
+
auth=auth,
|
|
1650
1645
|
)
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
except hive.Error as e:
|
|
1654
|
-
print(e)
|
|
1646
|
+
except hive.Error as e:
|
|
1655
1647
|
raise ValidationError(e)
|
|
1656
1648
|
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1649
|
+
def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
|
|
1650
|
+
if conn:
|
|
1651
|
+
try:
|
|
1652
|
+
cs = conn.cursor()
|
|
1653
|
+
cs.execute(sql)
|
|
1654
|
+
results = cs.fetchall()
|
|
1655
|
+
|
|
1656
|
+
# Create a pandas dataframe from the results
|
|
1657
|
+
df = pd.DataFrame(
|
|
1658
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1659
|
+
)
|
|
1660
|
+
return df
|
|
1660
1661
|
|
|
1661
|
-
|
|
1662
|
-
|
|
1662
|
+
except hive.Error as e:
|
|
1663
|
+
print(e)
|
|
1664
|
+
raise ValidationError(e)
|
|
1665
|
+
|
|
1666
|
+
except Exception as e:
|
|
1667
|
+
print(e)
|
|
1668
|
+
raise e
|
|
1669
|
+
|
|
1670
|
+
self.run_sql_is_set = True
|
|
1671
|
+
self.run_sql = run_sql_hive
|
|
1663
1672
|
|
|
1664
1673
|
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
|
|
1665
1674
|
"""
|
|
@@ -1717,22 +1726,23 @@ class VannaBase(ABC):
|
|
|
1717
1726
|
question = input("Enter a question: ")
|
|
1718
1727
|
|
|
1719
1728
|
try:
|
|
1720
|
-
sql = self.generate_sql(
|
|
1729
|
+
sql = self.generate_sql(
|
|
1730
|
+
question=question, allow_llm_to_see_data=allow_llm_to_see_data
|
|
1731
|
+
)
|
|
1721
1732
|
except Exception as e:
|
|
1722
1733
|
print(e)
|
|
1723
1734
|
return None, None, None
|
|
1724
1735
|
|
|
1725
1736
|
if print_results:
|
|
1726
1737
|
try:
|
|
1727
|
-
|
|
1738
|
+
from IPython.display import Code, display
|
|
1739
|
+
|
|
1728
1740
|
display(Code(sql))
|
|
1729
1741
|
except Exception as e:
|
|
1730
1742
|
print(sql)
|
|
1731
1743
|
|
|
1732
1744
|
if self.run_sql_is_set is False:
|
|
1733
|
-
print(
|
|
1734
|
-
"If you want to run the SQL query, connect to a database first."
|
|
1735
|
-
)
|
|
1745
|
+
print("If you want to run the SQL query, connect to a database first.")
|
|
1736
1746
|
|
|
1737
1747
|
if print_results:
|
|
1738
1748
|
return None
|
|
@@ -1776,6 +1786,7 @@ class VannaBase(ABC):
|
|
|
1776
1786
|
fig.show()
|
|
1777
1787
|
except Exception as e:
|
|
1778
1788
|
# Print stack trace
|
|
1789
|
+
traceback.print_stack()
|
|
1779
1790
|
traceback.print_exc()
|
|
1780
1791
|
print("Couldn't run plotly code: ", e)
|
|
1781
1792
|
if print_results:
|
|
@@ -1891,12 +1902,8 @@ class VannaBase(ABC):
|
|
|
1891
1902
|
table_column = df.columns[
|
|
1892
1903
|
df.columns.str.lower().str.contains("table_name")
|
|
1893
1904
|
].to_list()[0]
|
|
1894
|
-
columns = [database_column,
|
|
1895
|
-
|
|
1896
|
-
table_column]
|
|
1897
|
-
candidates = ["column_name",
|
|
1898
|
-
"data_type",
|
|
1899
|
-
"comment"]
|
|
1905
|
+
columns = [database_column, schema_column, table_column]
|
|
1906
|
+
candidates = ["column_name", "data_type", "comment"]
|
|
1900
1907
|
matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
|
|
1901
1908
|
columns += df.columns[matches].to_list()
|
|
1902
1909
|
|