vanna 0.7.9__py3-none-any.whl → 2.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. vanna/__init__.py +167 -395
  2. vanna/agents/__init__.py +7 -0
  3. vanna/capabilities/__init__.py +17 -0
  4. vanna/capabilities/agent_memory/__init__.py +21 -0
  5. vanna/capabilities/agent_memory/base.py +103 -0
  6. vanna/capabilities/agent_memory/models.py +53 -0
  7. vanna/capabilities/file_system/__init__.py +14 -0
  8. vanna/capabilities/file_system/base.py +71 -0
  9. vanna/capabilities/file_system/models.py +25 -0
  10. vanna/capabilities/sql_runner/__init__.py +13 -0
  11. vanna/capabilities/sql_runner/base.py +37 -0
  12. vanna/capabilities/sql_runner/models.py +13 -0
  13. vanna/components/__init__.py +92 -0
  14. vanna/components/base.py +11 -0
  15. vanna/components/rich/__init__.py +83 -0
  16. vanna/components/rich/containers/__init__.py +7 -0
  17. vanna/components/rich/containers/card.py +20 -0
  18. vanna/components/rich/data/__init__.py +9 -0
  19. vanna/components/rich/data/chart.py +17 -0
  20. vanna/components/rich/data/dataframe.py +93 -0
  21. vanna/components/rich/feedback/__init__.py +21 -0
  22. vanna/components/rich/feedback/badge.py +16 -0
  23. vanna/components/rich/feedback/icon_text.py +14 -0
  24. vanna/components/rich/feedback/log_viewer.py +41 -0
  25. vanna/components/rich/feedback/notification.py +19 -0
  26. vanna/components/rich/feedback/progress.py +37 -0
  27. vanna/components/rich/feedback/status_card.py +28 -0
  28. vanna/components/rich/feedback/status_indicator.py +14 -0
  29. vanna/components/rich/interactive/__init__.py +21 -0
  30. vanna/components/rich/interactive/button.py +95 -0
  31. vanna/components/rich/interactive/task_list.py +58 -0
  32. vanna/components/rich/interactive/ui_state.py +93 -0
  33. vanna/components/rich/specialized/__init__.py +7 -0
  34. vanna/components/rich/specialized/artifact.py +20 -0
  35. vanna/components/rich/text.py +16 -0
  36. vanna/components/simple/__init__.py +15 -0
  37. vanna/components/simple/image.py +15 -0
  38. vanna/components/simple/link.py +15 -0
  39. vanna/components/simple/text.py +11 -0
  40. vanna/core/__init__.py +193 -0
  41. vanna/core/_compat.py +19 -0
  42. vanna/core/agent/__init__.py +10 -0
  43. vanna/core/agent/agent.py +1407 -0
  44. vanna/core/agent/config.py +123 -0
  45. vanna/core/audit/__init__.py +28 -0
  46. vanna/core/audit/base.py +299 -0
  47. vanna/core/audit/models.py +131 -0
  48. vanna/core/component_manager.py +329 -0
  49. vanna/core/components.py +53 -0
  50. vanna/core/enhancer/__init__.py +11 -0
  51. vanna/core/enhancer/base.py +94 -0
  52. vanna/core/enhancer/default.py +118 -0
  53. vanna/core/enricher/__init__.py +10 -0
  54. vanna/core/enricher/base.py +59 -0
  55. vanna/core/errors.py +47 -0
  56. vanna/core/evaluation/__init__.py +81 -0
  57. vanna/core/evaluation/base.py +186 -0
  58. vanna/core/evaluation/dataset.py +254 -0
  59. vanna/core/evaluation/evaluators.py +376 -0
  60. vanna/core/evaluation/report.py +289 -0
  61. vanna/core/evaluation/runner.py +313 -0
  62. vanna/core/filter/__init__.py +10 -0
  63. vanna/core/filter/base.py +67 -0
  64. vanna/core/lifecycle/__init__.py +10 -0
  65. vanna/core/lifecycle/base.py +83 -0
  66. vanna/core/llm/__init__.py +16 -0
  67. vanna/core/llm/base.py +40 -0
  68. vanna/core/llm/models.py +61 -0
  69. vanna/core/middleware/__init__.py +10 -0
  70. vanna/core/middleware/base.py +69 -0
  71. vanna/core/observability/__init__.py +11 -0
  72. vanna/core/observability/base.py +88 -0
  73. vanna/core/observability/models.py +47 -0
  74. vanna/core/recovery/__init__.py +11 -0
  75. vanna/core/recovery/base.py +84 -0
  76. vanna/core/recovery/models.py +32 -0
  77. vanna/core/registry.py +278 -0
  78. vanna/core/rich_component.py +156 -0
  79. vanna/core/simple_component.py +27 -0
  80. vanna/core/storage/__init__.py +14 -0
  81. vanna/core/storage/base.py +46 -0
  82. vanna/core/storage/models.py +46 -0
  83. vanna/core/system_prompt/__init__.py +13 -0
  84. vanna/core/system_prompt/base.py +36 -0
  85. vanna/core/system_prompt/default.py +157 -0
  86. vanna/core/tool/__init__.py +18 -0
  87. vanna/core/tool/base.py +70 -0
  88. vanna/core/tool/models.py +84 -0
  89. vanna/core/user/__init__.py +17 -0
  90. vanna/core/user/base.py +29 -0
  91. vanna/core/user/models.py +25 -0
  92. vanna/core/user/request_context.py +70 -0
  93. vanna/core/user/resolver.py +42 -0
  94. vanna/core/validation.py +164 -0
  95. vanna/core/workflow/__init__.py +12 -0
  96. vanna/core/workflow/base.py +254 -0
  97. vanna/core/workflow/default.py +789 -0
  98. vanna/examples/__init__.py +1 -0
  99. vanna/examples/__main__.py +44 -0
  100. vanna/examples/anthropic_quickstart.py +80 -0
  101. vanna/examples/artifact_example.py +293 -0
  102. vanna/examples/claude_sqlite_example.py +236 -0
  103. vanna/examples/coding_agent_example.py +300 -0
  104. vanna/examples/custom_system_prompt_example.py +174 -0
  105. vanna/examples/default_workflow_handler_example.py +208 -0
  106. vanna/examples/email_auth_example.py +340 -0
  107. vanna/examples/evaluation_example.py +269 -0
  108. vanna/examples/extensibility_example.py +262 -0
  109. vanna/examples/minimal_example.py +67 -0
  110. vanna/examples/mock_auth_example.py +227 -0
  111. vanna/examples/mock_custom_tool.py +311 -0
  112. vanna/examples/mock_quickstart.py +79 -0
  113. vanna/examples/mock_quota_example.py +145 -0
  114. vanna/examples/mock_rich_components_demo.py +396 -0
  115. vanna/examples/mock_sqlite_example.py +223 -0
  116. vanna/examples/openai_quickstart.py +83 -0
  117. vanna/examples/primitive_components_demo.py +305 -0
  118. vanna/examples/quota_lifecycle_example.py +139 -0
  119. vanna/examples/visualization_example.py +251 -0
  120. vanna/integrations/__init__.py +17 -0
  121. vanna/integrations/anthropic/__init__.py +9 -0
  122. vanna/integrations/anthropic/llm.py +270 -0
  123. vanna/integrations/azureopenai/__init__.py +9 -0
  124. vanna/integrations/azureopenai/llm.py +329 -0
  125. vanna/integrations/azuresearch/__init__.py +7 -0
  126. vanna/integrations/azuresearch/agent_memory.py +413 -0
  127. vanna/integrations/bigquery/__init__.py +5 -0
  128. vanna/integrations/bigquery/sql_runner.py +81 -0
  129. vanna/integrations/chromadb/__init__.py +104 -0
  130. vanna/integrations/chromadb/agent_memory.py +416 -0
  131. vanna/integrations/clickhouse/__init__.py +5 -0
  132. vanna/integrations/clickhouse/sql_runner.py +82 -0
  133. vanna/integrations/duckdb/__init__.py +5 -0
  134. vanna/integrations/duckdb/sql_runner.py +65 -0
  135. vanna/integrations/faiss/__init__.py +7 -0
  136. vanna/integrations/faiss/agent_memory.py +431 -0
  137. vanna/integrations/google/__init__.py +9 -0
  138. vanna/integrations/google/gemini.py +370 -0
  139. vanna/integrations/hive/__init__.py +5 -0
  140. vanna/integrations/hive/sql_runner.py +87 -0
  141. vanna/integrations/local/__init__.py +17 -0
  142. vanna/integrations/local/agent_memory/__init__.py +7 -0
  143. vanna/integrations/local/agent_memory/in_memory.py +285 -0
  144. vanna/integrations/local/audit.py +59 -0
  145. vanna/integrations/local/file_system.py +242 -0
  146. vanna/integrations/local/file_system_conversation_store.py +255 -0
  147. vanna/integrations/local/storage.py +62 -0
  148. vanna/integrations/marqo/__init__.py +7 -0
  149. vanna/integrations/marqo/agent_memory.py +354 -0
  150. vanna/integrations/milvus/__init__.py +7 -0
  151. vanna/integrations/milvus/agent_memory.py +458 -0
  152. vanna/integrations/mock/__init__.py +9 -0
  153. vanna/integrations/mock/llm.py +65 -0
  154. vanna/integrations/mssql/__init__.py +5 -0
  155. vanna/integrations/mssql/sql_runner.py +66 -0
  156. vanna/integrations/mysql/__init__.py +5 -0
  157. vanna/integrations/mysql/sql_runner.py +92 -0
  158. vanna/integrations/ollama/__init__.py +7 -0
  159. vanna/integrations/ollama/llm.py +252 -0
  160. vanna/integrations/openai/__init__.py +10 -0
  161. vanna/integrations/openai/llm.py +267 -0
  162. vanna/integrations/openai/responses.py +163 -0
  163. vanna/integrations/opensearch/__init__.py +7 -0
  164. vanna/integrations/opensearch/agent_memory.py +411 -0
  165. vanna/integrations/oracle/__init__.py +5 -0
  166. vanna/integrations/oracle/sql_runner.py +75 -0
  167. vanna/integrations/pinecone/__init__.py +7 -0
  168. vanna/integrations/pinecone/agent_memory.py +329 -0
  169. vanna/integrations/plotly/__init__.py +5 -0
  170. vanna/integrations/plotly/chart_generator.py +313 -0
  171. vanna/integrations/postgres/__init__.py +9 -0
  172. vanna/integrations/postgres/sql_runner.py +112 -0
  173. vanna/integrations/premium/agent_memory/__init__.py +7 -0
  174. vanna/integrations/premium/agent_memory/premium.py +186 -0
  175. vanna/integrations/presto/__init__.py +5 -0
  176. vanna/integrations/presto/sql_runner.py +107 -0
  177. vanna/integrations/qdrant/__init__.py +7 -0
  178. vanna/integrations/qdrant/agent_memory.py +439 -0
  179. vanna/integrations/snowflake/__init__.py +5 -0
  180. vanna/integrations/snowflake/sql_runner.py +147 -0
  181. vanna/integrations/sqlite/__init__.py +9 -0
  182. vanna/integrations/sqlite/sql_runner.py +65 -0
  183. vanna/integrations/weaviate/__init__.py +7 -0
  184. vanna/integrations/weaviate/agent_memory.py +428 -0
  185. vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_embeddings.py +11 -11
  186. vanna/legacy/__init__.py +403 -0
  187. vanna/legacy/adapter.py +463 -0
  188. vanna/{advanced → legacy/advanced}/__init__.py +3 -1
  189. vanna/{anthropic → legacy/anthropic}/anthropic_chat.py +9 -7
  190. vanna/{azuresearch → legacy/azuresearch}/azuresearch_vector.py +79 -41
  191. vanna/{base → legacy/base}/base.py +224 -217
  192. vanna/legacy/bedrock/__init__.py +1 -0
  193. vanna/{bedrock → legacy/bedrock}/bedrock_converse.py +13 -12
  194. vanna/{chromadb → legacy/chromadb}/chromadb_vector.py +3 -1
  195. vanna/legacy/cohere/__init__.py +2 -0
  196. vanna/{cohere → legacy/cohere}/cohere_chat.py +19 -14
  197. vanna/{cohere → legacy/cohere}/cohere_embeddings.py +25 -19
  198. vanna/{deepseek → legacy/deepseek}/deepseek_chat.py +5 -6
  199. vanna/legacy/faiss/__init__.py +1 -0
  200. vanna/{faiss → legacy/faiss}/faiss.py +113 -59
  201. vanna/{flask → legacy/flask}/__init__.py +84 -43
  202. vanna/{flask → legacy/flask}/assets.py +5 -5
  203. vanna/{flask → legacy/flask}/auth.py +5 -4
  204. vanna/{google → legacy/google}/bigquery_vector.py +75 -42
  205. vanna/{google → legacy/google}/gemini_chat.py +7 -3
  206. vanna/{hf → legacy/hf}/hf.py +0 -1
  207. vanna/{milvus → legacy/milvus}/milvus_vector.py +58 -35
  208. vanna/{mock → legacy/mock}/llm.py +0 -1
  209. vanna/legacy/mock/vectordb.py +67 -0
  210. vanna/legacy/ollama/ollama.py +110 -0
  211. vanna/{openai → legacy/openai}/openai_chat.py +2 -6
  212. vanna/legacy/opensearch/opensearch_vector.py +369 -0
  213. vanna/legacy/opensearch/opensearch_vector_semantic.py +200 -0
  214. vanna/legacy/oracle/oracle_vector.py +584 -0
  215. vanna/{pgvector → legacy/pgvector}/pgvector.py +42 -13
  216. vanna/{qdrant → legacy/qdrant}/qdrant.py +2 -6
  217. vanna/legacy/qianfan/Qianfan_Chat.py +170 -0
  218. vanna/legacy/qianfan/Qianfan_embeddings.py +36 -0
  219. vanna/legacy/qianwen/QianwenAI_chat.py +132 -0
  220. vanna/{remote.py → legacy/remote.py} +28 -26
  221. vanna/{utils.py → legacy/utils.py} +6 -11
  222. vanna/{vannadb → legacy/vannadb}/vannadb_vector.py +115 -46
  223. vanna/{vllm → legacy/vllm}/vllm.py +5 -6
  224. vanna/{weaviate → legacy/weaviate}/weaviate_vector.py +59 -40
  225. vanna/{xinference → legacy/xinference}/xinference.py +6 -6
  226. vanna/py.typed +0 -0
  227. vanna/servers/__init__.py +16 -0
  228. vanna/servers/__main__.py +8 -0
  229. vanna/servers/base/__init__.py +18 -0
  230. vanna/servers/base/chat_handler.py +65 -0
  231. vanna/servers/base/models.py +111 -0
  232. vanna/servers/base/rich_chat_handler.py +141 -0
  233. vanna/servers/base/templates.py +331 -0
  234. vanna/servers/cli/__init__.py +7 -0
  235. vanna/servers/cli/server_runner.py +204 -0
  236. vanna/servers/fastapi/__init__.py +7 -0
  237. vanna/servers/fastapi/app.py +163 -0
  238. vanna/servers/fastapi/routes.py +183 -0
  239. vanna/servers/flask/__init__.py +7 -0
  240. vanna/servers/flask/app.py +132 -0
  241. vanna/servers/flask/routes.py +137 -0
  242. vanna/tools/__init__.py +41 -0
  243. vanna/tools/agent_memory.py +322 -0
  244. vanna/tools/file_system.py +879 -0
  245. vanna/tools/python.py +222 -0
  246. vanna/tools/run_sql.py +165 -0
  247. vanna/tools/visualize_data.py +195 -0
  248. vanna/utils/__init__.py +0 -0
  249. vanna/web_components/__init__.py +44 -0
  250. vanna-2.0.0rc1.dist-info/METADATA +868 -0
  251. vanna-2.0.0rc1.dist-info/RECORD +289 -0
  252. vanna-2.0.0rc1.dist-info/entry_points.txt +3 -0
  253. vanna/bedrock/__init__.py +0 -1
  254. vanna/cohere/__init__.py +0 -2
  255. vanna/faiss/__init__.py +0 -1
  256. vanna/mock/vectordb.py +0 -55
  257. vanna/ollama/ollama.py +0 -103
  258. vanna/opensearch/opensearch_vector.py +0 -392
  259. vanna/opensearch/opensearch_vector_semantic.py +0 -175
  260. vanna/oracle/oracle_vector.py +0 -585
  261. vanna/qianfan/Qianfan_Chat.py +0 -165
  262. vanna/qianfan/Qianfan_embeddings.py +0 -36
  263. vanna/qianwen/QianwenAI_chat.py +0 -133
  264. vanna-0.7.9.dist-info/METADATA +0 -408
  265. vanna-0.7.9.dist-info/RECORD +0 -79
  266. /vanna/{ZhipuAI → legacy/ZhipuAI}/ZhipuAI_Chat.py +0 -0
  267. /vanna/{ZhipuAI → legacy/ZhipuAI}/__init__.py +0 -0
  268. /vanna/{anthropic → legacy/anthropic}/__init__.py +0 -0
  269. /vanna/{azuresearch → legacy/azuresearch}/__init__.py +0 -0
  270. /vanna/{base → legacy/base}/__init__.py +0 -0
  271. /vanna/{chromadb → legacy/chromadb}/__init__.py +0 -0
  272. /vanna/{deepseek → legacy/deepseek}/__init__.py +0 -0
  273. /vanna/{exceptions → legacy/exceptions}/__init__.py +0 -0
  274. /vanna/{google → legacy/google}/__init__.py +0 -0
  275. /vanna/{hf → legacy/hf}/__init__.py +0 -0
  276. /vanna/{local.py → legacy/local.py} +0 -0
  277. /vanna/{marqo → legacy/marqo}/__init__.py +0 -0
  278. /vanna/{marqo → legacy/marqo}/marqo.py +0 -0
  279. /vanna/{milvus → legacy/milvus}/__init__.py +0 -0
  280. /vanna/{mistral → legacy/mistral}/__init__.py +0 -0
  281. /vanna/{mistral → legacy/mistral}/mistral.py +0 -0
  282. /vanna/{mock → legacy/mock}/__init__.py +0 -0
  283. /vanna/{mock → legacy/mock}/embedding.py +0 -0
  284. /vanna/{ollama → legacy/ollama}/__init__.py +0 -0
  285. /vanna/{openai → legacy/openai}/__init__.py +0 -0
  286. /vanna/{openai → legacy/openai}/openai_embeddings.py +0 -0
  287. /vanna/{opensearch → legacy/opensearch}/__init__.py +0 -0
  288. /vanna/{oracle → legacy/oracle}/__init__.py +0 -0
  289. /vanna/{pgvector → legacy/pgvector}/__init__.py +0 -0
  290. /vanna/{pinecone → legacy/pinecone}/__init__.py +0 -0
  291. /vanna/{pinecone → legacy/pinecone}/pinecone_vector.py +0 -0
  292. /vanna/{qdrant → legacy/qdrant}/__init__.py +0 -0
  293. /vanna/{qianfan → legacy/qianfan}/__init__.py +0 -0
  294. /vanna/{qianwen → legacy/qianwen}/QianwenAI_embeddings.py +0 -0
  295. /vanna/{qianwen → legacy/qianwen}/__init__.py +0 -0
  296. /vanna/{types → legacy/types}/__init__.py +0 -0
  297. /vanna/{vannadb → legacy/vannadb}/__init__.py +0 -0
  298. /vanna/{vllm → legacy/vllm}/__init__.py +0 -0
  299. /vanna/{weaviate → legacy/weaviate}/__init__.py +0 -0
  300. /vanna/{xinference → legacy/xinference}/__init__.py +0 -0
  301. {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/WHEEL +0 -0
  302. {vanna-0.7.9.dist-info → vanna-2.0.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,584 @@
1
+ import json
2
+ import uuid
3
+ from typing import List, Optional, Tuple
4
+
5
+ import oracledb
6
+ import pandas as pd
7
+ from chromadb.utils import embedding_functions
8
+
9
+ from ..base import VannaBase
10
+
11
+ default_ef = embedding_functions.DefaultEmbeddingFunction()
12
+
13
+
14
+ class Oracle_VectorStore(VannaBase):
15
+ def __init__(self, config=None):
16
+ VannaBase.__init__(self, config=config)
17
+
18
+ if config is not None:
19
+ self.embedding_function = config.get("embedding_function", default_ef)
20
+ self.pre_delete_collection = config.get("pre_delete_collection", False)
21
+ self.cmetadata = config.get("cmetadata", {"created_by": "oracle"})
22
+ else:
23
+ self.embedding_function = default_ef
24
+ self.pre_delete_collection = False
25
+ self.cmetadata = {"created_by": "oracle"}
26
+
27
+ self.oracle_conn = oracledb.connect(dsn=config.get("dsn"))
28
+ self.oracle_conn.call_timeout = 30000
29
+ self.documentation_collection = "documentation"
30
+ self.ddl_collection = "ddl"
31
+ self.sql_collection = "sql"
32
+ self.n_results = config.get("n_results", 10)
33
+ self.n_results_ddl = config.get("n_results_ddl", self.n_results)
34
+ self.n_results_sql = config.get("n_results_sql", self.n_results)
35
+ self.n_results_documentation = config.get(
36
+ "n_results_documentation", self.n_results
37
+ )
38
+ self.create_tables_if_not_exists()
39
+ self.create_collections_if_not_exists(self.documentation_collection)
40
+ self.create_collections_if_not_exists(self.ddl_collection)
41
+ self.create_collections_if_not_exists(self.sql_collection)
42
+
43
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
44
+ embeddings = self.embedding_function([data])
45
+ if len(embeddings) == 1:
46
+ return list(embeddings[0].astype(float))
47
+ return list(embeddings.astype(float))
48
+
49
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
50
+ cmetadata = self.cmetadata.copy()
51
+ collection = self.get_collection(self.sql_collection)
52
+ question_sql_json = json.dumps(
53
+ {
54
+ "question": question,
55
+ "sql": sql,
56
+ },
57
+ ensure_ascii=False,
58
+ )
59
+ id = str(uuid.uuid4())
60
+ embeddings = self.generate_embedding(question)
61
+ custom_id = id + "-sql"
62
+
63
+ cursor = self.oracle_conn.cursor()
64
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR)
65
+ cursor.execute(
66
+ """
67
+ INSERT INTO oracle_embedding (
68
+ collection_id,
69
+ embedding,
70
+ document,
71
+ cmetadata,
72
+ custom_id,
73
+ uuid
74
+ ) VALUES (
75
+ :1,
76
+ TO_VECTOR(:2),
77
+ :3,
78
+ :4,
79
+ :5,
80
+ :6
81
+ )
82
+ """,
83
+ [
84
+ collection["uuid"],
85
+ embeddings,
86
+ question_sql_json,
87
+ json.dumps(cmetadata),
88
+ custom_id,
89
+ id,
90
+ ],
91
+ )
92
+
93
+ self.oracle_conn.commit()
94
+ cursor.close()
95
+ return id
96
+
97
+ def add_ddl(self, ddl: str, **kwargs) -> str:
98
+ collection = self.get_collection(self.ddl_collection)
99
+ question_ddl_json = json.dumps(
100
+ {
101
+ "question": None,
102
+ "ddl": ddl,
103
+ },
104
+ ensure_ascii=False,
105
+ )
106
+ id = str(uuid.uuid4())
107
+ custom_id = id + "-ddl"
108
+ cursor = self.oracle_conn.cursor()
109
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR)
110
+ cursor.execute(
111
+ """
112
+ INSERT INTO oracle_embedding (
113
+ collection_id,
114
+ embedding,
115
+ document,
116
+ cmetadata,
117
+ custom_id,
118
+ uuid
119
+ ) VALUES (
120
+ :1,
121
+ TO_VECTOR(:2),
122
+ :3,
123
+ :4,
124
+ :5,
125
+ :6
126
+ )
127
+ """,
128
+ [
129
+ collection["uuid"],
130
+ self.generate_embedding(ddl),
131
+ question_ddl_json,
132
+ json.dumps(self.cmetadata),
133
+ custom_id,
134
+ id,
135
+ ],
136
+ )
137
+ self.oracle_conn.commit()
138
+ cursor.close()
139
+ return id
140
+
141
+ def add_documentation(self, documentation: str, **kwargs) -> str:
142
+ collection = self.get_collection(self.documentation_collection)
143
+ question_documentation_json = json.dumps(
144
+ {
145
+ "question": None,
146
+ "documentation": documentation,
147
+ },
148
+ ensure_ascii=False,
149
+ )
150
+ id = str(uuid.uuid4())
151
+ custom_id = id + "-doc"
152
+ cursor = self.oracle_conn.cursor()
153
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR)
154
+ cursor.execute(
155
+ """
156
+ INSERT INTO oracle_embedding (
157
+ collection_id,
158
+ embedding,
159
+ document,
160
+ cmetadata,
161
+ custom_id,
162
+ uuid
163
+ ) VALUES (
164
+ :1,
165
+ TO_VECTOR(:2),
166
+ :3,
167
+ :4,
168
+ :5,
169
+ :6
170
+ )
171
+ """,
172
+ [
173
+ collection["uuid"],
174
+ self.generate_embedding(documentation),
175
+ question_documentation_json,
176
+ json.dumps(self.cmetadata),
177
+ custom_id,
178
+ id,
179
+ ],
180
+ )
181
+ self.oracle_conn.commit()
182
+ cursor.close()
183
+ return id
184
+
185
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
186
+ df = pd.DataFrame()
187
+
188
+ cursor = self.oracle_conn.cursor()
189
+ sql_collection = self.get_collection(self.sql_collection)
190
+ cursor.execute(
191
+ """
192
+ SELECT
193
+ document,
194
+ uuid
195
+ FROM oracle_embedding
196
+ WHERE
197
+ collection_id = :1
198
+ """,
199
+ [sql_collection["uuid"]],
200
+ )
201
+ sql_data = cursor.fetchall()
202
+
203
+ if sql_data is not None:
204
+ # Extract the documents and ids
205
+ documents = [row_data[0] for row_data in sql_data]
206
+ ids = [row_data[1] for row_data in sql_data]
207
+
208
+ # Create a DataFrame
209
+ df_sql = pd.DataFrame(
210
+ {
211
+ "id": ids,
212
+ "question": [
213
+ json.loads(doc)["question"]
214
+ if isinstance(doc, str)
215
+ else doc["question"]
216
+ for doc in documents
217
+ ],
218
+ "content": [
219
+ json.loads(doc)["sql"] if isinstance(doc, str) else doc["sql"]
220
+ for doc in documents
221
+ ],
222
+ }
223
+ )
224
+ df_sql["training_data_type"] = "sql"
225
+ df = pd.concat([df, df_sql])
226
+
227
+ ddl_collection = self.get_collection(self.ddl_collection)
228
+ cursor.execute(
229
+ """
230
+ SELECT
231
+ document,
232
+ uuid
233
+ FROM oracle_embedding
234
+ WHERE
235
+ collection_id = :1
236
+ """,
237
+ [ddl_collection["uuid"]],
238
+ )
239
+ ddl_data = cursor.fetchall()
240
+
241
+ if ddl_data is not None:
242
+ # Extract the documents and ids
243
+ documents = [row_data[0] for row_data in ddl_data]
244
+ ids = [row_data[1] for row_data in ddl_data]
245
+
246
+ # Create a DataFrame
247
+ df_ddl = pd.DataFrame(
248
+ {
249
+ "id": ids,
250
+ "question": [None for _ in documents],
251
+ "content": [
252
+ json.loads(doc)["ddl"] if isinstance(doc, str) else doc["ddl"]
253
+ for doc in documents
254
+ ],
255
+ }
256
+ )
257
+ df_ddl["training_data_type"] = "ddl"
258
+ df = pd.concat([df, df_ddl])
259
+
260
+ doc_collection = self.get_collection(self.documentation_collection)
261
+ cursor.execute(
262
+ """
263
+ SELECT
264
+ document,
265
+ uuid
266
+ FROM oracle_embedding
267
+ WHERE
268
+ collection_id = :1
269
+ """,
270
+ [doc_collection["uuid"]],
271
+ )
272
+ doc_data = cursor.fetchall()
273
+
274
+ if doc_data is not None:
275
+ # Extract the documents and ids
276
+ documents = [row_data[0] for row_data in doc_data]
277
+ ids = [row_data[1] for row_data in doc_data]
278
+
279
+ # Create a DataFrame
280
+ df_doc = pd.DataFrame(
281
+ {
282
+ "id": ids,
283
+ "question": [None for _ in documents],
284
+ "content": [
285
+ json.loads(doc)["documentation"]
286
+ if isinstance(doc, str)
287
+ else doc["documentation"]
288
+ for doc in documents
289
+ ],
290
+ }
291
+ )
292
+ df_doc["training_data_type"] = "documentation"
293
+ df = pd.concat([df, df_doc])
294
+
295
+ self.oracle_conn.commit()
296
+ cursor.close()
297
+ return df
298
+
299
+ def remove_training_data(self, id: str, **kwargs) -> bool:
300
+ cursor = self.oracle_conn.cursor()
301
+ cursor.execute(
302
+ """
303
+ DELETE
304
+ FROM
305
+ oracle_embedding
306
+ WHERE
307
+ uuid = :1
308
+ """,
309
+ [id],
310
+ )
311
+
312
+ self.oracle_conn.commit()
313
+ cursor.close()
314
+ return True
315
+
316
+ def update_training_data(
317
+ self, id: str, train_type: str, question: str, **kwargs
318
+ ) -> bool:
319
+ print(f"{train_type=}")
320
+ update_content = kwargs["content"]
321
+ if train_type == "sql":
322
+ update_json = json.dumps(
323
+ {
324
+ "question": question,
325
+ "sql": update_content,
326
+ }
327
+ )
328
+ elif train_type == "ddl":
329
+ update_json = json.dumps(
330
+ {
331
+ "question": None,
332
+ "ddl": update_content,
333
+ }
334
+ )
335
+ elif train_type == "documentation":
336
+ update_json = json.dumps(
337
+ {
338
+ "question": None,
339
+ "documentation": update_content,
340
+ }
341
+ )
342
+ else:
343
+ update_json = json.dumps(
344
+ {
345
+ "question": question,
346
+ "sql": update_content,
347
+ }
348
+ )
349
+ cursor = self.oracle_conn.cursor()
350
+ cursor.setinputsizes(oracledb.DB_TYPE_VECTOR, oracledb.DB_TYPE_JSON)
351
+ cursor.execute(
352
+ """
353
+ UPDATE
354
+ oracle_embedding
355
+ SET
356
+ embedding = TO_VECTOR(:1),
357
+ document = JSON_MERGEPATCH(document, :2)
358
+ WHERE
359
+ uuid = :3
360
+ """,
361
+ [self.generate_embedding(update_content), update_json, id],
362
+ )
363
+
364
+ self.oracle_conn.commit()
365
+ cursor.close()
366
+ return True
367
+
368
+ @staticmethod
369
+ def _extract_documents(query_results) -> list:
370
+ """
371
+ Static method to extract the documents from the results of a query.
372
+
373
+ Args:
374
+ query_results (pd.DataFrame): The dataframe to use.
375
+
376
+ Returns:
377
+ List[str] or None: The extracted documents, or an empty list or single document if an error occurred.
378
+ """
379
+ if query_results is None or len(query_results) == 0:
380
+ return []
381
+
382
+ documents = [
383
+ json.loads(row_data[0]) if isinstance(row_data[0], str) else row_data[0]
384
+ for row_data in query_results
385
+ ]
386
+
387
+ return documents
388
+
389
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
390
+ embeddings = self.generate_embedding(question)
391
+ collection = self.get_collection(self.sql_collection)
392
+ cursor = self.oracle_conn.cursor()
393
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR, oracledb.DB_TYPE_VECTOR)
394
+ cursor.execute(
395
+ """
396
+ SELECT document
397
+ FROM oracle_embedding
398
+ WHERE collection_id = :1
399
+ ORDER BY VECTOR_DISTANCE(embedding, TO_VECTOR(:2), COSINE)
400
+ FETCH FIRST :3 ROWS ONLY
401
+ """,
402
+ [collection["uuid"], embeddings, self.n_results_sql],
403
+ )
404
+ results = cursor.fetchall()
405
+ cursor.close()
406
+ return self._extract_documents(results)
407
+
408
+ def get_related_ddl(self, question: str, **kwargs) -> list:
409
+ collection = self.get_collection(self.ddl_collection)
410
+ cursor = self.oracle_conn.cursor()
411
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR)
412
+ cursor.execute(
413
+ """
414
+ SELECT
415
+ document
416
+ FROM oracle_embedding
417
+ WHERE
418
+ collection_id = :1
419
+ ORDER BY VECTOR_DISTANCE(embedding, TO_VECTOR(:2), COSINE)
420
+ FETCH FIRST :top_k ROWS ONLY
421
+ """,
422
+ [collection["uuid"], self.generate_embedding(question), 100],
423
+ )
424
+ results = cursor.fetchall()
425
+
426
+ self.oracle_conn.commit()
427
+ cursor.close()
428
+ return Oracle_VectorStore._extract_documents(results)
429
+
430
+ def search_tables_metadata(
431
+ self,
432
+ engine: str = None,
433
+ catalog: str = None,
434
+ schema: str = None,
435
+ table_name: str = None,
436
+ ddl: str = None,
437
+ size: int = 10,
438
+ **kwargs,
439
+ ) -> list:
440
+ pass
441
+
442
+ def get_related_documentation(self, question: str, **kwargs) -> list:
443
+ collection = self.get_collection(self.documentation_collection)
444
+ cursor = self.oracle_conn.cursor()
445
+ cursor.setinputsizes(None, oracledb.DB_TYPE_VECTOR)
446
+ cursor.execute(
447
+ """
448
+ SELECT
449
+ document
450
+ FROM oracle_embedding
451
+ WHERE
452
+ collection_id = :1
453
+ ORDER BY VECTOR_DISTANCE(embedding, TO_VECTOR(:2), DOT)
454
+ FETCH FIRST :top_k ROWS ONLY
455
+ """,
456
+ [collection["uuid"], self.generate_embedding(question), 100],
457
+ )
458
+ results = cursor.fetchall()
459
+
460
+ self.oracle_conn.commit()
461
+ cursor.close()
462
+
463
+ return Oracle_VectorStore._extract_documents(results)
464
+
465
+ def create_tables_if_not_exists(self) -> None:
466
+ cursor = self.oracle_conn.cursor()
467
+ cursor.execute(
468
+ """
469
+ CREATE TABLE IF NOT EXISTS oracle_collection (
470
+ name VARCHAR2(200) NOT NULL,
471
+ cmetadata json NOT NULL,
472
+ uuid VARCHAR2(200) NOT NULL,
473
+ CONSTRAINT oc_key_uuid PRIMARY KEY ( uuid )
474
+ )
475
+ """
476
+ )
477
+
478
+ cursor.execute(
479
+ """
480
+ CREATE TABLE IF NOT EXISTS oracle_embedding (
481
+ collection_id VARCHAR2(200) NOT NULL,
482
+ embedding vector NOT NULL,
483
+ document json NOT NULL,
484
+ cmetadata json NOT NULL,
485
+ custom_id VARCHAR2(200) NOT NULL,
486
+ uuid VARCHAR2(200) NOT NULL,
487
+ CONSTRAINT oe_key_uuid PRIMARY KEY ( uuid )
488
+ )
489
+ """
490
+ )
491
+
492
+ self.oracle_conn.commit()
493
+ cursor.close()
494
+
495
+ def create_collections_if_not_exists(
496
+ self,
497
+ name: str,
498
+ cmetadata: Optional[dict] = None,
499
+ ) -> Tuple[dict, bool]:
500
+ """
501
+ Get or create a collection.
502
+ Returns [Collection, bool] where the bool is True if the collection was created.
503
+ """
504
+ if self.pre_delete_collection:
505
+ self.delete_collection(name)
506
+ created = False
507
+ collection = self.get_collection(name)
508
+ if collection:
509
+ return collection, created
510
+
511
+ cmetadata = (
512
+ json.dumps(self.cmetadata) if cmetadata is None else json.dumps(cmetadata)
513
+ )
514
+ collection_id = str(uuid.uuid4())
515
+ cursor = self.oracle_conn.cursor()
516
+ cursor.execute(
517
+ """
518
+ INSERT INTO oracle_collection(name, cmetadata, uuid)
519
+ VALUES (:1, :2, :3)
520
+ """,
521
+ [name, cmetadata, str(collection_id)],
522
+ )
523
+
524
+ self.oracle_conn.commit()
525
+ cursor.close()
526
+
527
+ collection = {"name": name, "cmetadata": cmetadata, "uuid": collection_id}
528
+ created = True
529
+ return collection, created
530
+
531
+ def get_collection(self, name) -> Optional[dict]:
532
+ return self.get_by_name(name)
533
+
534
+ def get_by_name(self, name: str) -> Optional[dict]:
535
+ cursor = self.oracle_conn.cursor()
536
+ cursor.execute(
537
+ """
538
+ SELECT
539
+ name,
540
+ cmetadata,
541
+ uuid
542
+ FROM
543
+ oracle_collection
544
+ WHERE
545
+ name = :1
546
+ FETCH FIRST 1 ROWS ONLY
547
+ """,
548
+ [name],
549
+ )
550
+
551
+ for row in cursor:
552
+ return {"name": row[0], "cmetadata": row[1], "uuid": row[2]}
553
+
554
+ return # type: ignore
555
+
556
+ def delete_collection(self, name) -> None:
557
+ collection = self.get_collection(name)
558
+ if not collection:
559
+ return
560
+
561
+ cursor = self.oracle_conn.cursor()
562
+ cursor.execute(
563
+ """
564
+ DELETE
565
+ FROM
566
+ oracle_embedding
567
+ WHERE
568
+ collection_id = ( SELECT uuid FROM oracle_collection WHERE name = :1 )
569
+ """,
570
+ [name],
571
+ )
572
+ cursor.execute(
573
+ """
574
+ DELETE
575
+ FROM
576
+ oracle_collection
577
+ WHERE
578
+ name = :1
579
+ """,
580
+ [name],
581
+ )
582
+
583
+ self.oracle_conn.commit()
584
+ cursor.close()