vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +461 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +247 -223
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0.dist-info/METADATA +485 -0
  251. vanna-2.0.0.dist-info/RECORD +289 -0
  252. vanna-2.0.0.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.8.dist-info/METADATA +0 -408
  265. vanna-0.7.8.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -136,7 +136,7 @@ class VannaBase(ABC):
136
136
  llm_response = self.submit_prompt(prompt, **kwargs)
137
137
  self.log(title="LLM Response", message=llm_response)
138
138
 
139
- if 'intermediate_sql' in llm_response:
139
+ if "intermediate_sql" in llm_response:
140
140
  if not allow_llm_to_see_data:
141
141
  return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
142
142
 
@@ -152,7 +152,11 @@ class VannaBase(ABC):
152
152
  question=question,
153
153
  question_sql_list=question_sql_list,
154
154
  ddl_list=ddl_list,
155
- doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
155
+ doc_list=doc_list
156
+ + [
157
+ f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n"
158
+ + df.to_markdown()
159
+ ],
156
160
  **kwargs,
157
161
  )
158
162
  self.log(title="Final SQL Prompt", message=prompt)
@@ -161,7 +165,6 @@ class VannaBase(ABC):
161
165
  except Exception as e:
162
166
  return f"Error running intermediate SQL: {e}"
163
167
 
164
-
165
168
  return self.extract_sql(llm_response)
166
169
 
167
170
  def extract_sql(self, llm_response: str) -> str:
@@ -181,30 +184,52 @@ class VannaBase(ABC):
181
184
  str: The extracted SQL query.
182
185
  """
183
186
 
184
- # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
185
- sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
187
+ import re
188
+
189
+ """
190
+ Extracts the SQL query from the LLM response, handling various formats including:
191
+ - WITH clause
192
+ - SELECT statement
193
+ - CREATE TABLE AS SELECT
194
+ - Markdown code blocks
195
+ """
196
+
197
+ # Match CREATE TABLE ... AS SELECT
198
+ sqls = re.findall(
199
+ r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE
200
+ )
186
201
  if sqls:
187
202
  sql = sqls[-1]
188
203
  self.log(title="Extracted SQL", message=f"{sql}")
189
204
  return sql
190
205
 
191
- # If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
192
- sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
206
+ # Match WITH clause (CTEs)
207
+ sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
193
208
  if sqls:
194
209
  sql = sqls[-1]
195
210
  self.log(title="Extracted SQL", message=f"{sql}")
196
211
  return sql
197
212
 
198
- # If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
199
- sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
213
+ # Match SELECT ... ;
214
+ sqls = re.findall(r"\bSELECT\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
200
215
  if sqls:
201
216
  sql = sqls[-1]
202
217
  self.log(title="Extracted SQL", message=f"{sql}")
203
218
  return sql
204
219
 
205
- sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
220
+ # Match ```sql ... ``` blocks
221
+ sqls = re.findall(
222
+ r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE
223
+ )
206
224
  if sqls:
207
- sql = sqls[-1]
225
+ sql = sqls[-1].strip()
226
+ self.log(title="Extracted SQL", message=f"{sql}")
227
+ return sql
228
+
229
+ # Match any ``` ... ``` code blocks
230
+ sqls = re.findall(r"```(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
231
+ if sqls:
232
+ sql = sqls[-1].strip()
208
233
  self.log(title="Extracted SQL", message=f"{sql}")
209
234
  return sql
210
235
 
@@ -229,7 +254,7 @@ class VannaBase(ABC):
229
254
  parsed = sqlparse.parse(sql)
230
255
 
231
256
  for statement in parsed:
232
- if statement.get_type() == 'SELECT':
257
+ if statement.get_type() == "SELECT":
233
258
  return True
234
259
 
235
260
  return False
@@ -251,12 +276,14 @@ class VannaBase(ABC):
251
276
  bool: True if a chart should be generated, False otherwise.
252
277
  """
253
278
 
254
- if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0:
279
+ if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0:
255
280
  return True
256
281
 
257
282
  return False
258
283
 
259
- def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
284
+ def generate_rewritten_question(
285
+ self, last_question: str, new_question: str, **kwargs
286
+ ) -> str:
260
287
  """
261
288
  **Example:**
262
289
  ```python
@@ -277,8 +304,15 @@ class VannaBase(ABC):
277
304
  return new_question
278
305
 
279
306
  prompt = [
280
- self.system_message("Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
281
- self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
307
+ self.system_message(
308
+ "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."
309
+ ),
310
+ self.user_message(
311
+ "First question: "
312
+ + last_question
313
+ + "\nSecond question: "
314
+ + new_question
315
+ ),
282
316
  ]
283
317
 
284
318
  return self.submit_prompt(prompt=prompt, **kwargs)
@@ -309,8 +343,8 @@ class VannaBase(ABC):
309
343
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.head(25).to_markdown()}\n\n"
310
344
  ),
311
345
  self.user_message(
312
- f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
313
- self._response_language()
346
+ f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
347
+ + self._response_language()
314
348
  ),
315
349
  ]
316
350
 
@@ -354,8 +388,8 @@ class VannaBase(ABC):
354
388
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
355
389
  ),
356
390
  self.user_message(
357
- "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." +
358
- self._response_language()
391
+ "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
392
+ + self._response_language()
359
393
  ),
360
394
  ]
361
395
 
@@ -551,7 +585,7 @@ class VannaBase(ABC):
551
585
 
552
586
  def get_sql_prompt(
553
587
  self,
554
- initial_prompt : str,
588
+ initial_prompt: str,
555
589
  question: str,
556
590
  question_sql_list: list,
557
591
  ddl_list: list,
@@ -583,8 +617,10 @@ class VannaBase(ABC):
583
617
  """
584
618
 
585
619
  if initial_prompt is None:
586
- initial_prompt = f"You are a {self.dialect} expert. " + \
587
- "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
620
+ initial_prompt = (
621
+ f"You are a {self.dialect} expert. "
622
+ + "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
623
+ )
588
624
 
589
625
  initial_prompt = self.add_ddl_to_prompt(
590
626
  initial_prompt, ddl_list, max_tokens=self.max_tokens
@@ -749,7 +785,7 @@ class VannaBase(ABC):
749
785
  database: str,
750
786
  role: Union[str, None] = None,
751
787
  warehouse: Union[str, None] = None,
752
- **kwargs
788
+ **kwargs,
753
789
  ):
754
790
  try:
755
791
  snowflake = __import__("snowflake.connector")
@@ -797,7 +833,7 @@ class VannaBase(ABC):
797
833
  account=account,
798
834
  database=database,
799
835
  client_session_keep_alive=True,
800
- **kwargs
836
+ **kwargs,
801
837
  )
802
838
 
803
839
  def run_sql_snowflake(sql: str) -> pd.DataFrame:
@@ -823,7 +859,7 @@ class VannaBase(ABC):
823
859
  self.run_sql = run_sql_snowflake
824
860
  self.run_sql_is_set = True
825
861
 
826
- def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
862
+ def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
827
863
  """
828
864
  Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
829
865
 
@@ -848,11 +884,7 @@ class VannaBase(ABC):
848
884
  url = path
849
885
 
850
886
  # Connect to the database
851
- conn = sqlite3.connect(
852
- url,
853
- check_same_thread=check_same_thread,
854
- **kwargs
855
- )
887
+ conn = sqlite3.connect(url, check_same_thread=check_same_thread, **kwargs)
856
888
 
857
889
  def run_sql_sqlite(sql: str):
858
890
  return pd.read_sql_query(sql, conn)
@@ -868,9 +900,8 @@ class VannaBase(ABC):
868
900
  user: str = None,
869
901
  password: str = None,
870
902
  port: int = None,
871
- **kwargs
903
+ **kwargs,
872
904
  ):
873
-
874
905
  """
875
906
  Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
876
907
  **Example:**
@@ -939,15 +970,20 @@ class VannaBase(ABC):
939
970
  user=user,
940
971
  password=password,
941
972
  port=port,
942
- **kwargs
973
+ **kwargs,
943
974
  )
944
975
  except psycopg2.Error as e:
945
976
  raise ValidationError(e)
946
977
 
947
978
  def connect_to_db():
948
- return psycopg2.connect(host=host, dbname=dbname,
949
- user=user, password=password, port=port, **kwargs)
950
-
979
+ return psycopg2.connect(
980
+ host=host,
981
+ dbname=dbname,
982
+ user=user,
983
+ password=password,
984
+ port=port,
985
+ **kwargs,
986
+ )
951
987
 
952
988
  def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
953
989
  conn = None
@@ -980,14 +1016,13 @@ class VannaBase(ABC):
980
1016
  raise ValidationError(e)
981
1017
 
982
1018
  except Exception as e:
983
- conn.rollback()
984
- raise e
1019
+ conn.rollback()
1020
+ raise e
985
1021
 
986
1022
  self.dialect = "PostgreSQL"
987
1023
  self.run_sql_is_set = True
988
1024
  self.run_sql = run_sql_postgres
989
1025
 
990
-
991
1026
  def connect_to_mysql(
992
1027
  self,
993
1028
  host: str = None,
@@ -995,9 +1030,8 @@ class VannaBase(ABC):
995
1030
  user: str = None,
996
1031
  password: str = None,
997
1032
  port: int = None,
998
- **kwargs
1033
+ **kwargs,
999
1034
  ):
1000
-
1001
1035
  try:
1002
1036
  import pymysql.cursors
1003
1037
  except ImportError:
@@ -1046,7 +1080,7 @@ class VannaBase(ABC):
1046
1080
  database=dbname,
1047
1081
  port=port,
1048
1082
  cursorclass=pymysql.cursors.DictCursor,
1049
- **kwargs
1083
+ **kwargs,
1050
1084
  )
1051
1085
  except pymysql.Error as e:
1052
1086
  raise ValidationError(e)
@@ -1083,9 +1117,8 @@ class VannaBase(ABC):
1083
1117
  user: str = None,
1084
1118
  password: str = None,
1085
1119
  port: int = None,
1086
- **kwargs
1120
+ **kwargs,
1087
1121
  ):
1088
-
1089
1122
  try:
1090
1123
  import clickhouse_connect
1091
1124
  except ImportError:
@@ -1133,7 +1166,7 @@ class VannaBase(ABC):
1133
1166
  username=user,
1134
1167
  password=password,
1135
1168
  database=dbname,
1136
- **kwargs
1169
+ **kwargs,
1137
1170
  )
1138
1171
  print(conn)
1139
1172
  except Exception as e:
@@ -1156,13 +1189,8 @@ class VannaBase(ABC):
1156
1189
  self.run_sql = run_sql_clickhouse
1157
1190
 
1158
1191
  def connect_to_oracle(
1159
- self,
1160
- user: str = None,
1161
- password: str = None,
1162
- dsn: str = None,
1163
- **kwargs
1192
+ self, user: str = None, password: str = None, dsn: str = None, **kwargs
1164
1193
  ):
1165
-
1166
1194
  """
1167
1195
  Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1168
1196
  **Example:**
@@ -1182,7 +1210,6 @@ class VannaBase(ABC):
1182
1210
  try:
1183
1211
  import oracledb
1184
1212
  except ImportError:
1185
-
1186
1213
  raise DependencyError(
1187
1214
  "You need to install required dependencies to execute this method,"
1188
1215
  " run command: \npip install oracledb"
@@ -1192,7 +1219,9 @@ class VannaBase(ABC):
1192
1219
  dsn = os.getenv("DSN")
1193
1220
 
1194
1221
  if not dsn:
1195
- raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
1222
+ raise ImproperlyConfigured(
1223
+ "Please set your Oracle dsn which should include host:port/sid"
1224
+ )
1196
1225
 
1197
1226
  if not user:
1198
1227
  user = os.getenv("USER")
@@ -1209,12 +1238,7 @@ class VannaBase(ABC):
1209
1238
  conn = None
1210
1239
 
1211
1240
  try:
1212
- conn = oracledb.connect(
1213
- user=user,
1214
- password=password,
1215
- dsn=dsn,
1216
- **kwargs
1217
- )
1241
+ conn = oracledb.connect(user=user, password=password, dsn=dsn, **kwargs)
1218
1242
  except oracledb.Error as e:
1219
1243
  raise ValidationError(e)
1220
1244
 
@@ -1222,7 +1246,9 @@ class VannaBase(ABC):
1222
1246
  if conn:
1223
1247
  try:
1224
1248
  sql = sql.rstrip()
1225
- if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
1249
+ if sql.endswith(
1250
+ ";"
1251
+ ): # fix for a known problem with Oracle db where an extra ; will cause an error.
1226
1252
  sql = sql[:-1]
1227
1253
 
1228
1254
  cs = conn.cursor()
@@ -1247,10 +1273,7 @@ class VannaBase(ABC):
1247
1273
  self.run_sql = run_sql_oracle
1248
1274
 
1249
1275
  def connect_to_bigquery(
1250
- self,
1251
- cred_file_path: str = None,
1252
- project_id: str = None,
1253
- **kwargs
1276
+ self, cred_file_path: str = None, project_id: str = None, **kwargs
1254
1277
  ):
1255
1278
  """
1256
1279
  Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
@@ -1299,7 +1322,7 @@ class VannaBase(ABC):
1299
1322
  if not cred_file_path:
1300
1323
  try:
1301
1324
  conn = bigquery.Client(project=project_id)
1302
- except:
1325
+ except Exception:
1303
1326
  print("Could not found any google cloud implicit credentials")
1304
1327
  else:
1305
1328
  # Validate file path and pemissions
@@ -1314,11 +1337,9 @@ class VannaBase(ABC):
1314
1337
 
1315
1338
  try:
1316
1339
  conn = bigquery.Client(
1317
- project=project_id,
1318
- credentials=credentials,
1319
- **kwargs
1340
+ project=project_id, credentials=credentials, **kwargs
1320
1341
  )
1321
- except:
1342
+ except Exception:
1322
1343
  raise ImproperlyConfigured(
1323
1344
  "Could not connect to bigquery please correct credentials"
1324
1345
  )
@@ -1430,20 +1451,21 @@ class VannaBase(ABC):
1430
1451
  self.dialect = "T-SQL / Microsoft SQL Server"
1431
1452
  self.run_sql = run_sql_mssql
1432
1453
  self.run_sql_is_set = True
1454
+
1433
1455
  def connect_to_presto(
1434
1456
  self,
1435
1457
  host: str,
1436
- catalog: str = 'hive',
1437
- schema: str = 'default',
1458
+ catalog: str = "hive",
1459
+ schema: str = "default",
1438
1460
  user: str = None,
1439
1461
  password: str = None,
1440
1462
  port: int = None,
1441
1463
  combined_pem_path: str = None,
1442
- protocol: str = 'https',
1464
+ protocol: str = "https",
1443
1465
  requests_kwargs: dict = None,
1444
- **kwargs
1466
+ **kwargs,
1445
1467
  ):
1446
- """
1468
+ """
1447
1469
  Connect to a Presto database using the specified parameters.
1448
1470
 
1449
1471
  Args:
@@ -1463,101 +1485,103 @@ class VannaBase(ABC):
1463
1485
 
1464
1486
  Returns:
1465
1487
  None
1466
- """
1467
- try:
1468
- from pyhive import presto
1469
- except ImportError:
1470
- raise DependencyError(
1471
- "You need to install required dependencies to execute this method,"
1472
- " run command: \npip install pyhive"
1473
- )
1488
+ """
1489
+ try:
1490
+ from pyhive import presto
1491
+ except ImportError:
1492
+ raise DependencyError(
1493
+ "You need to install required dependencies to execute this method,"
1494
+ " run command: \npip install pyhive"
1495
+ )
1474
1496
 
1475
- if not host:
1476
- host = os.getenv("PRESTO_HOST")
1477
-
1478
- if not host:
1479
- raise ImproperlyConfigured("Please set your presto host")
1480
-
1481
- if not catalog:
1482
- catalog = os.getenv("PRESTO_CATALOG")
1483
-
1484
- if not catalog:
1485
- raise ImproperlyConfigured("Please set your presto catalog")
1486
-
1487
- if not user:
1488
- user = os.getenv("PRESTO_USER")
1489
-
1490
- if not user:
1491
- raise ImproperlyConfigured("Please set your presto user")
1492
-
1493
- if not password:
1494
- password = os.getenv("PRESTO_PASSWORD")
1495
-
1496
- if not port:
1497
- port = os.getenv("PRESTO_PORT")
1498
-
1499
- if not port:
1500
- raise ImproperlyConfigured("Please set your presto port")
1501
-
1502
- conn = None
1503
-
1504
- try:
1505
- if requests_kwargs is None and combined_pem_path is not None:
1506
- # use the combined pem file to verify the SSL connection
1507
- requests_kwargs = {
1508
- 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
1509
- }
1510
- conn = presto.Connection(host=host,
1511
- username=user,
1512
- password=password,
1513
- catalog=catalog,
1514
- schema=schema,
1515
- port=port,
1516
- protocol=protocol,
1517
- requests_kwargs=requests_kwargs,
1518
- **kwargs)
1519
- except presto.Error as e:
1520
- raise ValidationError(e)
1521
-
1522
- def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1523
- if conn:
1524
- try:
1525
- sql = sql.rstrip()
1526
- # fix for a known problem with presto db where an extra ; will cause an error.
1527
- if sql.endswith(';'):
1528
- sql = sql[:-1]
1529
- cs = conn.cursor()
1530
- cs.execute(sql)
1531
- results = cs.fetchall()
1497
+ if not host:
1498
+ host = os.getenv("PRESTO_HOST")
1532
1499
 
1533
- # Create a pandas dataframe from the results
1534
- df = pd.DataFrame(
1535
- results, columns=[desc[0] for desc in cs.description]
1536
- )
1537
- return df
1500
+ if not host:
1501
+ raise ImproperlyConfigured("Please set your presto host")
1538
1502
 
1539
- except presto.Error as e:
1540
- print(e)
1503
+ if not catalog:
1504
+ catalog = os.getenv("PRESTO_CATALOG")
1505
+
1506
+ if not catalog:
1507
+ raise ImproperlyConfigured("Please set your presto catalog")
1508
+
1509
+ if not user:
1510
+ user = os.getenv("PRESTO_USER")
1511
+
1512
+ if not user:
1513
+ raise ImproperlyConfigured("Please set your presto user")
1514
+
1515
+ if not password:
1516
+ password = os.getenv("PRESTO_PASSWORD")
1517
+
1518
+ if not port:
1519
+ port = os.getenv("PRESTO_PORT")
1520
+
1521
+ if not port:
1522
+ raise ImproperlyConfigured("Please set your presto port")
1523
+
1524
+ conn = None
1525
+
1526
+ try:
1527
+ if requests_kwargs is None and combined_pem_path is not None:
1528
+ # use the combined pem file to verify the SSL connection
1529
+ requests_kwargs = {
1530
+ "verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
1531
+ }
1532
+ conn = presto.Connection(
1533
+ host=host,
1534
+ username=user,
1535
+ password=password,
1536
+ catalog=catalog,
1537
+ schema=schema,
1538
+ port=port,
1539
+ protocol=protocol,
1540
+ requests_kwargs=requests_kwargs,
1541
+ **kwargs,
1542
+ )
1543
+ except presto.Error as e:
1541
1544
  raise ValidationError(e)
1542
1545
 
1543
- except Exception as e:
1544
- print(e)
1545
- raise e
1546
+ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1547
+ if conn:
1548
+ try:
1549
+ sql = sql.rstrip()
1550
+ # fix for a known problem with presto db where an extra ; will cause an error.
1551
+ if sql.endswith(";"):
1552
+ sql = sql[:-1]
1553
+ cs = conn.cursor()
1554
+ cs.execute(sql)
1555
+ results = cs.fetchall()
1556
+
1557
+ # Create a pandas dataframe from the results
1558
+ df = pd.DataFrame(
1559
+ results, columns=[desc[0] for desc in cs.description]
1560
+ )
1561
+ return df
1546
1562
 
1547
- self.run_sql_is_set = True
1548
- self.run_sql = run_sql_presto
1563
+ except presto.Error as e:
1564
+ print(e)
1565
+ raise ValidationError(e)
1566
+
1567
+ except Exception as e:
1568
+ print(e)
1569
+ raise e
1570
+
1571
+ self.run_sql_is_set = True
1572
+ self.run_sql = run_sql_presto
1549
1573
 
1550
1574
  def connect_to_hive(
1551
1575
  self,
1552
1576
  host: str = None,
1553
- dbname: str = 'default',
1577
+ dbname: str = "default",
1554
1578
  user: str = None,
1555
1579
  password: str = None,
1556
1580
  port: int = None,
1557
- auth: str = 'CUSTOM',
1558
- **kwargs
1581
+ auth: str = "CUSTOM",
1582
+ **kwargs,
1559
1583
  ):
1560
- """
1584
+ """
1561
1585
  Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1562
1586
  Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
1563
1587
 
@@ -1571,78 +1595,80 @@ class VannaBase(ABC):
1571
1595
 
1572
1596
  Returns:
1573
1597
  None
1574
- """
1575
-
1576
- try:
1577
- from pyhive import hive
1578
- except ImportError:
1579
- raise DependencyError(
1580
- "You need to install required dependencies to execute this method,"
1581
- " run command: \npip install pyhive"
1582
- )
1583
-
1584
- if not host:
1585
- host = os.getenv("HIVE_HOST")
1598
+ """
1586
1599
 
1587
- if not host:
1588
- raise ImproperlyConfigured("Please set your hive host")
1600
+ try:
1601
+ from pyhive import hive
1602
+ except ImportError:
1603
+ raise DependencyError(
1604
+ "You need to install required dependencies to execute this method,"
1605
+ " run command: \npip install pyhive"
1606
+ )
1589
1607
 
1590
- if not dbname:
1591
- dbname = os.getenv("HIVE_DATABASE")
1608
+ if not host:
1609
+ host = os.getenv("HIVE_HOST")
1592
1610
 
1593
- if not dbname:
1594
- raise ImproperlyConfigured("Please set your hive database")
1611
+ if not host:
1612
+ raise ImproperlyConfigured("Please set your hive host")
1595
1613
 
1596
- if not user:
1597
- user = os.getenv("HIVE_USER")
1614
+ if not dbname:
1615
+ dbname = os.getenv("HIVE_DATABASE")
1598
1616
 
1599
- if not user:
1600
- raise ImproperlyConfigured("Please set your hive user")
1617
+ if not dbname:
1618
+ raise ImproperlyConfigured("Please set your hive database")
1601
1619
 
1602
- if not password:
1603
- password = os.getenv("HIVE_PASSWORD")
1620
+ if not user:
1621
+ user = os.getenv("HIVE_USER")
1604
1622
 
1605
- if not port:
1606
- port = os.getenv("HIVE_PORT")
1623
+ if not user:
1624
+ raise ImproperlyConfigured("Please set your hive user")
1607
1625
 
1608
- if not port:
1609
- raise ImproperlyConfigured("Please set your hive port")
1626
+ if not password:
1627
+ password = os.getenv("HIVE_PASSWORD")
1610
1628
 
1611
- conn = None
1629
+ if not port:
1630
+ port = os.getenv("HIVE_PORT")
1612
1631
 
1613
- try:
1614
- conn = hive.Connection(host=host,
1615
- username=user,
1616
- password=password,
1617
- database=dbname,
1618
- port=port,
1619
- auth=auth)
1620
- except hive.Error as e:
1621
- raise ValidationError(e)
1632
+ if not port:
1633
+ raise ImproperlyConfigured("Please set your hive port")
1622
1634
 
1623
- def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
1624
- if conn:
1625
- try:
1626
- cs = conn.cursor()
1627
- cs.execute(sql)
1628
- results = cs.fetchall()
1635
+ conn = None
1629
1636
 
1630
- # Create a pandas dataframe from the results
1631
- df = pd.DataFrame(
1632
- results, columns=[desc[0] for desc in cs.description]
1637
+ try:
1638
+ conn = hive.Connection(
1639
+ host=host,
1640
+ username=user,
1641
+ password=password,
1642
+ database=dbname,
1643
+ port=port,
1644
+ auth=auth,
1633
1645
  )
1634
- return df
1635
-
1636
- except hive.Error as e:
1637
- print(e)
1646
+ except hive.Error as e:
1638
1647
  raise ValidationError(e)
1639
1648
 
1640
- except Exception as e:
1641
- print(e)
1642
- raise e
1649
+ def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]:
1650
+ if conn:
1651
+ try:
1652
+ cs = conn.cursor()
1653
+ cs.execute(sql)
1654
+ results = cs.fetchall()
1643
1655
 
1644
- self.run_sql_is_set = True
1645
- self.run_sql = run_sql_hive
1656
+ # Create a pandas dataframe from the results
1657
+ df = pd.DataFrame(
1658
+ results, columns=[desc[0] for desc in cs.description]
1659
+ )
1660
+ return df
1661
+
1662
+ except hive.Error as e:
1663
+ print(e)
1664
+ raise ValidationError(e)
1665
+
1666
+ except Exception as e:
1667
+ print(e)
1668
+ raise e
1669
+
1670
+ self.run_sql_is_set = True
1671
+ self.run_sql = run_sql_hive
1646
1672
 
1647
1673
  def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
1648
1674
  """
@@ -1700,22 +1726,23 @@ class VannaBase(ABC):
1700
1726
  question = input("Enter a question: ")
1701
1727
 
1702
1728
  try:
1703
- sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
1729
+ sql = self.generate_sql(
1730
+ question=question, allow_llm_to_see_data=allow_llm_to_see_data
1731
+ )
1704
1732
  except Exception as e:
1705
1733
  print(e)
1706
1734
  return None, None, None
1707
1735
 
1708
1736
  if print_results:
1709
1737
  try:
1710
- Code = __import__("IPython.display", fromList=["Code"]).Code
1738
+ from IPython.display import Code, display
1739
+
1711
1740
  display(Code(sql))
1712
1741
  except Exception as e:
1713
1742
  print(sql)
1714
1743
 
1715
1744
  if self.run_sql_is_set is False:
1716
- print(
1717
- "If you want to run the SQL query, connect to a database first."
1718
- )
1745
+ print("If you want to run the SQL query, connect to a database first.")
1719
1746
 
1720
1747
  if print_results:
1721
1748
  return None
@@ -1759,6 +1786,7 @@ class VannaBase(ABC):
1759
1786
  fig.show()
1760
1787
  except Exception as e:
1761
1788
  # Print stack trace
1789
+ traceback.print_stack()
1762
1790
  traceback.print_exc()
1763
1791
  print("Couldn't run plotly code: ", e)
1764
1792
  if print_results:
@@ -1874,12 +1902,8 @@ class VannaBase(ABC):
1874
1902
  table_column = df.columns[
1875
1903
  df.columns.str.lower().str.contains("table_name")
1876
1904
  ].to_list()[0]
1877
- columns = [database_column,
1878
- schema_column,
1879
- table_column]
1880
- candidates = ["column_name",
1881
- "data_type",
1882
- "comment"]
1905
+ columns = [database_column, schema_column, table_column]
1906
+ candidates = ["column_name", "data_type", "comment"]
1883
1907
  matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
1884
1908
  columns += df.columns[matches].to_list()
1885
1909