vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +439 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +224 -217
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0rc1.dist-info/METADATA +868 -0
  251. vanna-2.0.0rc1.dist-info/RECORD +289 -0
  252. vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.9.dist-info/METADATA +0 -408
  265. vanna-0.7.9.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -8,15 +8,15 @@ import requests
8
8
  from ..advanced import VannaAdvanced
9
9
  from ..base import VannaBase
10
10
  from ..types import (
11
- DataFrameJSON,
12
- NewOrganization,
13
- OrganizationList,
14
- Question,
15
- QuestionSQLPair,
16
- Status,
17
- StatusWithId,
18
- StringData,
19
- TrainingData,
11
+ DataFrameJSON,
12
+ NewOrganization,
13
+ OrganizationList,
14
+ Question,
15
+ QuestionSQLPair,
16
+ Status,
17
+ StatusWithId,
18
+ StringData,
19
+ TrainingData,
20
20
  )
21
21
  from ..utils import sanitize_model_name
22
22
 
@@ -85,18 +85,25 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
85
85
  }
86
86
  """
87
87
 
88
- response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
88
+ response = requests.post(
89
+ self._graphql_endpoint, headers=self._graphql_headers, json={"query": query}
90
+ )
89
91
  response_json = response.json()
90
- if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
91
- self.log(response_json['data']['get_all_sql_functions'])
92
- resp = response_json['data']['get_all_sql_functions']
92
+ if (
93
+ response.status_code == 200
94
+ and "data" in response_json
95
+ and "get_all_sql_functions" in response_json["data"]
96
+ ):
97
+ self.log(response_json["data"]["get_all_sql_functions"])
98
+ resp = response_json["data"]["get_all_sql_functions"]
93
99
 
94
100
  print(resp)
95
101
 
96
102
  return resp
97
103
  else:
98
- raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
99
-
104
+ raise Exception(
105
+ f"Query failed to run by returning code of {response.status_code}. {response.text}"
106
+ )
100
107
 
101
108
  def get_function(self, question: str, additional_data: dict = {}) -> dict:
102
109
  query = """
@@ -121,21 +128,38 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
121
128
  }
122
129
  }
123
130
  """
124
- static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
125
- variables = {"question": question, "staticFunctionArguments": static_function_arguments}
126
- response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
131
+ static_function_arguments = [
132
+ {"name": key, "value": str(value)} for key, value in additional_data.items()
133
+ ]
134
+ variables = {
135
+ "question": question,
136
+ "staticFunctionArguments": static_function_arguments,
137
+ }
138
+ response = requests.post(
139
+ self._graphql_endpoint,
140
+ headers=self._graphql_headers,
141
+ json={"query": query, "variables": variables},
142
+ )
127
143
  response_json = response.json()
128
- if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
129
- self.log(response_json['data']['get_and_instantiate_function'])
130
- resp = response_json['data']['get_and_instantiate_function']
144
+ if (
145
+ response.status_code == 200
146
+ and "data" in response_json
147
+ and "get_and_instantiate_function" in response_json["data"]
148
+ ):
149
+ self.log(response_json["data"]["get_and_instantiate_function"])
150
+ resp = response_json["data"]["get_and_instantiate_function"]
131
151
 
132
152
  print(resp)
133
153
 
134
154
  return resp
135
155
  else:
136
- raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
156
+ raise Exception(
157
+ f"Query failed to run by returning code of {response.status_code}. {response.text}"
158
+ )
137
159
 
138
- def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
160
+ def create_function(
161
+ self, question: str, sql: str, plotly_code: str, **kwargs
162
+ ) -> dict:
139
163
  query = """
140
164
  mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
141
165
  generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
@@ -153,16 +177,27 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
153
177
  }
154
178
  """
155
179
  variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
156
- response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
180
+ response = requests.post(
181
+ self._graphql_endpoint,
182
+ headers=self._graphql_headers,
183
+ json={"query": query, "variables": variables},
184
+ )
157
185
  response_json = response.json()
158
- if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
159
- resp = response_json['data']['generate_and_create_sql_function']
186
+ if (
187
+ response.status_code == 200
188
+ and "data" in response_json
189
+ and response_json["data"] is not None
190
+ and "generate_and_create_sql_function" in response_json["data"]
191
+ ):
192
+ resp = response_json["data"]["generate_and_create_sql_function"]
160
193
 
161
194
  print(resp)
162
195
 
163
196
  return resp
164
197
  else:
165
- raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
198
+ raise Exception(
199
+ f"Query failed to run by returning code of {response.status_code}. {response.text}"
200
+ )
166
201
 
167
202
  def update_function(self, old_function_name: str, updated_function: dict) -> bool:
168
203
  """
@@ -187,41 +222,64 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
187
222
  """
188
223
 
189
224
  SQLFunctionUpdate = {
190
- 'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
225
+ "function_name",
226
+ "description",
227
+ "arguments",
228
+ "sql_template",
229
+ "post_processing_code_template",
191
230
  }
192
231
 
193
232
  # Define the expected keys for each argument in the arguments list
194
- ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
233
+ ArgumentKeys = {
234
+ "name",
235
+ "general_type",
236
+ "description",
237
+ "is_user_editable",
238
+ "available_values",
239
+ }
195
240
 
196
241
  # Function to validate and transform arguments
197
242
  def validate_arguments(args):
198
243
  return [
199
- {key: arg[key] for key in arg if key in ArgumentKeys}
200
- for arg in args
244
+ {key: arg[key] for key in arg if key in ArgumentKeys} for arg in args
201
245
  ]
202
246
 
203
247
  # Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
204
- updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
248
+ updated_function = {
249
+ key: value
250
+ for key, value in updated_function.items()
251
+ if key in SQLFunctionUpdate
252
+ }
205
253
 
206
254
  # Special handling for 'arguments' to ensure they conform to the spec
207
- if 'arguments' in updated_function:
208
- updated_function['arguments'] = validate_arguments(updated_function['arguments'])
255
+ if "arguments" in updated_function:
256
+ updated_function["arguments"] = validate_arguments(
257
+ updated_function["arguments"]
258
+ )
209
259
 
210
260
  variables = {
211
- "input": {
212
- "old_function_name": old_function_name,
213
- **updated_function
214
- }
261
+ "input": {"old_function_name": old_function_name, **updated_function}
215
262
  }
216
263
 
217
264
  print("variables", variables)
218
265
 
219
- response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
266
+ response = requests.post(
267
+ self._graphql_endpoint,
268
+ headers=self._graphql_headers,
269
+ json={"query": mutation, "variables": variables},
270
+ )
220
271
  response_json = response.json()
221
- if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
222
- return response_json['data']['update_sql_function']
272
+ if (
273
+ response.status_code == 200
274
+ and "data" in response_json
275
+ and response_json["data"] is not None
276
+ and "update_sql_function" in response_json["data"]
277
+ ):
278
+ return response_json["data"]["update_sql_function"]
223
279
  else:
224
- raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
280
+ raise Exception(
281
+ f"Mutation failed to run by returning code of {response.status_code}. {response.text}"
282
+ )
225
283
 
226
284
  def delete_function(self, function_name: str) -> bool:
227
285
  mutation = """
@@ -230,12 +288,23 @@ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
230
288
  }
231
289
  """
232
290
  variables = {"function_name": function_name}
233
- response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
291
+ response = requests.post(
292
+ self._graphql_endpoint,
293
+ headers=self._graphql_headers,
294
+ json={"query": mutation, "variables": variables},
295
+ )
234
296
  response_json = response.json()
235
- if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
236
- return response_json['data']['delete_sql_function']
297
+ if (
298
+ response.status_code == 200
299
+ and "data" in response_json
300
+ and response_json["data"] is not None
301
+ and "delete_sql_function" in response_json["data"]
302
+ ):
303
+ return response_json["data"]["delete_sql_function"]
237
304
  else:
238
- raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
305
+ raise Exception(
306
+ f"Mutation failed to run by returning code of {response.status_code}. {response.text}"
307
+ )
239
308
 
240
309
  def create_model(self, model: str, **kwargs) -> bool:
241
310
  """
@@ -80,13 +80,12 @@ class Vllm(VannaBase):
80
80
  }
81
81
 
82
82
  if self.auth_key is not None:
83
- headers = {
84
- 'Content-Type': 'application/json',
85
- 'Authorization': f'Bearer {self.auth_key}'
83
+ headers = {
84
+ "Content-Type": "application/json",
85
+ "Authorization": f"Bearer {self.auth_key}",
86
86
  }
87
87
 
88
- response = requests.post(url, headers=headers,json=data)
89
-
88
+ response = requests.post(url, headers=headers, json=data)
90
89
 
91
90
  else:
92
91
  response = requests.post(url, json=data)
@@ -95,4 +94,4 @@ class Vllm(VannaBase):
95
94
 
96
95
  self.log(response.text)
97
96
 
98
- return response_dict['choices'][0]['message']['content']
97
+ return response_dict["choices"][0]["message"]["content"]
@@ -6,7 +6,6 @@ from vanna.base import VannaBase
6
6
 
7
7
 
8
8
  class WeaviateDatabase(VannaBase):
9
-
10
9
  def __init__(self, config=None):
11
10
  """
12
11
  Initialize the VannaEnhanced class with the provided configuration.
@@ -42,30 +41,35 @@ class WeaviateDatabase(VannaBase):
42
41
  self.training_data_cluster = {
43
42
  "sql": "SQLTrainingDataEntry",
44
43
  "ddl": "DDLEntry",
45
- "doc": "DocumentationEntry"
44
+ "doc": "DocumentationEntry",
46
45
  }
47
46
 
48
47
  self._create_collections_if_not_exist()
49
48
 
50
49
  def _create_collections_if_not_exist(self):
51
50
  properties_dict = {
52
- self.training_data_cluster['ddl']: [
53
- wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
51
+ self.training_data_cluster["ddl"]: [
52
+ wvc.config.Property(
53
+ name="description", data_type=wvc.config.DataType.TEXT
54
+ ),
54
55
  ],
55
- self.training_data_cluster['doc']: [
56
- wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
56
+ self.training_data_cluster["doc"]: [
57
+ wvc.config.Property(
58
+ name="description", data_type=wvc.config.DataType.TEXT
59
+ ),
57
60
  ],
58
- self.training_data_cluster['sql']: [
61
+ self.training_data_cluster["sql"]: [
59
62
  wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
60
- wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT),
61
- ]
63
+ wvc.config.Property(
64
+ name="natural_language_question", data_type=wvc.config.DataType.TEXT
65
+ ),
66
+ ],
62
67
  }
63
68
 
64
69
  for cluster, properties in properties_dict.items():
65
70
  if not self.weaviate_client.collections.exists(cluster):
66
71
  self.weaviate_client.collections.create(
67
- name=cluster,
68
- properties=properties
72
+ name=cluster, properties=properties
69
73
  )
70
74
 
71
75
  def _initialize_weaviate_client(self):
@@ -74,28 +78,26 @@ class WeaviateDatabase(VannaBase):
74
78
  cluster_url=self.weaviate_url,
75
79
  auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
76
80
  additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
77
- skip_init_checks=True
81
+ skip_init_checks=True,
78
82
  )
79
83
  else:
80
84
  return weaviate.connect_to_local(
81
85
  port=self.weaviate_port,
82
86
  grpc_port=self.weaviate_grpc_port,
83
87
  additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
84
- skip_init_checks=True
88
+ skip_init_checks=True,
85
89
  )
86
90
 
87
91
  def generate_embedding(self, data: str, **kwargs):
88
- embedding_model = TextEmbedding(model_name=self.fastembed_model)
89
- embedding = next(embedding_model.embed(data))
90
- return embedding.tolist()
91
-
92
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
93
+ embedding = next(embedding_model.embed(data))
94
+ return embedding.tolist()
92
95
 
93
96
  def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
94
97
  self.weaviate_client.connect()
95
- response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert(
96
- properties=data_object,
97
- vector=vector
98
- )
98
+ response = self.weaviate_client.collections.get(
99
+ self.training_data_cluster[cluster_key]
100
+ ).data.insert(properties=data_object, vector=vector)
99
101
  self.weaviate_client.close()
100
102
  return response
101
103
 
@@ -103,31 +105,37 @@ class WeaviateDatabase(VannaBase):
103
105
  data_object = {
104
106
  "description": ddl,
105
107
  }
106
- response = self._insert_data('ddl', data_object, self.generate_embedding(ddl))
107
- return f'{response}-ddl'
108
+ response = self._insert_data("ddl", data_object, self.generate_embedding(ddl))
109
+ return f"{response}-ddl"
108
110
 
109
111
  def add_documentation(self, doc: str, **kwargs) -> str:
110
112
  data_object = {
111
113
  "description": doc,
112
114
  }
113
- response = self._insert_data('doc', data_object, self.generate_embedding(doc))
114
- return f'{response}-doc'
115
+ response = self._insert_data("doc", data_object, self.generate_embedding(doc))
116
+ return f"{response}-doc"
115
117
 
116
118
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
117
119
  data_object = {
118
120
  "sql": sql,
119
121
  "natural_language_question": question,
120
122
  }
121
- response = self._insert_data('sql', data_object, self.generate_embedding(question))
122
- return f'{response}-sql'
123
+ response = self._insert_data(
124
+ "sql", data_object, self.generate_embedding(question)
125
+ )
126
+ return f"{response}-sql"
123
127
 
124
- def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list) -> list:
128
+ def _query_collection(
129
+ self, cluster_key: str, vector_input: list, return_properties: list
130
+ ) -> list:
125
131
  self.weaviate_client.connect()
126
- collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key])
132
+ collection = self.weaviate_client.collections.get(
133
+ self.training_data_cluster[cluster_key]
134
+ )
127
135
  response = collection.query.near_vector(
128
136
  near_vector=vector_input,
129
137
  limit=self.n_results,
130
- return_properties=return_properties
138
+ return_properties=return_properties,
131
139
  )
132
140
  response_list = [item.properties for item in response.objects]
133
141
  self.weaviate_client.close()
@@ -135,18 +143,23 @@ class WeaviateDatabase(VannaBase):
135
143
 
136
144
  def get_related_ddl(self, question: str, **kwargs) -> list:
137
145
  vector_input = self.generate_embedding(question)
138
- response_list = self._query_collection('ddl', vector_input, ["description"])
146
+ response_list = self._query_collection("ddl", vector_input, ["description"])
139
147
  return [item["description"] for item in response_list]
140
148
 
141
149
  def get_related_documentation(self, question: str, **kwargs) -> list:
142
150
  vector_input = self.generate_embedding(question)
143
- response_list = self._query_collection('doc', vector_input, ["description"])
151
+ response_list = self._query_collection("doc", vector_input, ["description"])
144
152
  return [item["description"] for item in response_list]
145
153
 
146
154
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
147
155
  vector_input = self.generate_embedding(question)
148
- response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"])
149
- return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list]
156
+ response_list = self._query_collection(
157
+ "sql", vector_input, ["sql", "natural_language_question"]
158
+ )
159
+ return [
160
+ {"question": item["natural_language_question"], "sql": item["sql"]}
161
+ for item in response_list
162
+ ]
150
163
 
151
164
  def get_training_data(self, **kwargs) -> list:
152
165
  self.weaviate_client.connect()
@@ -163,13 +176,19 @@ class WeaviateDatabase(VannaBase):
163
176
  self.weaviate_client.connect()
164
177
  success = False
165
178
  if id.endswith("-sql"):
166
- id = id.replace('-sql', '')
167
- success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id)
179
+ id = id.replace("-sql", "")
180
+ success = self.weaviate_client.collections.get(
181
+ self.training_data_cluster["sql"]
182
+ ).data.delete_by_id(id)
168
183
  elif id.endswith("-ddl"):
169
- id = id.replace('-ddl', '')
170
- success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id)
184
+ id = id.replace("-ddl", "")
185
+ success = self.weaviate_client.collections.get(
186
+ self.training_data_cluster["ddl"]
187
+ ).data.delete_by_id(id)
171
188
  elif id.endswith("-doc"):
172
- id = id.replace('-doc', '')
173
- success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id)
189
+ id = id.replace("-doc", "")
190
+ success = self.weaviate_client.collections.get(
191
+ self.training_data_cluster["doc"]
192
+ ).data.delete_by_id(id)
174
193
  self.weaviate_client.close()
175
194
  return success
@@ -1,6 +1,6 @@
1
1
  from xinference_client.client.restful.restful_client import (
2
- Client,
3
- RESTfulChatModelHandle,
2
+ Client,
3
+ RESTfulChatModelHandle,
4
4
  )
5
5
 
6
6
  from ..base import VannaBase
@@ -43,11 +43,11 @@ class Xinference(VannaBase):
43
43
 
44
44
  xinference_model = self.xinference_client.get_model(model_uid)
45
45
  if isinstance(xinference_model, RESTfulChatModelHandle):
46
- print(
47
- f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
48
- )
46
+ print(f"Using model_uid {model_uid} for {num_tokens} tokens (approx)")
49
47
 
50
48
  response = xinference_model.chat(prompt)
51
49
  return response["choices"][0]["message"]["content"]
52
50
  else:
53
- raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")
51
+ raise NotImplementedError(
52
+ f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle"
53
+ )
vanna/py.typed ADDED
File without changes
@@ -0,0 +1,16 @@
1
+ """
2
+ Server implementations for the Vanna Agents framework.
3
+
4
+ This module provides Flask and FastAPI server factories for serving
5
+ Vanna agents over HTTP with SSE, WebSocket, and polling endpoints.
6
+ """
7
+
8
+ from .base import ChatHandler, ChatRequest, ChatStreamChunk
9
+ from .cli.server_runner import ExampleAgentLoader
10
+
11
+ __all__ = [
12
+ "ChatHandler",
13
+ "ChatRequest",
14
+ "ChatStreamChunk",
15
+ "ExampleAgentLoader",
16
+ ]
@@ -0,0 +1,8 @@
1
+ """
2
+ Entry point for running Vanna Agents servers.
3
+ """
4
+
5
+ from .cli.server_runner import main
6
+
7
+ if __name__ == "__main__":
8
+ main()
@@ -0,0 +1,18 @@
1
+ """
2
+ Base server components for the Vanna Agents framework.
3
+
4
+ This module provides framework-agnostic components for handling chat
5
+ requests and responses.
6
+ """
7
+
8
+ from .chat_handler import ChatHandler
9
+ from .models import ChatRequest, ChatStreamChunk, ChatResponse
10
+ from .templates import INDEX_HTML
11
+
12
+ __all__ = [
13
+ "ChatHandler",
14
+ "ChatRequest",
15
+ "ChatStreamChunk",
16
+ "ChatResponse",
17
+ "INDEX_HTML",
18
+ ]
@@ -0,0 +1,65 @@
1
+ """
2
+ Framework-agnostic chat handling logic.
3
+ """
4
+
5
+ import uuid
6
+ from typing import AsyncGenerator, List
7
+
8
+ from ...core import Agent
9
+ from .models import ChatRequest, ChatResponse, ChatStreamChunk
10
+
11
+
12
+ class ChatHandler:
13
+ """Core chat handling logic - framework agnostic."""
14
+
15
+ def __init__(
16
+ self,
17
+ agent: Agent,
18
+ ):
19
+ """Initialize chat handler.
20
+
21
+ Args:
22
+ agent: The agent to handle chat requests
23
+ """
24
+ self.agent = agent
25
+
26
+ async def handle_stream(
27
+ self, request: ChatRequest
28
+ ) -> AsyncGenerator[ChatStreamChunk, None]:
29
+ """Stream chat responses.
30
+
31
+ Args:
32
+ request: Chat request
33
+
34
+ Yields:
35
+ Chat stream chunks
36
+ """
37
+ conversation_id = request.conversation_id or self._generate_conversation_id()
38
+ # Use request_id from client for tracking, or use the one generated internally
39
+ request_id = request.request_id or str(uuid.uuid4())
40
+
41
+ async for component in self.agent.send_message(
42
+ request_context=request.request_context,
43
+ message=request.message,
44
+ conversation_id=conversation_id,
45
+ ):
46
+ yield ChatStreamChunk.from_component(component, conversation_id, request_id)
47
+
48
+ async def handle_poll(self, request: ChatRequest) -> ChatResponse:
49
+ """Handle polling-based chat.
50
+
51
+ Args:
52
+ request: Chat request
53
+
54
+ Returns:
55
+ Complete chat response
56
+ """
57
+ chunks = []
58
+ async for chunk in self.handle_stream(request):
59
+ chunks.append(chunk)
60
+
61
+ return ChatResponse.from_chunks(chunks)
62
+
63
+ def _generate_conversation_id(self) -> str:
64
+ """Generate new conversation ID."""
65
+ return f"conv_{uuid.uuid4().hex[:8]}"