vanna 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {vanna-0.4.0 → vanna-0.4.2}/PKG-INFO +5 -1
  2. {vanna-0.4.0 → vanna-0.4.2}/pyproject.toml +3 -2
  3. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/exceptions/__init__.py +8 -8
  4. vanna-0.4.2/src/vanna/qdrant/__init__.py +3 -0
  5. vanna-0.4.2/src/vanna/qdrant/qdrant.py +324 -0
  6. {vanna-0.4.0 → vanna-0.4.2}/README.md +0 -0
  7. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  8. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  9. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/ZhipuAI/__init__.py +0 -0
  10. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/__init__.py +0 -0
  11. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/anthropic/__init__.py +0 -0
  12. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/anthropic/anthropic_chat.py +0 -0
  13. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/base/__init__.py +0 -0
  14. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/base/base.py +0 -0
  15. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/chromadb/__init__.py +0 -0
  16. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/chromadb/chromadb_vector.py +0 -0
  17. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/flask/__init__.py +0 -0
  18. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/flask/assets.py +0 -0
  19. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/flask/auth.py +0 -0
  20. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/google/__init__.py +0 -0
  21. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/google/gemini_chat.py +0 -0
  22. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/local.py +0 -0
  23. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/marqo/__init__.py +0 -0
  24. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/marqo/marqo.py +0 -0
  25. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/mistral/__init__.py +0 -0
  26. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/mistral/mistral.py +0 -0
  27. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/ollama/__init__.py +0 -0
  28. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/ollama/ollama.py +0 -0
  29. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/openai/__init__.py +0 -0
  30. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/openai/openai_chat.py +0 -0
  31. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/openai/openai_embeddings.py +0 -0
  32. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/remote.py +0 -0
  33. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/types/__init__.py +0 -0
  34. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/utils.py +0 -0
  35. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/vannadb/__init__.py +0 -0
  36. {vanna-0.4.0 → vanna-0.4.2}/src/vanna/vannadb/vannadb_vector.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -30,6 +30,8 @@ Requires-Dist: zhipuai ; extra == "all"
30
30
  Requires-Dist: marqo ; extra == "all"
31
31
  Requires-Dist: google-generativeai ; extra == "all"
32
32
  Requires-Dist: google-cloud-aiplatform ; extra == "all"
33
+ Requires-Dist: qdrant-client ; extra == "all"
34
+ Requires-Dist: fastembed ; extra == "all"
33
35
  Requires-Dist: anthropic ; extra == "anthropic"
34
36
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
35
37
  Requires-Dist: chromadb ; extra == "chromadb"
@@ -43,6 +45,7 @@ Requires-Dist: PyMySQL ; extra == "mysql"
43
45
  Requires-Dist: openai ; extra == "openai"
44
46
  Requires-Dist: psycopg2-binary ; extra == "postgres"
45
47
  Requires-Dist: db-dtypes ; extra == "postgres"
48
+ Requires-Dist: qdrant-client ; extra == "qdrant"
46
49
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
47
50
  Requires-Dist: tox ; extra == "test"
48
51
  Requires-Dist: zhipuai ; extra == "zhipuai"
@@ -60,6 +63,7 @@ Provides-Extra: mistralai
60
63
  Provides-Extra: mysql
61
64
  Provides-Extra: openai
62
65
  Provides-Extra: postgres
66
+ Provides-Extra: qdrant
63
67
  Provides-Extra: snowflake
64
68
  Provides-Extra: test
65
69
  Provides-Extra: zhipuai
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.4.0"
7
+ version = "0.4.2"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"]
32
32
  snowflake = ["snowflake-connector-python"]
33
33
  duckdb = ["duckdb"]
34
34
  google = ["google-generativeai", "google-cloud-aiplatform"]
35
- all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"]
35
+ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed"]
36
36
  test = ["tox"]
37
37
  chromadb = ["chromadb"]
38
38
  openai = ["openai"]
@@ -41,3 +41,4 @@ anthropic = ["anthropic"]
41
41
  gemini = ["google-generativeai"]
42
42
  marqo = ["marqo"]
43
43
  zhipuai = ["zhipuai"]
44
+ qdrant = ["qdrant-client"]
@@ -1,46 +1,46 @@
1
- class ImproperlyConfigured(BaseException):
1
+ class ImproperlyConfigured(Exception):
2
2
  """Raise for incorrect configuration."""
3
3
 
4
4
  pass
5
5
 
6
6
 
7
- class DependencyError(BaseException):
7
+ class DependencyError(Exception):
8
8
  """Raise for missing dependencies."""
9
9
 
10
10
  pass
11
11
 
12
12
 
13
- class ConnectionError(BaseException):
13
+ class ConnectionError(Exception):
14
14
  """Raise for connection"""
15
15
 
16
16
  pass
17
17
 
18
18
 
19
- class OTPCodeError(BaseException):
19
+ class OTPCodeError(Exception):
20
20
  """Raise for invalid otp or not able to send it"""
21
21
 
22
22
  pass
23
23
 
24
24
 
25
- class SQLRemoveError(BaseException):
25
+ class SQLRemoveError(Exception):
26
26
  """Raise when not able to remove SQL"""
27
27
 
28
28
  pass
29
29
 
30
30
 
31
- class ExecutionError(BaseException):
31
+ class ExecutionError(Exception):
32
32
  """Raise when not able to execute Code"""
33
33
 
34
34
  pass
35
35
 
36
36
 
37
- class ValidationError(BaseException):
37
+ class ValidationError(Exception):
38
38
  """Raise for validations"""
39
39
 
40
40
  pass
41
41
 
42
42
 
43
- class APIError(BaseException):
43
+ class APIError(Exception):
44
44
  """Raise for API errors"""
45
45
 
46
46
  pass
@@ -0,0 +1,3 @@
1
+ from .qdrant import Qdrant_VectorStore
2
+
3
+ __all__ = ["Qdrant_VectorStore"]
@@ -0,0 +1,324 @@
1
+ from functools import cached_property
2
+ from typing import List, Tuple
3
+
4
+ import pandas as pd
5
+ from qdrant_client import QdrantClient, grpc, models
6
+
7
+ from ..base import VannaBase
8
+ from ..utils import deterministic_uuid
9
+
10
+ DOCUMENTATION_COLLECTION_NAME = "documentation"
11
+ DDL_COLLECTION_NAME = "ddl"
12
+ SQL_COLLECTION_NAME = "sql"
13
+ SCROLL_SIZE = 1000
14
+
15
+ ID_SUFFIXES = {
16
+ DDL_COLLECTION_NAME: "ddl",
17
+ DOCUMENTATION_COLLECTION_NAME: "doc",
18
+ SQL_COLLECTION_NAME: "sql",
19
+ }
20
+
21
+
22
+ class Qdrant_VectorStore(VannaBase):
23
+ """Vectorstore implementation using Qdrant - https://qdrant.tech/"""
24
+
25
+ def __init__(
26
+ self,
27
+ config={},
28
+ ):
29
+ """
30
+ Vectorstore implementation using Qdrant - https://qdrant.tech/
31
+
32
+ Args:
33
+ - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
34
+ - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
35
+ - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
36
+ - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
37
+ - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
38
+ - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
39
+ - api_key: API key for authentication in Qdrant Cloud. Default: `None`
40
+ - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
41
+ - path: Persistence path for QdrantLocal. Default: `None`.
42
+ - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
43
+ - n_results: Number of results to return from similarity search. Defaults to 10.
44
+ - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
45
+ Defaults to `"BAAI/bge-small-en-v1.5"`.
46
+ - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
47
+ - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
48
+
49
+ Raises:
50
+ TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
51
+ """
52
+ VannaBase.__init__(self, config=config)
53
+ client = config.get("client")
54
+
55
+ if client is None:
56
+ self._client = QdrantClient(
57
+ location=config.get("location", None),
58
+ url=config.get("url", None),
59
+ prefer_grpc=config.get("prefer_grpc", False),
60
+ https=config.get("https", None),
61
+ api_key=config.get("api_key", None),
62
+ timeout=config.get("timeout", None),
63
+ path=config.get("path", None),
64
+ prefix=config.get("prefix", None),
65
+ )
66
+ elif not isinstance(client, QdrantClient):
67
+ raise TypeError(
68
+ f"Unsupported client of type {client.__class__} was set in config"
69
+ )
70
+
71
+ else:
72
+ self._client = client
73
+
74
+ self.n_results = config.get("n_results", 10)
75
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
76
+ self.collection_params = config.get("collection_params", {})
77
+ self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
78
+
79
+ self._setup_collections()
80
+
81
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
82
+ question_answer = format("Question: {0}\n\nSQL: {1}", question, sql)
83
+ id = deterministic_uuid(question_answer)
84
+
85
+ self._client.upsert(
86
+ SQL_COLLECTION_NAME,
87
+ points=[
88
+ models.PointStruct(
89
+ id=id,
90
+ vector=self.generate_embedding(question_answer),
91
+ payload={
92
+ "question": question,
93
+ "sql": sql,
94
+ },
95
+ )
96
+ ],
97
+ )
98
+
99
+ return self._format_point_id(id, SQL_COLLECTION_NAME)
100
+
101
+ def add_ddl(self, ddl: str, **kwargs) -> str:
102
+ id = deterministic_uuid(ddl)
103
+ self._client.upsert(
104
+ DDL_COLLECTION_NAME,
105
+ points=[
106
+ models.PointStruct(
107
+ id=id,
108
+ vector=self.generate_embedding(ddl),
109
+ payload={
110
+ "ddl": ddl,
111
+ },
112
+ )
113
+ ],
114
+ )
115
+ return self._format_point_id(id, DDL_COLLECTION_NAME)
116
+
117
+ def add_documentation(self, documentation: str, **kwargs) -> str:
118
+ id = deterministic_uuid(documentation)
119
+
120
+ self._client.upsert(
121
+ DOCUMENTATION_COLLECTION_NAME,
122
+ points=[
123
+ models.PointStruct(
124
+ id=id,
125
+ vector=self.generate_embedding(documentation),
126
+ payload={
127
+ "documentation": documentation,
128
+ },
129
+ )
130
+ ],
131
+ )
132
+
133
+ return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
134
+
135
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
136
+ df = pd.DataFrame()
137
+
138
+ if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
139
+ question_list = [data.payload["question"] for data in sql_data]
140
+ sql_list = [data.payload["sql"] for data in sql_data]
141
+ id_list = [
142
+ self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
143
+ ]
144
+
145
+ df_sql = pd.DataFrame(
146
+ {
147
+ "id": id_list,
148
+ "question": question_list,
149
+ "content": sql_list,
150
+ }
151
+ )
152
+
153
+ df_sql["training_data_type"] = "sql"
154
+
155
+ df = pd.concat([df, df_sql])
156
+
157
+ if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
158
+ ddl_list = [data.payload["ddl"] for data in ddl_data]
159
+ id_list = [
160
+ self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in sql_data
161
+ ]
162
+
163
+ df_ddl = pd.DataFrame(
164
+ {
165
+ "id": id_list,
166
+ "question": [None for _ in ddl_list],
167
+ "content": ddl_list,
168
+ }
169
+ )
170
+
171
+ df_ddl["training_data_type"] = "ddl"
172
+
173
+ df = pd.concat([df, df_ddl])
174
+
175
+ doc_data = self.documentation_collection.get()
176
+
177
+ if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
178
+ document_list = [data.payload["documentation"] for data in doc_data]
179
+ id_list = [
180
+ self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
181
+ for data in doc_data
182
+ ]
183
+
184
+ df_doc = pd.DataFrame(
185
+ {
186
+ "id": id_list,
187
+ "question": [None for _ in document_list],
188
+ "content": document_list,
189
+ }
190
+ )
191
+
192
+ df_doc["training_data_type"] = "documentation"
193
+
194
+ df = pd.concat([df, df_doc])
195
+
196
+ return df
197
+
198
+ def remove_training_data(self, id: str, **kwargs) -> bool:
199
+ try:
200
+ id, collection_name = self._parse_point_id(id)
201
+ self._client.delete(collection_name, points_selector=[id])
202
+ except ValueError:
203
+ return False
204
+
205
+ def remove_collection(self, collection_name: str) -> bool:
206
+ """
207
+ This function can reset the collection to empty state.
208
+
209
+ Args:
210
+ collection_name (str): sql or ddl or documentation
211
+
212
+ Returns:
213
+ bool: True if collection is deleted, False otherwise
214
+ """
215
+ if collection_name in ID_SUFFIXES.keys():
216
+ self._client.delete_collection(collection_name)
217
+ self._setup_collections()
218
+ return True
219
+ else:
220
+ return False
221
+
222
+ @cached_property
223
+ def embeddings_dimension(self):
224
+ return len(self.generate_embedding("ABCDEF"))
225
+
226
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
227
+ results = self._client.search(
228
+ SQL_COLLECTION_NAME,
229
+ query_vector=self.generate_embedding(question),
230
+ limit=self.n_results,
231
+ with_payload=True,
232
+ )
233
+
234
+ return [dict(result.payload) for result in results]
235
+
236
+ def get_related_ddl(self, question: str, **kwargs) -> list:
237
+ results = self._client.search(
238
+ DDL_COLLECTION_NAME,
239
+ query_vector=self.generate_embedding(question),
240
+ limit=self.n_results,
241
+ with_payload=True,
242
+ )
243
+
244
+ return [result.payload["ddl"] for result in results]
245
+
246
+ def get_related_documentation(self, question: str, **kwargs) -> list:
247
+ results = self._client.search(
248
+ DOCUMENTATION_COLLECTION_NAME,
249
+ query_vector=self.generate_embedding(question),
250
+ limit=self.n_results,
251
+ with_payload=True,
252
+ )
253
+
254
+ return [result.payload["documentation"] for result in results]
255
+
256
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
257
+ embedding_model = self._client._get_or_init_model(
258
+ model_name=self.fastembed_model
259
+ )
260
+ embedding = next(embedding_model.embed(data))
261
+
262
+ return embedding.tolist()
263
+
264
+ def _get_all_points(self, collection_name: str):
265
+ results: List[models.Record] = []
266
+ next_offset = None
267
+ stop_scrolling = False
268
+ while not stop_scrolling:
269
+ records, next_offset = self._client.scroll(
270
+ collection_name,
271
+ limit=SCROLL_SIZE,
272
+ offset=next_offset,
273
+ with_payload=True,
274
+ with_vectors=False,
275
+ )
276
+ stop_scrolling = next_offset is None or (
277
+ isinstance(next_offset, grpc.PointId)
278
+ and next_offset.num == 0
279
+ and next_offset.uuid == ""
280
+ )
281
+
282
+ results.extend(records)
283
+
284
+ return results
285
+
286
+ def _setup_collections(self):
287
+ if not self._client.collection_exists(SQL_COLLECTION_NAME):
288
+ self._client.create_collection(
289
+ collection_name=SQL_COLLECTION_NAME,
290
+ vectors_config=models.VectorParams(
291
+ size=self.embeddings_dimension,
292
+ distance=self.distance_metric,
293
+ ),
294
+ **self.collection_params,
295
+ )
296
+
297
+ if not self._client.collection_exists(DDL_COLLECTION_NAME):
298
+ self._client.create_collection(
299
+ collection_name=DDL_COLLECTION_NAME,
300
+ vectors_config=models.VectorParams(
301
+ size=self.embeddings_dimension,
302
+ distance=self.distance_metric,
303
+ ),
304
+ **self.collection_params,
305
+ )
306
+ if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
307
+ self._client.create_collection(
308
+ collection_name=DOCUMENTATION_COLLECTION_NAME,
309
+ vectors_config=models.VectorParams(
310
+ size=self.embeddings_dimension,
311
+ distance=self.distance_metric,
312
+ ),
313
+ **self.collection_params,
314
+ )
315
+
316
+ def _format_point_id(self, id: str, collection_name: str) -> str:
317
+ return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
318
+
319
+ def _parse_point_id(self, id: str) -> Tuple[str, str]:
320
+ id, suffix = id.rsplit("-", 1)
321
+ for collection_name, suffix in ID_SUFFIXES.items():
322
+ if type == suffix:
323
+ return id, collection_name
324
+ raise ValueError(f"Invalid id {id}")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes