vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +461 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +247 -223
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0.dist-info/METADATA +485 -0
  251. vanna-2.0.0.dist-info/RECORD +289 -0
  252. vanna-2.0.0.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.8.dist-info/METADATA +0 -408
  265. vanna-0.7.8.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,10 +2,7 @@ import datetime
2
2
  import os
3
3
  import uuid
4
4
  from typing import List, Optional
5
- from vertexai.language_models import (
6
- TextEmbeddingInput,
7
- TextEmbeddingModel
8
- )
5
+ from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
9
6
 
10
7
  import pandas as pd
11
8
  from google.cloud import bigquery
@@ -18,7 +15,9 @@ class BigQuery_VectorStore(VannaBase):
18
15
  self.config = config
19
16
 
20
17
  self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
21
- self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
18
+ self.n_results_documentation = config.get(
19
+ "n_results_documentation", config.get("n_results", 10)
20
+ )
22
21
  self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
23
22
 
24
23
  if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
@@ -47,7 +46,7 @@ class BigQuery_VectorStore(VannaBase):
47
46
 
48
47
  self.conn = bigquery.Client(project=self.project_id)
49
48
 
50
- dataset_name = self.config.get('bigquery_dataset_name', 'vanna_managed')
49
+ dataset_name = self.config.get("bigquery_dataset_name", "vanna_managed")
51
50
  self.dataset_id = f"{self.project_id}.{dataset_name}"
52
51
  dataset = bigquery.Dataset(self.dataset_id)
53
52
 
@@ -101,21 +100,35 @@ class BigQuery_VectorStore(VannaBase):
101
100
  # except Exception as e:
102
101
  # print(f"Failed to create vector index: {e}")
103
102
 
104
- def store_training_data(self, training_data_type: str, question: str, content: str, embedding: List[float], **kwargs) -> str:
103
+ def store_training_data(
104
+ self,
105
+ training_data_type: str,
106
+ question: str,
107
+ content: str,
108
+ embedding: List[float],
109
+ **kwargs,
110
+ ) -> str:
105
111
  id = str(uuid.uuid4())
106
112
  created_at = datetime.datetime.now()
107
- self.conn.insert_rows_json(self.table_id, [{
108
- "id": id,
109
- "training_data_type": training_data_type,
110
- "question": question,
111
- "content": content,
112
- "embedding": embedding,
113
- "created_at": created_at.isoformat()
114
- }])
113
+ self.conn.insert_rows_json(
114
+ self.table_id,
115
+ [
116
+ {
117
+ "id": id,
118
+ "training_data_type": training_data_type,
119
+ "question": question,
120
+ "content": content,
121
+ "embedding": embedding,
122
+ "created_at": created_at.isoformat(),
123
+ }
124
+ ],
125
+ )
115
126
 
116
127
  return id
117
128
 
118
- def fetch_similar_training_data(self, training_data_type: str, question: str, n_results, **kwargs) -> pd.DataFrame:
129
+ def fetch_similar_training_data(
130
+ self, training_data_type: str, question: str, n_results, **kwargs
131
+ ) -> pd.DataFrame:
119
132
  question_embedding = self.generate_question_embedding(question)
120
133
 
121
134
  query = f"""
@@ -145,29 +158,28 @@ class BigQuery_VectorStore(VannaBase):
145
158
  embeddings = None
146
159
 
147
160
  if self.type == "VERTEX_AI":
148
- input = [TextEmbeddingInput(data, task)]
149
- model = TextEmbeddingModel.from_pretrained("text-embedding-004")
161
+ input = [TextEmbeddingInput(data, task)]
162
+ model = TextEmbeddingModel.from_pretrained("text-embedding-004")
150
163
 
151
- result = model.get_embeddings(input)
164
+ result = model.get_embeddings(input)
152
165
 
153
- if len(result) > 0:
154
- embeddings = result[0].values
166
+ if len(result) > 0:
167
+ embeddings = result[0].values
155
168
  else:
156
- # Use Gemini Consumer API
157
- result = self.genai.embed_content(
158
- model="models/text-embedding-004",
159
- content=data,
160
- task_type=task)
169
+ # Use Gemini Consumer API
170
+ result = self.genai.embed_content(
171
+ model="models/text-embedding-004", content=data, task_type=task
172
+ )
161
173
 
162
- if 'embedding' in result:
163
- embeddings = result['embedding']
174
+ if "embedding" in result:
175
+ embeddings = result["embedding"]
164
176
 
165
177
  return embeddings
166
178
 
167
179
  def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
168
180
  result = self.get_embeddings(data, "RETRIEVAL_QUERY")
169
181
 
170
- if result != None:
182
+ if result is not None:
171
183
  return result
172
184
  else:
173
185
  raise ValueError("No embeddings returned")
@@ -175,7 +187,7 @@ class BigQuery_VectorStore(VannaBase):
175
187
  def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
176
188
  result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
177
189
 
178
- if result != None:
190
+ if result is not None:
179
191
  return result
180
192
  else:
181
193
  raise ValueError("No embeddings returned")
@@ -195,45 +207,66 @@ class BigQuery_VectorStore(VannaBase):
195
207
  return self.generate_storage_embedding(data, **kwargs)
196
208
 
197
209
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
198
- df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql)
210
+ df = self.fetch_similar_training_data(
211
+ training_data_type="sql", question=question, n_results=self.n_results_sql
212
+ )
199
213
 
200
214
  # Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql
201
- return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(orient="records")
215
+ return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(
216
+ orient="records"
217
+ )
202
218
 
203
219
  def get_related_ddl(self, question: str, **kwargs) -> list:
204
- df = self.fetch_similar_training_data(training_data_type="ddl", question=question, n_results=self.n_results_ddl)
220
+ df = self.fetch_similar_training_data(
221
+ training_data_type="ddl", question=question, n_results=self.n_results_ddl
222
+ )
205
223
 
206
224
  # Return a list of strings of the content
207
225
  return df["content"].tolist()
208
226
 
209
227
  def get_related_documentation(self, question: str, **kwargs) -> list:
210
- df = self.fetch_similar_training_data(training_data_type="documentation", question=question, n_results=self.n_results_documentation)
228
+ df = self.fetch_similar_training_data(
229
+ training_data_type="documentation",
230
+ question=question,
231
+ n_results=self.n_results_documentation,
232
+ )
211
233
 
212
234
  # Return a list of strings of the content
213
235
  return df["content"].tolist()
214
236
 
215
237
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
216
- doc = {
217
- "question": question,
218
- "sql": sql
219
- }
238
+ doc = {"question": question, "sql": sql}
220
239
 
221
240
  embedding = self.generate_embedding(str(doc))
222
241
 
223
- return self.store_training_data(training_data_type="sql", question=question, content=sql, embedding=embedding)
242
+ return self.store_training_data(
243
+ training_data_type="sql",
244
+ question=question,
245
+ content=sql,
246
+ embedding=embedding,
247
+ )
224
248
 
225
249
  def add_ddl(self, ddl: str, **kwargs) -> str:
226
250
  embedding = self.generate_embedding(ddl)
227
251
 
228
- return self.store_training_data(training_data_type="ddl", question="", content=ddl, embedding=embedding)
252
+ return self.store_training_data(
253
+ training_data_type="ddl", question="", content=ddl, embedding=embedding
254
+ )
229
255
 
230
256
  def add_documentation(self, documentation: str, **kwargs) -> str:
231
257
  embedding = self.generate_embedding(documentation)
232
258
 
233
- return self.store_training_data(training_data_type="documentation", question="", content=documentation, embedding=embedding)
259
+ return self.store_training_data(
260
+ training_data_type="documentation",
261
+ question="",
262
+ content=documentation,
263
+ embedding=embedding,
264
+ )
234
265
 
235
266
  def get_training_data(self, **kwargs) -> pd.DataFrame:
236
- query = f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"
267
+ query = (
268
+ f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"
269
+ )
237
270
 
238
271
  return self.conn.query(query).result().to_dataframe()
239
272
 
@@ -35,14 +35,18 @@ class GoogleGeminiChat(VannaBase):
35
35
  import vertexai
36
36
  from vertexai.generative_models import GenerativeModel
37
37
 
38
- json_file_path = config.get("google_credentials") # Assuming the JSON file path is provided in the config
38
+ json_file_path = config.get(
39
+ "google_credentials"
40
+ ) # Assuming the JSON file path is provided in the config
39
41
 
40
42
  if not json_file_path or not os.path.exists(json_file_path):
41
- raise FileNotFoundError(f"JSON credentials file not found at: {json_file_path}")
43
+ raise FileNotFoundError(
44
+ f"JSON credentials file not found at: {json_file_path}"
45
+ )
42
46
 
43
47
  try:
44
48
  # Validate and set the JSON file path for GOOGLE_APPLICATION_CREDENTIALS
45
- os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = json_file_path
49
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = json_file_path
46
50
 
47
51
  # Initialize VertexAI with the credentials
48
52
  credentials, _ = google.auth.default()
@@ -61,7 +61,6 @@ class Hf(VannaBase):
61
61
  return self.extract_sql_query(sql)
62
62
 
63
63
  def submit_prompt(self, prompt, **kwargs) -> str:
64
-
65
64
  input_ids = self.tokenizer.apply_chat_template(
66
65
  prompt, add_generation_prompt=True, return_tensors="pt"
67
66
  ).to(self.model.device)
@@ -33,6 +33,7 @@ class Milvus_VectorStore(VannaBase):
33
33
  For more models, please refer to:
34
34
  https://milvus.io/docs/embeddings.md
35
35
  """
36
+
36
37
  def __init__(self, config=None):
37
38
  VannaBase.__init__(self, config=config)
38
39
 
@@ -45,7 +46,9 @@ class Milvus_VectorStore(VannaBase):
45
46
  self.embedding_function = config.get("embedding_function")
46
47
  else:
47
48
  self.embedding_function = model.DefaultEmbeddingFunction()
48
- self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
49
+ self._embedding_dim = self.embedding_function.encode_documents(["foo"])[
50
+ 0
51
+ ].shape[0]
49
52
  self._create_collections()
50
53
  self.n_results = config.get("n_results", 10)
51
54
 
@@ -54,21 +57,32 @@ class Milvus_VectorStore(VannaBase):
54
57
  self._create_ddl_collection("vannaddl")
55
58
  self._create_doc_collection("vannadoc")
56
59
 
57
-
58
60
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
59
61
  return self.embedding_function.encode_documents(data).tolist()
60
62
 
61
-
62
63
  def _create_sql_collection(self, name: str):
63
64
  if not self.milvus_client.has_collection(collection_name=name):
64
65
  vannasql_schema = MilvusClient.create_schema(
65
66
  auto_id=False,
66
67
  enable_dynamic_field=False,
67
68
  )
68
- vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
69
- vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
70
- vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
71
- vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
69
+ vannasql_schema.add_field(
70
+ field_name="id",
71
+ datatype=DataType.VARCHAR,
72
+ max_length=65535,
73
+ is_primary=True,
74
+ )
75
+ vannasql_schema.add_field(
76
+ field_name="text", datatype=DataType.VARCHAR, max_length=65535
77
+ )
78
+ vannasql_schema.add_field(
79
+ field_name="sql", datatype=DataType.VARCHAR, max_length=65535
80
+ )
81
+ vannasql_schema.add_field(
82
+ field_name="vector",
83
+ datatype=DataType.FLOAT_VECTOR,
84
+ dim=self._embedding_dim,
85
+ )
72
86
 
73
87
  vannasql_index_params = self.milvus_client.prepare_index_params()
74
88
  vannasql_index_params.add_index(
@@ -81,7 +95,7 @@ class Milvus_VectorStore(VannaBase):
81
95
  collection_name=name,
82
96
  schema=vannasql_schema,
83
97
  index_params=vannasql_index_params,
84
- consistency_level="Strong"
98
+ consistency_level="Strong",
85
99
  )
86
100
 
87
101
  def _create_ddl_collection(self, name: str):
@@ -90,9 +104,20 @@ class Milvus_VectorStore(VannaBase):
90
104
  auto_id=False,
91
105
  enable_dynamic_field=False,
92
106
  )
93
- vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
94
- vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
95
- vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
107
+ vannaddl_schema.add_field(
108
+ field_name="id",
109
+ datatype=DataType.VARCHAR,
110
+ max_length=65535,
111
+ is_primary=True,
112
+ )
113
+ vannaddl_schema.add_field(
114
+ field_name="ddl", datatype=DataType.VARCHAR, max_length=65535
115
+ )
116
+ vannaddl_schema.add_field(
117
+ field_name="vector",
118
+ datatype=DataType.FLOAT_VECTOR,
119
+ dim=self._embedding_dim,
120
+ )
96
121
 
97
122
  vannaddl_index_params = self.milvus_client.prepare_index_params()
98
123
  vannaddl_index_params.add_index(
@@ -105,7 +130,7 @@ class Milvus_VectorStore(VannaBase):
105
130
  collection_name=name,
106
131
  schema=vannaddl_schema,
107
132
  index_params=vannaddl_index_params,
108
- consistency_level="Strong"
133
+ consistency_level="Strong",
109
134
  )
110
135
 
111
136
  def _create_doc_collection(self, name: str):
@@ -114,9 +139,20 @@ class Milvus_VectorStore(VannaBase):
114
139
  auto_id=False,
115
140
  enable_dynamic_field=False,
116
141
  )
117
- vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
118
- vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
119
- vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
142
+ vannadoc_schema.add_field(
143
+ field_name="id",
144
+ datatype=DataType.VARCHAR,
145
+ max_length=65535,
146
+ is_primary=True,
147
+ )
148
+ vannadoc_schema.add_field(
149
+ field_name="doc", datatype=DataType.VARCHAR, max_length=65535
150
+ )
151
+ vannadoc_schema.add_field(
152
+ field_name="vector",
153
+ datatype=DataType.FLOAT_VECTOR,
154
+ dim=self._embedding_dim,
155
+ )
120
156
 
121
157
  vannadoc_index_params = self.milvus_client.prepare_index_params()
122
158
  vannadoc_index_params.add_index(
@@ -129,7 +165,7 @@ class Milvus_VectorStore(VannaBase):
129
165
  collection_name=name,
130
166
  schema=vannadoc_schema,
131
167
  index_params=vannadoc_index_params,
132
- consistency_level="Strong"
168
+ consistency_level="Strong",
133
169
  )
134
170
 
135
171
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
@@ -139,12 +175,7 @@ class Milvus_VectorStore(VannaBase):
139
175
  embedding = self.embedding_function.encode_documents([question])[0]
140
176
  self.milvus_client.insert(
141
177
  collection_name="vannasql",
142
- data={
143
- "id": _id,
144
- "text": question,
145
- "sql": sql,
146
- "vector": embedding
147
- }
178
+ data={"id": _id, "text": question, "sql": sql, "vector": embedding},
148
179
  )
149
180
  return _id
150
181
 
@@ -155,11 +186,7 @@ class Milvus_VectorStore(VannaBase):
155
186
  embedding = self.embedding_function.encode_documents([ddl])[0]
156
187
  self.milvus_client.insert(
157
188
  collection_name="vannaddl",
158
- data={
159
- "id": _id,
160
- "ddl": ddl,
161
- "vector": embedding
162
- }
189
+ data={"id": _id, "ddl": ddl, "vector": embedding},
163
190
  )
164
191
  return _id
165
192
 
@@ -170,11 +197,7 @@ class Milvus_VectorStore(VannaBase):
170
197
  embedding = self.embedding_function.encode_documents([documentation])[0]
171
198
  self.milvus_client.insert(
172
199
  collection_name="vannadoc",
173
- data={
174
- "id": _id,
175
- "doc": documentation,
176
- "vector": embedding
177
- }
200
+ data={"id": _id, "doc": documentation, "vector": embedding},
178
201
  )
179
202
  return _id
180
203
 
@@ -237,7 +260,7 @@ class Milvus_VectorStore(VannaBase):
237
260
  data=embeddings,
238
261
  limit=self.n_results,
239
262
  output_fields=["text", "sql"],
240
- search_params=search_params
263
+ search_params=search_params,
241
264
  )
242
265
  res = res[0]
243
266
 
@@ -261,7 +284,7 @@ class Milvus_VectorStore(VannaBase):
261
284
  data=embeddings,
262
285
  limit=self.n_results,
263
286
  output_fields=["ddl"],
264
- search_params=search_params
287
+ search_params=search_params,
265
288
  )
266
289
  res = res[0]
267
290
 
@@ -282,7 +305,7 @@ class Milvus_VectorStore(VannaBase):
282
305
  data=embeddings,
283
306
  limit=self.n_results,
284
307
  output_fields=["doc"],
285
- search_params=search_params
308
+ search_params=search_params,
286
309
  )
287
310
  res = res[0]
288
311
 
@@ -1,4 +1,3 @@
1
-
2
1
  from ..base import VannaBase
3
2
 
4
3
 
@@ -0,0 +1,67 @@
1
+ import pandas as pd
2
+
3
+ from ..base import VannaBase
4
+
5
+
6
+ class MockVectorDB(VannaBase):
7
+ def __init__(self, config=None):
8
+ pass
9
+
10
+ def _get_id(self, value: str, **kwargs) -> str:
11
+ # Hash the value and return the ID
12
+ return str(hash(value))
13
+
14
+ def add_ddl(self, ddl: str, **kwargs) -> str:
15
+ return self._get_id(ddl)
16
+
17
+ def add_documentation(self, doc: str, **kwargs) -> str:
18
+ return self._get_id(doc)
19
+
20
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
21
+ return self._get_id(question)
22
+
23
+ def get_related_ddl(self, question: str, **kwargs) -> list:
24
+ return []
25
+
26
+ def get_related_documentation(self, question: str, **kwargs) -> list:
27
+ return []
28
+
29
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
30
+ return []
31
+
32
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
33
+ return pd.DataFrame(
34
+ {
35
+ "id": {
36
+ 0: "19546-ddl",
37
+ 1: "91597-sql",
38
+ 2: "133976-sql",
39
+ 3: "59851-doc",
40
+ 4: "73046-sql",
41
+ },
42
+ "training_data_type": {
43
+ 0: "ddl",
44
+ 1: "sql",
45
+ 2: "sql",
46
+ 3: "documentation",
47
+ 4: "sql",
48
+ },
49
+ "question": {
50
+ 0: None,
51
+ 1: "What are the top selling genres?",
52
+ 2: "What are the low 7 artists by sales?",
53
+ 3: None,
54
+ 4: "What is the total sales for each customer?",
55
+ },
56
+ "content": {
57
+ 0: "CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)",
58
+ 1: "SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;",
59
+ 2: "SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;",
60
+ 3: "This is a SQLite database. For dates rememeber to use SQLite syntax.",
61
+ 4: "SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;",
62
+ },
63
+ }
64
+ )
65
+
66
+ def remove_training_data(id: str, **kwargs) -> bool:
67
+ return True
@@ -0,0 +1,110 @@
1
+ import json
2
+ import re
3
+
4
+ from httpx import Timeout
5
+
6
+ from ..base import VannaBase
7
+ from ..exceptions import DependencyError
8
+
9
+
10
+ class Ollama(VannaBase):
11
+ def __init__(self, config=None):
12
+ try:
13
+ ollama = __import__("ollama")
14
+ except ImportError:
15
+ raise DependencyError(
16
+ "You need to install required dependencies to execute this method, run command:"
17
+ " \npip install ollama"
18
+ )
19
+
20
+ if not config:
21
+ raise ValueError("config must contain at least Ollama model")
22
+ if "model" not in config.keys():
23
+ raise ValueError("config must contain at least Ollama model")
24
+ self.host = config.get("ollama_host", "http://localhost:11434")
25
+ self.model = config["model"]
26
+ if ":" not in self.model:
27
+ self.model += ":latest"
28
+
29
+ self.ollama_timeout = config.get("ollama_timeout", 240.0)
30
+
31
+ self.ollama_client = ollama.Client(
32
+ self.host, timeout=Timeout(self.ollama_timeout)
33
+ )
34
+ self.keep_alive = config.get("keep_alive", None)
35
+ self.ollama_options = config.get("options", {})
36
+ self.num_ctx = self.ollama_options.get("num_ctx", 2048)
37
+ self.__pull_model_if_ne(self.ollama_client, self.model)
38
+
39
+ @staticmethod
40
+ def __pull_model_if_ne(ollama_client, model):
41
+ model_response = ollama_client.list()
42
+ model_lists = [
43
+ model_element["model"] for model_element in model_response.get("models", [])
44
+ ]
45
+ if model not in model_lists:
46
+ ollama_client.pull(model)
47
+
48
+ def system_message(self, message: str) -> any:
49
+ return {"role": "system", "content": message}
50
+
51
+ def user_message(self, message: str) -> any:
52
+ return {"role": "user", "content": message}
53
+
54
+ def assistant_message(self, message: str) -> any:
55
+ return {"role": "assistant", "content": message}
56
+
57
+ def extract_sql(self, llm_response):
58
+ """
59
+ Extracts the first SQL statement after the word 'select', ignoring case,
60
+ matches until the first semicolon, three backticks, or the end of the string,
61
+ and removes three backticks if they exist in the extracted string.
62
+
63
+ Args:
64
+ - llm_response (str): The string to search within for an SQL statement.
65
+
66
+ Returns:
67
+ - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
68
+ """
69
+ # Remove ollama-generated extra characters
70
+ llm_response = llm_response.replace("\\_", "_")
71
+ llm_response = llm_response.replace("\\", "")
72
+
73
+ # Regular expression to find ```sql' and capture until '```'
74
+ sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
75
+ # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
76
+ select_with = re.search(
77
+ r"(select|(with.*?as \())(.*?)(?=;|\[|```)",
78
+ llm_response,
79
+ re.IGNORECASE | re.DOTALL,
80
+ )
81
+ if sql:
82
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
83
+ return sql.group(1).replace("```", "")
84
+ elif select_with:
85
+ self.log(
86
+ f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}"
87
+ )
88
+ return select_with.group(0)
89
+ else:
90
+ return llm_response
91
+
92
+ def submit_prompt(self, prompt, **kwargs) -> str:
93
+ self.log(
94
+ f"Ollama parameters:\n"
95
+ f"model={self.model},\n"
96
+ f"options={self.ollama_options},\n"
97
+ f"keep_alive={self.keep_alive}"
98
+ )
99
+ self.log(f"Prompt Content:\n{json.dumps(prompt, ensure_ascii=False)}")
100
+ response_dict = self.ollama_client.chat(
101
+ model=self.model,
102
+ messages=prompt,
103
+ stream=False,
104
+ options=self.ollama_options,
105
+ keep_alive=self.keep_alive,
106
+ )
107
+
108
+ self.log(f"Ollama Response:\n{str(response_dict)}")
109
+
110
+ return response_dict["message"]["content"]
@@ -65,9 +65,7 @@ class OpenAI_Chat(VannaBase):
65
65
 
66
66
  if kwargs.get("model", None) is not None:
67
67
  model = kwargs.get("model", None)
68
- print(
69
- f"Using model {model} for {num_tokens} tokens (approx)"
70
- )
68
+ print(f"Using model {model} for {num_tokens} tokens (approx)")
71
69
  response = self.client.chat.completions.create(
72
70
  model=model,
73
71
  messages=prompt,
@@ -76,9 +74,7 @@ class OpenAI_Chat(VannaBase):
76
74
  )
77
75
  elif kwargs.get("engine", None) is not None:
78
76
  engine = kwargs.get("engine", None)
79
- print(
80
- f"Using model {engine} for {num_tokens} tokens (approx)"
81
- )
77
+ print(f"Using model {engine} for {num_tokens} tokens (approx)")
82
78
  response = self.client.chat.completions.create(
83
79
  engine=engine,
84
80
  messages=prompt,