vanna 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/base/base.py CHANGED
@@ -992,6 +992,7 @@ class VannaBase(ABC):
992
992
  def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
993
993
  if conn:
994
994
  try:
995
+ conn.ping(reconnect=True)
995
996
  cs = conn.cursor()
996
997
  cs.execute(sql)
997
998
  results = cs.fetchall()
@@ -0,0 +1 @@
1
+ from .milvus_vector import Milvus_VectorStore
@@ -0,0 +1,305 @@
1
+ import uuid
2
+ from typing import List
3
+
4
+ import pandas as pd
5
+ from pymilvus import DataType, MilvusClient, model
6
+
7
+ from ..base import VannaBase
8
+
9
+ # Setting the URI as a local file, e.g.`./milvus.db`,
10
+ # is the most convenient method, as it automatically utilizes Milvus Lite
11
+ # to store all data in this file.
12
+ #
13
+ # If you have large scale of data such as more than a million docs, we
14
+ # recommend setting up a more performant Milvus server on docker or kubernetes.
15
+ # When using this setup, please use the server URI,
16
+ # e.g.`http://localhost:19530`, as your URI.
17
+
18
+ DEFAULT_MILVUS_URI = "./milvus.db"
19
+ # DEFAULT_MILVUS_URI = "http://localhost:19530"
20
+
21
+ MAX_LIMIT_SIZE = 10_000
22
+
23
+
24
+ class Milvus_VectorStore(VannaBase):
25
+ """
26
+ Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md
27
+
28
+ Args:
29
+ - config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`.
30
+ - milvus_client: A `pymilvus.MilvusClient` instance.
31
+ - embedding_function:
32
+ A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`.
33
+ For more models, please refer to:
34
+ https://milvus.io/docs/embeddings.md
35
+ """
36
+ def __init__(self, config=None):
37
+ VannaBase.__init__(self, config=config)
38
+
39
+ if "milvus_client" in config:
40
+ self.milvus_client = config["milvus_client"]
41
+ else:
42
+ self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI)
43
+
44
+ if "embedding_function" in config:
45
+ self.embedding_function = config.get("embedding_function")
46
+ else:
47
+ self.embedding_function = model.DefaultEmbeddingFunction()
48
+ self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
49
+ self._create_collections()
50
+ self.n_results = config.get("n_results", 10)
51
+
52
+ def _create_collections(self):
53
+ self._create_sql_collection("vannasql")
54
+ self._create_ddl_collection("vannaddl")
55
+ self._create_doc_collection("vannadoc")
56
+
57
+
58
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
59
+ return self.embedding_function.encode_documents(data).tolist()
60
+
61
+
62
+ def _create_sql_collection(self, name: str):
63
+ if not self.milvus_client.has_collection(collection_name=name):
64
+ vannasql_schema = MilvusClient.create_schema(
65
+ auto_id=False,
66
+ enable_dynamic_field=False,
67
+ )
68
+ vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
69
+ vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
70
+ vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
71
+ vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
72
+
73
+ vannasql_index_params = self.milvus_client.prepare_index_params()
74
+ vannasql_index_params.add_index(
75
+ field_name="vector",
76
+ index_name="vector",
77
+ index_type="AUTOINDEX",
78
+ metric_type="L2",
79
+ )
80
+ self.milvus_client.create_collection(
81
+ collection_name=name,
82
+ schema=vannasql_schema,
83
+ index_params=vannasql_index_params,
84
+ consistency_level="Strong"
85
+ )
86
+
87
+ def _create_ddl_collection(self, name: str):
88
+ if not self.milvus_client.has_collection(collection_name=name):
89
+ vannaddl_schema = MilvusClient.create_schema(
90
+ auto_id=False,
91
+ enable_dynamic_field=False,
92
+ )
93
+ vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
94
+ vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
95
+ vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
96
+
97
+ vannaddl_index_params = self.milvus_client.prepare_index_params()
98
+ vannaddl_index_params.add_index(
99
+ field_name="vector",
100
+ index_name="vector",
101
+ index_type="AUTOINDEX",
102
+ metric_type="L2",
103
+ )
104
+ self.milvus_client.create_collection(
105
+ collection_name=name,
106
+ schema=vannaddl_schema,
107
+ index_params=vannaddl_index_params,
108
+ consistency_level="Strong"
109
+ )
110
+
111
+ def _create_doc_collection(self, name: str):
112
+ if not self.milvus_client.has_collection(collection_name=name):
113
+ vannadoc_schema = MilvusClient.create_schema(
114
+ auto_id=False,
115
+ enable_dynamic_field=False,
116
+ )
117
+ vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
118
+ vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
119
+ vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
120
+
121
+ vannadoc_index_params = self.milvus_client.prepare_index_params()
122
+ vannadoc_index_params.add_index(
123
+ field_name="vector",
124
+ index_name="vector",
125
+ index_type="AUTOINDEX",
126
+ metric_type="L2",
127
+ )
128
+ self.milvus_client.create_collection(
129
+ collection_name=name,
130
+ schema=vannadoc_schema,
131
+ index_params=vannadoc_index_params,
132
+ consistency_level="Strong"
133
+ )
134
+
135
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
136
+ if len(question) == 0 or len(sql) == 0:
137
+ raise Exception("pair of question and sql can not be null")
138
+ _id = str(uuid.uuid4()) + "-sql"
139
+ embedding = self.embedding_function.encode_documents([question])[0]
140
+ self.milvus_client.insert(
141
+ collection_name="vannasql",
142
+ data={
143
+ "id": _id,
144
+ "text": question,
145
+ "sql": sql,
146
+ "vector": embedding
147
+ }
148
+ )
149
+ return _id
150
+
151
+ def add_ddl(self, ddl: str, **kwargs) -> str:
152
+ if len(ddl) == 0:
153
+ raise Exception("ddl can not be null")
154
+ _id = str(uuid.uuid4()) + "-ddl"
155
+ embedding = self.embedding_function.encode_documents([ddl])[0]
156
+ self.milvus_client.insert(
157
+ collection_name="vannaddl",
158
+ data={
159
+ "id": _id,
160
+ "ddl": ddl,
161
+ "vector": embedding
162
+ }
163
+ )
164
+ return _id
165
+
166
+ def add_documentation(self, documentation: str, **kwargs) -> str:
167
+ if len(documentation) == 0:
168
+ raise Exception("documentation can not be null")
169
+ _id = str(uuid.uuid4()) + "-doc"
170
+ embedding = self.embedding_function.encode_documents([documentation])[0]
171
+ self.milvus_client.insert(
172
+ collection_name="vannadoc",
173
+ data={
174
+ "id": _id,
175
+ "doc": documentation,
176
+ "vector": embedding
177
+ }
178
+ )
179
+ return _id
180
+
181
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
182
+ sql_data = self.milvus_client.query(
183
+ collection_name="vannasql",
184
+ output_fields=["*"],
185
+ limit=MAX_LIMIT_SIZE,
186
+ )
187
+ df = pd.DataFrame()
188
+ df_sql = pd.DataFrame(
189
+ {
190
+ "id": [doc["id"] for doc in sql_data],
191
+ "question": [doc["text"] for doc in sql_data],
192
+ "content": [doc["sql"] for doc in sql_data],
193
+ }
194
+ )
195
+ df = pd.concat([df, df_sql])
196
+
197
+ ddl_data = self.milvus_client.query(
198
+ collection_name="vannaddl",
199
+ output_fields=["*"],
200
+ limit=MAX_LIMIT_SIZE,
201
+ )
202
+
203
+ df_ddl = pd.DataFrame(
204
+ {
205
+ "id": [doc["id"] for doc in ddl_data],
206
+ "question": [None for doc in ddl_data],
207
+ "content": [doc["ddl"] for doc in ddl_data],
208
+ }
209
+ )
210
+ df = pd.concat([df, df_ddl])
211
+
212
+ doc_data = self.milvus_client.query(
213
+ collection_name="vannadoc",
214
+ output_fields=["*"],
215
+ limit=MAX_LIMIT_SIZE,
216
+ )
217
+
218
+ df_doc = pd.DataFrame(
219
+ {
220
+ "id": [doc["id"] for doc in doc_data],
221
+ "question": [None for doc in doc_data],
222
+ "content": [doc["doc"] for doc in doc_data],
223
+ }
224
+ )
225
+ df = pd.concat([df, df_doc])
226
+ return df
227
+
228
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
229
+ search_params = {
230
+ "metric_type": "L2",
231
+ "params": {"nprobe": 128},
232
+ }
233
+ embeddings = self.embedding_function.encode_queries([question])
234
+ res = self.milvus_client.search(
235
+ collection_name="vannasql",
236
+ anns_field="vector",
237
+ data=embeddings,
238
+ limit=self.n_results,
239
+ output_fields=["text", "sql"],
240
+ search_params=search_params
241
+ )
242
+ res = res[0]
243
+
244
+ list_sql = []
245
+ for doc in res:
246
+ dict = {}
247
+ dict["question"] = doc["entity"]["text"]
248
+ dict["sql"] = doc["entity"]["sql"]
249
+ list_sql.append(dict)
250
+ return list_sql
251
+
252
+ def get_related_ddl(self, question: str, **kwargs) -> list:
253
+ search_params = {
254
+ "metric_type": "L2",
255
+ "params": {"nprobe": 128},
256
+ }
257
+ embeddings = self.embedding_function.encode_queries([question])
258
+ res = self.milvus_client.search(
259
+ collection_name="vannaddl",
260
+ anns_field="vector",
261
+ data=embeddings,
262
+ limit=self.n_results,
263
+ output_fields=["ddl"],
264
+ search_params=search_params
265
+ )
266
+ res = res[0]
267
+
268
+ list_ddl = []
269
+ for doc in res:
270
+ list_ddl.append(doc["entity"]["ddl"])
271
+ return list_ddl
272
+
273
+ def get_related_documentation(self, question: str, **kwargs) -> list:
274
+ search_params = {
275
+ "metric_type": "L2",
276
+ "params": {"nprobe": 128},
277
+ }
278
+ embeddings = self.embedding_function.encode_queries([question])
279
+ res = self.milvus_client.search(
280
+ collection_name="vannadoc",
281
+ anns_field="vector",
282
+ data=embeddings,
283
+ limit=self.n_results,
284
+ output_fields=["doc"],
285
+ search_params=search_params
286
+ )
287
+ res = res[0]
288
+
289
+ list_doc = []
290
+ for doc in res:
291
+ list_doc.append(doc["entity"]["doc"])
292
+ return list_doc
293
+
294
+ def remove_training_data(self, id: str, **kwargs) -> bool:
295
+ if id.endswith("-sql"):
296
+ self.milvus_client.delete(collection_name="vannasql", ids=[id])
297
+ return True
298
+ elif id.endswith("-ddl"):
299
+ self.milvus_client.delete(collection_name="vannaddl", ids=[id])
300
+ return True
301
+ elif id.endswith("-doc"):
302
+ self.milvus_client.delete(collection_name="vannadoc", ids=[id])
303
+ return True
304
+ else:
305
+ return False
vanna/qdrant/qdrant.py CHANGED
@@ -39,16 +39,6 @@ class Qdrant_VectorStore(VannaBase):
39
39
  TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
40
40
  """
41
41
 
42
- documentation_collection_name = "documentation"
43
- ddl_collection_name = "ddl"
44
- sql_collection_name = "sql"
45
-
46
- id_suffixes = {
47
- ddl_collection_name: "ddl",
48
- documentation_collection_name: "doc",
49
- sql_collection_name: "sql",
50
- }
51
-
52
42
  def __init__(
53
43
  self,
54
44
  config={},
@@ -80,15 +70,21 @@ class Qdrant_VectorStore(VannaBase):
80
70
  self.collection_params = config.get("collection_params", {})
81
71
  self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
82
72
  self.documentation_collection_name = config.get(
83
- "documentation_collection_name", self.documentation_collection_name
73
+ "documentation_collection_name", "documentation"
84
74
  )
85
75
  self.ddl_collection_name = config.get(
86
- "ddl_collection_name", self.ddl_collection_name
76
+ "ddl_collection_name", "ddl"
87
77
  )
88
78
  self.sql_collection_name = config.get(
89
- "sql_collection_name", self.sql_collection_name
79
+ "sql_collection_name", "sql"
90
80
  )
91
81
 
82
+ self.id_suffixes = {
83
+ self.ddl_collection_name: "ddl",
84
+ self.documentation_collection_name: "doc",
85
+ self.sql_collection_name: "sql",
86
+ }
87
+
92
88
  self._setup_collections()
93
89
 
94
90
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
@@ -212,7 +208,7 @@ class Qdrant_VectorStore(VannaBase):
212
208
  try:
213
209
  id, collection_name = self._parse_point_id(id)
214
210
  res = self._client.delete(collection_name, points_selector=[id])
215
- res == UpdateStatus.COMPLETED
211
+ return True
216
212
  except ValueError:
217
213
  return False
218
214
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.6.0
3
+ Version: 0.6.1
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -39,6 +39,7 @@ Requires-Dist: opensearch-py ; extra == "all"
39
39
  Requires-Dist: opensearch-dsl ; extra == "all"
40
40
  Requires-Dist: transformers ; extra == "all"
41
41
  Requires-Dist: pinecone-client ; extra == "all"
42
+ Requires-Dist: pymilvus[model] ; extra == "all"
42
43
  Requires-Dist: anthropic ; extra == "anthropic"
43
44
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
44
45
  Requires-Dist: chromadb ; extra == "chromadb"
@@ -49,6 +50,7 @@ Requires-Dist: google-generativeai ; extra == "google"
49
50
  Requires-Dist: google-cloud-aiplatform ; extra == "google"
50
51
  Requires-Dist: transformers ; extra == "hf"
51
52
  Requires-Dist: marqo ; extra == "marqo"
53
+ Requires-Dist: pymilvus[model] ; extra == "milvus"
52
54
  Requires-Dist: mistralai ; extra == "mistralai"
53
55
  Requires-Dist: PyMySQL ; extra == "mysql"
54
56
  Requires-Dist: ollama ; extra == "ollama"
@@ -78,6 +80,7 @@ Provides-Extra: gemini
78
80
  Provides-Extra: google
79
81
  Provides-Extra: hf
80
82
  Provides-Extra: marqo
83
+ Provides-Extra: milvus
81
84
  Provides-Extra: mistralai
82
85
  Provides-Extra: mysql
83
86
  Provides-Extra: ollama
@@ -9,7 +9,7 @@ vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,65
9
9
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
10
10
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
11
11
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
12
- vanna/base/base.py,sha256=XiBJU4LOKoFnY1vnkr_jz3soAdVdoL0D1WIZdzZCVyU,70942
12
+ vanna/base/base.py,sha256=l1H0TKsK9DN3n5XgDkUckdLois4dTCAUwrVsRa_6SlQ,70988
13
13
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
14
14
  vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
15
15
  vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
@@ -22,6 +22,8 @@ vanna/hf/__init__.py,sha256=vD0bIhfLkA1UsvVSF4MAz3Da8aQunkQo3wlDztmMuj0,19
22
22
  vanna/hf/hf.py,sha256=v1v6sZnbj5xcrjgmvLP_ytS9NM7E5d0GyMfXXtr6BMU,2703
23
23
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
24
24
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
25
+ vanna/milvus/__init__.py,sha256=VBasJG2eTKbJI6CEand7kPLNBrqYrn0QCAhSYVz814s,46
26
+ vanna/milvus/milvus_vector.py,sha256=Mq0eaSh0UcTYhgh8mTm0fvS6rbfL6tQONVnDZGemWoM,11268
25
27
  vanna/mistral/__init__.py,sha256=70rTY-69Z2ehkkMj84dNMCukPo6AWdflBGvIB_pztS0,29
26
28
  vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
27
29
  vanna/mock/__init__.py,sha256=nYR2WfcV5NdwpK3V64QGOWHBGc3ESN9uV68JLS76aRw,97
@@ -38,12 +40,12 @@ vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8
38
40
  vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
39
41
  vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
40
42
  vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
41
- vanna/qdrant/qdrant.py,sha256=RyICUvOO_jt8u9MB4oIYhqv3BicZ0d9pQkSFwkIfUjg,12719
43
+ vanna/qdrant/qdrant.py,sha256=qkTWhGrVSAngJZkrcRQ8YFVHcI9j_ZoOGbF6ZVUUdsU,12567
42
44
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
43
45
  vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
44
46
  vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
45
47
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
46
48
  vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
47
- vanna-0.6.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
48
- vanna-0.6.0.dist-info/METADATA,sha256=Nu63PMm3HT3HAkpxKU4fte9CKpIDgnL90Ww2AljWSOs,11506
49
- vanna-0.6.0.dist-info/RECORD,,
49
+ vanna-0.6.1.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
50
+ vanna-0.6.1.dist-info/METADATA,sha256=BApaZrir1-x-Y99Ufh4eoVZ9_poaGuG2uh-YVQfBP3Y,11628
51
+ vanna-0.6.1.dist-info/RECORD,,
File without changes