vanna 0.7.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +461 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +247 -223
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0.dist-info/METADATA +485 -0
  251. vanna-2.0.0.dist-info/RECORD +289 -0
  252. vanna-2.0.0.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.8.dist-info/METADATA +0 -408
  265. vanna-0.7.8.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.8.dist-info → vanna-2.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1 @@
1
+ from .bedrock_converse import Bedrock_Converse
@@ -6,6 +6,7 @@ try:
6
6
  except ImportError:
7
7
  raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
8
8
 
9
+
9
10
  class Bedrock_Converse(VannaBase):
10
11
  def __init__(self, client=None, config=None):
11
12
  VannaBase.__init__(self, config=config)
@@ -13,29 +14,27 @@ class Bedrock_Converse(VannaBase):
13
14
  # default parameters
14
15
  self.temperature = 0.0
15
16
  self.max_tokens = 500
16
-
17
+
17
18
  if client is None:
18
19
  raise ValueError(
19
20
  "A valid Bedrock runtime client must be provided to invoke Bedrock models"
20
21
  )
21
22
  else:
22
23
  self.client = client
23
-
24
+
24
25
  if config is None:
25
26
  raise ValueError(
26
27
  "Config is required with model_id and inference parameters"
27
28
  )
28
-
29
+
29
30
  if "modelId" not in config:
30
- raise ValueError(
31
- "config must contain a modelId to invoke"
32
- )
31
+ raise ValueError("config must contain a modelId to invoke")
33
32
  else:
34
33
  self.model = config["modelId"]
35
-
34
+
36
35
  if "temperature" in config:
37
36
  self.temperature = config["temperature"]
38
-
37
+
39
38
  if "max_tokens" in config:
40
39
  self.max_tokens = config["max_tokens"]
41
40
 
@@ -51,7 +50,7 @@ class Bedrock_Converse(VannaBase):
51
50
  def submit_prompt(self, prompt, **kwargs) -> str:
52
51
  inference_config = {
53
52
  "temperature": self.temperature,
54
- "maxTokens": self.max_tokens
53
+ "maxTokens": self.max_tokens,
55
54
  }
56
55
  additional_model_fields = {
57
56
  "top_p": 1, # setting top_p value for nucleus sampling
@@ -64,13 +63,15 @@ class Bedrock_Converse(VannaBase):
64
63
  if role == "system":
65
64
  system_message = prompt_message["content"]
66
65
  else:
67
- no_system_prompt.append({"role": role, "content":[{"text": prompt_message["content"]}]})
66
+ no_system_prompt.append(
67
+ {"role": role, "content": [{"text": prompt_message["content"]}]}
68
+ )
68
69
 
69
70
  converse_api_params = {
70
71
  "modelId": self.model,
71
72
  "messages": no_system_prompt,
72
73
  "inferenceConfig": inference_config,
73
- "additionalModelRequestFields": additional_model_fields
74
+ "additionalModelRequestFields": additional_model_fields,
74
75
  }
75
76
 
76
77
  if system_message:
@@ -82,4 +83,4 @@ class Bedrock_Converse(VannaBase):
82
83
  return text_content
83
84
  except ClientError as err:
84
85
  message = err.response["Error"]["Message"]
85
- raise Exception(f"A Bedrock client error occurred: {message}")
86
+ raise Exception(f"A Bedrock client error occurred: {message}")
@@ -23,7 +23,9 @@ class ChromaDB_VectorStore(VannaBase):
23
23
  curr_client = config.get("client", "persistent")
24
24
  collection_metadata = config.get("collection_metadata", None)
25
25
  self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
26
- self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
26
+ self.n_results_documentation = config.get(
27
+ "n_results_documentation", config.get("n_results", 10)
28
+ )
27
29
  self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
28
30
 
29
31
  if curr_client == "persistent":
@@ -0,0 +1,2 @@
1
+ from .cohere_chat import Cohere_Chat
2
+ from .cohere_embeddings import Cohere_Embeddings
@@ -25,15 +25,17 @@ class Cohere_Chat(VannaBase):
25
25
 
26
26
  # Check for API key in environment variable
27
27
  api_key = os.getenv("COHERE_API_KEY")
28
-
28
+
29
29
  # Check for API key in config
30
30
  if config is not None and "api_key" in config:
31
31
  api_key = config["api_key"]
32
-
32
+
33
33
  # Validate API key
34
34
  if not api_key:
35
- raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
36
-
35
+ raise ValueError(
36
+ "Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
37
+ )
38
+
37
39
  # Initialize client with validated API key
38
40
  self.client = OpenAI(
39
41
  base_url="https://api.cohere.ai/compatibility/v1",
@@ -41,7 +43,10 @@ class Cohere_Chat(VannaBase):
41
43
  )
42
44
 
43
45
  def system_message(self, message: str) -> any:
44
- return {"role": "developer", "content": message} # Cohere uses 'developer' for system role
46
+ return {
47
+ "role": "developer",
48
+ "content": message,
49
+ } # Cohere uses 'developer' for system role
45
50
 
46
51
  def user_message(self, message: str) -> any:
47
52
  return {"role": "user", "content": message}
@@ -74,21 +79,21 @@ class Cohere_Chat(VannaBase):
74
79
  messages=prompt,
75
80
  temperature=self.temperature,
76
81
  )
77
-
82
+
78
83
  # Check if response has expected structure
79
- if not response or not hasattr(response, 'choices') or not response.choices:
84
+ if not response or not hasattr(response, "choices") or not response.choices:
80
85
  raise ValueError("Received empty or malformed response from API")
81
-
82
- if not response.choices[0] or not hasattr(response.choices[0], 'message'):
86
+
87
+ if not response.choices[0] or not hasattr(response.choices[0], "message"):
83
88
  raise ValueError("Response is missing expected 'message' field")
84
-
85
- if not hasattr(response.choices[0].message, 'content'):
89
+
90
+ if not hasattr(response.choices[0].message, "content"):
86
91
  raise ValueError("Response message is missing expected 'content' field")
87
-
92
+
88
93
  return response.choices[0].message.content
89
-
94
+
90
95
  except Exception as e:
91
96
  # Log the error and raise a more informative exception
92
97
  error_msg = f"Error processing Cohere chat response: {str(e)}"
93
98
  print(error_msg)
94
- raise Exception(error_msg)
99
+ raise Exception(error_msg)
@@ -8,10 +8,10 @@ from ..base import VannaBase
8
8
  class Cohere_Embeddings(VannaBase):
9
9
  def __init__(self, client=None, config=None):
10
10
  VannaBase.__init__(self, config=config)
11
-
11
+
12
12
  # Default embedding model
13
13
  self.model = "embed-multilingual-v3.0"
14
-
14
+
15
15
  if config is not None and "model" in config:
16
16
  self.model = config["model"]
17
17
 
@@ -21,15 +21,17 @@ class Cohere_Embeddings(VannaBase):
21
21
 
22
22
  # Check for API key in environment variable
23
23
  api_key = os.getenv("COHERE_API_KEY")
24
-
24
+
25
25
  # Check for API key in config
26
26
  if config is not None and "api_key" in config:
27
27
  api_key = config["api_key"]
28
-
28
+
29
29
  # Validate API key
30
30
  if not api_key:
31
- raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
32
-
31
+ raise ValueError(
32
+ "Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable."
33
+ )
34
+
33
35
  # Initialize client with validated API key
34
36
  self.client = OpenAI(
35
37
  base_url="https://api.cohere.ai/compatibility/v1",
@@ -39,33 +41,37 @@ class Cohere_Embeddings(VannaBase):
39
41
  def generate_embedding(self, data: str, **kwargs) -> list[float]:
40
42
  if not data:
41
43
  raise ValueError("Cannot generate embedding for empty input data")
42
-
44
+
43
45
  # Use model from kwargs, config, or default
44
46
  model = kwargs.get("model", self.model)
45
47
  if self.config is not None and "model" in self.config and model == self.model:
46
48
  model = self.config["model"]
47
-
48
- try:
49
+
50
+ try:
49
51
  embedding = self.client.embeddings.create(
50
52
  model=model,
51
53
  input=data,
52
54
  encoding_format="float", # Ensure we get float values
53
55
  )
54
-
56
+
55
57
  # Check if response has expected structure
56
- if not embedding or not hasattr(embedding, 'data') or not embedding.data:
57
- raise ValueError("Received empty or malformed embedding response from API")
58
-
59
- if not embedding.data[0] or not hasattr(embedding.data[0], 'embedding'):
60
- raise ValueError("Embedding response is missing expected 'embedding' field")
61
-
58
+ if not embedding or not hasattr(embedding, "data") or not embedding.data:
59
+ raise ValueError(
60
+ "Received empty or malformed embedding response from API"
61
+ )
62
+
63
+ if not embedding.data[0] or not hasattr(embedding.data[0], "embedding"):
64
+ raise ValueError(
65
+ "Embedding response is missing expected 'embedding' field"
66
+ )
67
+
62
68
  if not embedding.data[0].embedding:
63
69
  raise ValueError("Received empty embedding vector")
64
-
70
+
65
71
  return embedding.data[0].embedding
66
-
72
+
67
73
  except Exception as e:
68
74
  # Log the error and raise a more informative exception
69
75
  error_msg = f"Error generating embedding with Cohere: {str(e)}"
70
76
  print(error_msg)
71
- raise Exception(error_msg)
77
+ raise Exception(error_msg)
@@ -5,7 +5,6 @@ from openai import OpenAI
5
5
  from ..base import VannaBase
6
6
 
7
7
 
8
-
9
8
  # from vanna.chromadb import ChromaDB_VectorStore
10
9
 
11
10
  # class DeepSeekVanna(ChromaDB_VectorStore, DeepSeekChat):
@@ -27,12 +26,12 @@ class DeepSeekChat(VannaBase):
27
26
 
28
27
  if "model" not in config:
29
28
  raise ValueError("config must contain a DeepSeek model")
30
-
29
+
31
30
  api_key = config["api_key"]
32
31
  model = config["model"]
33
32
  self.model = model
34
33
  self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")
35
-
34
+
36
35
  def system_message(self, message: str) -> any:
37
36
  return {"role": "system", "content": message}
38
37
 
@@ -45,10 +44,10 @@ class DeepSeekChat(VannaBase):
45
44
  def generate_sql(self, question: str, **kwargs) -> str:
46
45
  # 使用父类的 generate_sql
47
46
  sql = super().generate_sql(question, **kwargs)
48
-
47
+
49
48
  # 替换 "\_" 为 "_"
50
49
  sql = sql.replace("\\_", "_")
51
-
50
+
52
51
  return sql
53
52
 
54
53
  def submit_prompt(self, prompt, **kwargs) -> str:
@@ -56,5 +55,5 @@ class DeepSeekChat(VannaBase):
56
55
  model=self.model,
57
56
  messages=prompt,
58
57
  )
59
-
58
+
60
59
  return chat_response.choices[0].message.content
@@ -0,0 +1 @@
1
+ from .faiss import FAISS
@@ -1,4 +1,4 @@
1
- import os
1
+ import os
2
2
  import json
3
3
  import uuid
4
4
  from typing import List, Dict, Any
@@ -10,13 +10,14 @@ import pandas as pd
10
10
  from ..base import VannaBase
11
11
  from ..exceptions import DependencyError
12
12
 
13
+
13
14
  class FAISS(VannaBase):
14
15
  def __init__(self, config=None):
15
16
  if config is None:
16
17
  config = {}
17
-
18
+
18
19
  VannaBase.__init__(self, config=config)
19
-
20
+
20
21
  try:
21
22
  import faiss
22
23
  except ImportError:
@@ -30,34 +31,48 @@ class FAISS(VannaBase):
30
31
  raise DependencyError(
31
32
  "SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
32
33
  )
33
-
34
+
34
35
  self.path = config.get("path", ".")
35
- self.embedding_dim = config.get('embedding_dim', 384)
36
- self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10))
37
- self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10))
38
- self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10))
36
+ self.embedding_dim = config.get("embedding_dim", 384)
37
+ self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
38
+ self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
39
+ self.n_results_documentation = config.get(
40
+ "n_results_documentation", config.get("n_results", 10)
41
+ )
39
42
  self.curr_client = config.get("client", "persistent")
40
43
 
41
- if self.curr_client == 'persistent':
42
- self.sql_index = self._load_or_create_index('sql_index.faiss')
43
- self.ddl_index = self._load_or_create_index('ddl_index.faiss')
44
- self.doc_index = self._load_or_create_index('doc_index.faiss')
45
- elif self.curr_client == 'in-memory':
44
+ if self.curr_client == "persistent":
45
+ self.sql_index = self._load_or_create_index("sql_index.faiss")
46
+ self.ddl_index = self._load_or_create_index("ddl_index.faiss")
47
+ self.doc_index = self._load_or_create_index("doc_index.faiss")
48
+ elif self.curr_client == "in-memory":
46
49
  self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
47
50
  self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
48
51
  self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
49
- elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client):
52
+ elif (
53
+ isinstance(self.curr_client, list)
54
+ and len(self.curr_client) == 3
55
+ and all(isinstance(idx, faiss.Index) for idx in self.curr_client)
56
+ ):
50
57
  self.sql_index = self.curr_client[0]
51
58
  self.ddl_index = self.curr_client[1]
52
59
  self.doc_index = self.curr_client[2]
53
60
  else:
54
- raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}")
55
-
56
- self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json')
57
- self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json')
58
- self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json')
61
+ raise ValueError(
62
+ f"Unsupported storage type was set in config: {self.curr_client}"
63
+ )
59
64
 
60
- model_name = config.get('embedding_model', 'all-MiniLM-L6-v2')
65
+ self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata(
66
+ "sql_metadata.json"
67
+ )
68
+ self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
69
+ "ddl_metadata.json"
70
+ )
71
+ self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata(
72
+ "doc_metadata.json"
73
+ )
74
+
75
+ model_name = config.get("embedding_model", "all-MiniLM-L6-v2")
61
76
  self.embedding_model = SentenceTransformer(model_name)
62
77
 
63
78
  def _load_or_create_index(self, filename):
@@ -69,25 +84,26 @@ class FAISS(VannaBase):
69
84
  def _load_or_create_metadata(self, filename):
70
85
  filepath = os.path.join(self.path, filename)
71
86
  if os.path.exists(filepath):
72
- with open(filepath, 'r') as f:
87
+ with open(filepath, "r") as f:
73
88
  return json.load(f)
74
89
  return []
75
90
 
76
91
  def _save_index(self, index, filename):
77
- if self.curr_client == 'persistent':
92
+ if self.curr_client == "persistent":
78
93
  filepath = os.path.join(self.path, filename)
79
94
  faiss.write_index(index, filepath)
80
95
 
81
96
  def _save_metadata(self, metadata, filename):
82
- if self.curr_client == 'persistent':
97
+ if self.curr_client == "persistent":
83
98
  filepath = os.path.join(self.path, filename)
84
- with open(filepath, 'w') as f:
99
+ with open(filepath, "w") as f:
85
100
  json.dump(metadata, f)
86
101
 
87
102
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
88
103
  embedding = self.embedding_model.encode(data)
89
- assert embedding.shape[0] == self.embedding_dim, \
104
+ assert embedding.shape[0] == self.embedding_dim, (
90
105
  f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
106
+ )
91
107
  return embedding.tolist()
92
108
 
93
109
  def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
@@ -96,81 +112,119 @@ class FAISS(VannaBase):
96
112
  entry_id = str(uuid.uuid4())
97
113
  metadata_list.append({"id": entry_id, **(extra_metadata or {})})
98
114
  return entry_id
99
-
115
+
100
116
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
101
- entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql})
102
- self._save_index(self.sql_index, 'sql_index.faiss')
103
- self._save_metadata(self.sql_metadata, 'sql_metadata.json')
117
+ entry_id = self._add_to_index(
118
+ self.sql_index,
119
+ self.sql_metadata,
120
+ question + " " + sql,
121
+ {"question": question, "sql": sql},
122
+ )
123
+ self._save_index(self.sql_index, "sql_index.faiss")
124
+ self._save_metadata(self.sql_metadata, "sql_metadata.json")
104
125
  return entry_id
105
126
 
106
127
  def add_ddl(self, ddl: str, **kwargs) -> str:
107
- entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl})
108
- self._save_index(self.ddl_index, 'ddl_index.faiss')
109
- self._save_metadata(self.ddl_metadata, 'ddl_metadata.json')
128
+ entry_id = self._add_to_index(
129
+ self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl}
130
+ )
131
+ self._save_index(self.ddl_index, "ddl_index.faiss")
132
+ self._save_metadata(self.ddl_metadata, "ddl_metadata.json")
110
133
  return entry_id
111
134
 
112
135
  def add_documentation(self, documentation: str, **kwargs) -> str:
113
- entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation})
114
- self._save_index(self.doc_index, 'doc_index.faiss')
115
- self._save_metadata(self.doc_metadata, 'doc_metadata.json')
136
+ entry_id = self._add_to_index(
137
+ self.doc_index,
138
+ self.doc_metadata,
139
+ documentation,
140
+ {"documentation": documentation},
141
+ )
142
+ self._save_index(self.doc_index, "doc_index.faiss")
143
+ self._save_metadata(self.doc_metadata, "doc_metadata.json")
116
144
  return entry_id
117
145
 
118
146
  def _get_similar(self, index, metadata_list, text, n_results) -> list:
119
147
  embedding = self.generate_embedding(text)
120
148
  D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
121
- return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
149
+ return (
150
+ [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
151
+ )
122
152
 
123
153
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
124
- return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql)
125
-
154
+ return self._get_similar(
155
+ self.sql_index, self.sql_metadata, question, self.n_results_sql
156
+ )
157
+
126
158
  def get_related_ddl(self, question: str, **kwargs) -> list:
127
- return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)]
159
+ return [
160
+ metadata["ddl"]
161
+ for metadata in self._get_similar(
162
+ self.ddl_index, self.ddl_metadata, question, self.n_results_ddl
163
+ )
164
+ ]
128
165
 
129
166
  def get_related_documentation(self, question: str, **kwargs) -> list:
130
- return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)]
167
+ return [
168
+ metadata["documentation"]
169
+ for metadata in self._get_similar(
170
+ self.doc_index,
171
+ self.doc_metadata,
172
+ question,
173
+ self.n_results_documentation,
174
+ )
175
+ ]
131
176
 
132
177
  def get_training_data(self, **kwargs) -> pd.DataFrame:
133
178
  sql_data = pd.DataFrame(self.sql_metadata)
134
- sql_data['training_data_type'] = 'sql'
179
+ sql_data["training_data_type"] = "sql"
135
180
 
136
181
  ddl_data = pd.DataFrame(self.ddl_metadata)
137
- ddl_data['training_data_type'] = 'ddl'
182
+ ddl_data["training_data_type"] = "ddl"
138
183
 
139
184
  doc_data = pd.DataFrame(self.doc_metadata)
140
- doc_data['training_data_type'] = 'documentation'
185
+ doc_data["training_data_type"] = "documentation"
141
186
 
142
187
  return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
143
188
 
144
189
  def remove_training_data(self, id: str, **kwargs) -> bool:
145
190
  for metadata_list, index, index_name in [
146
- (self.sql_metadata, self.sql_index, 'sql_index.faiss'),
147
- (self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'),
148
- (self.doc_metadata, self.doc_index, 'doc_index.faiss')
191
+ (self.sql_metadata, self.sql_index, "sql_index.faiss"),
192
+ (self.ddl_metadata, self.ddl_index, "ddl_index.faiss"),
193
+ (self.doc_metadata, self.doc_index, "doc_index.faiss"),
149
194
  ]:
150
195
  for i, item in enumerate(metadata_list):
151
- if item['id'] == id:
196
+ if item["id"] == id:
152
197
  del metadata_list[i]
153
198
  new_index = faiss.IndexFlatL2(self.embedding_dim)
154
- embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list]
199
+ embeddings = [
200
+ self.generate_embedding(json.dumps(m)) for m in metadata_list
201
+ ]
155
202
  if embeddings:
156
203
  new_index.add(np.array(embeddings, dtype=np.float32))
157
- setattr(self, index_name.split('.')[0], new_index)
158
-
159
- if self.curr_client == 'persistent':
204
+ setattr(self, index_name.split(".")[0], new_index)
205
+
206
+ if self.curr_client == "persistent":
160
207
  self._save_index(new_index, index_name)
161
- self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json")
162
-
208
+ self._save_metadata(
209
+ metadata_list, f"{index_name.split('.')[0]}_metadata.json"
210
+ )
211
+
163
212
  return True
164
213
  return False
165
214
 
166
215
  def remove_collection(self, collection_name: str) -> bool:
167
216
  if collection_name in ["sql", "ddl", "documentation"]:
168
- setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim))
217
+ setattr(
218
+ self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim)
219
+ )
169
220
  setattr(self, f"{collection_name}_metadata", [])
170
-
171
- if self.curr_client == 'persistent':
172
- self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss")
221
+
222
+ if self.curr_client == "persistent":
223
+ self._save_index(
224
+ getattr(self, f"{collection_name}_index"),
225
+ f"{collection_name}_index.faiss",
226
+ )
173
227
  self._save_metadata([], f"{collection_name}_metadata.json")
174
-
228
+
175
229
  return True
176
- return False
230
+ return False