vanna 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/flask/auth.py ADDED
@@ -0,0 +1,55 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import flask
4
+
5
+
6
+ class AuthInterface(ABC):
7
+ @abstractmethod
8
+ def get_user(self, flask_request) -> any:
9
+ pass
10
+
11
+ @abstractmethod
12
+ def is_logged_in(self, user: any) -> bool:
13
+ pass
14
+
15
+ @abstractmethod
16
+ def override_config_for_user(self, user: any, config: dict) -> dict:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def login_form(self) -> str:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def login_handler(self, flask_request) -> str:
25
+ pass
26
+
27
+ @abstractmethod
28
+ def callback_handler(self, flask_request) -> str:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def logout_handler(self, flask_request) -> str:
33
+ pass
34
+
35
+ class NoAuth(AuthInterface):
36
+ def get_user(self, flask_request) -> any:
37
+ return {}
38
+
39
+ def is_logged_in(self, user: any) -> bool:
40
+ return True
41
+
42
+ def override_config_for_user(self, user: any, config: dict) -> dict:
43
+ return config
44
+
45
+ def login_form(self) -> str:
46
+ return ''
47
+
48
+ def login_handler(self, flask_request) -> str:
49
+ return 'No login required'
50
+
51
+ def callback_handler(self, flask_request) -> str:
52
+ return 'No login required'
53
+
54
+ def logout_handler(self, flask_request) -> str:
55
+ return 'No login required'
@@ -0,0 +1 @@
1
+ from .gemini_chat import GoogleGeminiChat
@@ -0,0 +1,52 @@
1
+ import os
2
+ from ..base import VannaBase
3
+
4
+
5
+ class GoogleGeminiChat(VannaBase):
6
+ def __init__(self, config=None):
7
+ VannaBase.__init__(self, config=config)
8
+
9
+ # default temperature - can be overrided using config
10
+ self.temperature = 0.7
11
+
12
+ if "temperature" in config:
13
+ self.temperature = config["temperature"]
14
+
15
+ if "model_name" in config:
16
+ model_name = config["model_name"]
17
+ else:
18
+ model_name = "gemini-1.0-pro"
19
+
20
+ self.google_api_key = None
21
+
22
+ if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
23
+ """
24
+ If Google api_key is provided through config
25
+ or set as an environment variable, assign it.
26
+ """
27
+ import google.generativeai as genai
28
+
29
+ genai.configure(api_key=config["api_key"])
30
+ self.chat_model = genai.GenerativeModel(model_name)
31
+ else:
32
+ # Authenticate using VertexAI
33
+ from vertexai.preview.generative_models import GenerativeModel
34
+ self.chat_model = GenerativeModel("gemini-pro")
35
+
36
+ def system_message(self, message: str) -> any:
37
+ return message
38
+
39
+ def user_message(self, message: str) -> any:
40
+ return message
41
+
42
+ def assistant_message(self, message: str) -> any:
43
+ return message
44
+
45
+ def submit_prompt(self, prompt, **kwargs) -> str:
46
+ response = self.chat_model.generate_content(
47
+ prompt,
48
+ generation_config={
49
+ "temperature": self.temperature,
50
+ },
51
+ )
52
+ return response.text
@@ -0,0 +1,3 @@
1
+ from .qdrant import Qdrant_VectorStore
2
+
3
+ __all__ = ["Qdrant_VectorStore"]
vanna/qdrant/qdrant.py ADDED
@@ -0,0 +1,324 @@
1
+ from functools import cached_property
2
+ from typing import List, Tuple
3
+
4
+ import pandas as pd
5
+ from qdrant_client import QdrantClient, grpc, models
6
+
7
+ from ..base import VannaBase
8
+ from ..utils import deterministic_uuid
9
+
10
+ DOCUMENTATION_COLLECTION_NAME = "documentation"
11
+ DDL_COLLECTION_NAME = "ddl"
12
+ SQL_COLLECTION_NAME = "sql"
13
+ SCROLL_SIZE = 1000
14
+
15
+ ID_SUFFIXES = {
16
+ DDL_COLLECTION_NAME: "ddl",
17
+ DOCUMENTATION_COLLECTION_NAME: "doc",
18
+ SQL_COLLECTION_NAME: "sql",
19
+ }
20
+
21
+
22
+ class Qdrant_VectorStore(VannaBase):
23
+ """Vectorstore implementation using Qdrant - https://qdrant.tech/"""
24
+
25
+ def __init__(
26
+ self,
27
+ config={},
28
+ ):
29
+ """
30
+ Vectorstore implementation using Qdrant - https://qdrant.tech/
31
+
32
+ Args:
33
+ - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
34
+ - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
35
+ - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
36
+ - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
37
+ - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
38
+ - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
39
+ - api_key: API key for authentication in Qdrant Cloud. Default: `None`
40
+ - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
41
+ - path: Persistence path for QdrantLocal. Default: `None`.
42
+ - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
43
+ - n_results: Number of results to return from similarity search. Defaults to 10.
44
+ - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
45
+ Defaults to `"BAAI/bge-small-en-v1.5"`.
46
+ - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
47
+ - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
48
+
49
+ Raises:
50
+ TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
51
+ """
52
+ VannaBase.__init__(self, config=config)
53
+ client = config.get("client")
54
+
55
+ if client is None:
56
+ self._client = QdrantClient(
57
+ location=config.get("location", None),
58
+ url=config.get("url", None),
59
+ prefer_grpc=config.get("prefer_grpc", False),
60
+ https=config.get("https", None),
61
+ api_key=config.get("api_key", None),
62
+ timeout=config.get("timeout", None),
63
+ path=config.get("path", None),
64
+ prefix=config.get("prefix", None),
65
+ )
66
+ elif not isinstance(client, QdrantClient):
67
+ raise TypeError(
68
+ f"Unsupported client of type {client.__class__} was set in config"
69
+ )
70
+
71
+ else:
72
+ self._client = client
73
+
74
+ self.n_results = config.get("n_results", 10)
75
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
76
+ self.collection_params = config.get("collection_params", {})
77
+ self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
78
+
79
+ self._setup_collections()
80
+
81
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
82
+ question_answer = format("Question: {0}\n\nSQL: {1}", question, sql)
83
+ id = deterministic_uuid(question_answer)
84
+
85
+ self._client.upsert(
86
+ SQL_COLLECTION_NAME,
87
+ points=[
88
+ models.PointStruct(
89
+ id=id,
90
+ vector=self.generate_embedding(question_answer),
91
+ payload={
92
+ "question": question,
93
+ "sql": sql,
94
+ },
95
+ )
96
+ ],
97
+ )
98
+
99
+ return self._format_point_id(id, SQL_COLLECTION_NAME)
100
+
101
+ def add_ddl(self, ddl: str, **kwargs) -> str:
102
+ id = deterministic_uuid(ddl)
103
+ self._client.upsert(
104
+ DDL_COLLECTION_NAME,
105
+ points=[
106
+ models.PointStruct(
107
+ id=id,
108
+ vector=self.generate_embedding(ddl),
109
+ payload={
110
+ "ddl": ddl,
111
+ },
112
+ )
113
+ ],
114
+ )
115
+ return self._format_point_id(id, DDL_COLLECTION_NAME)
116
+
117
+ def add_documentation(self, documentation: str, **kwargs) -> str:
118
+ id = deterministic_uuid(documentation)
119
+
120
+ self._client.upsert(
121
+ DOCUMENTATION_COLLECTION_NAME,
122
+ points=[
123
+ models.PointStruct(
124
+ id=id,
125
+ vector=self.generate_embedding(documentation),
126
+ payload={
127
+ "documentation": documentation,
128
+ },
129
+ )
130
+ ],
131
+ )
132
+
133
+ return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
134
+
135
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
136
+ df = pd.DataFrame()
137
+
138
+ if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
139
+ question_list = [data.payload["question"] for data in sql_data]
140
+ sql_list = [data.payload["sql"] for data in sql_data]
141
+ id_list = [
142
+ self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
143
+ ]
144
+
145
+ df_sql = pd.DataFrame(
146
+ {
147
+ "id": id_list,
148
+ "question": question_list,
149
+ "content": sql_list,
150
+ }
151
+ )
152
+
153
+ df_sql["training_data_type"] = "sql"
154
+
155
+ df = pd.concat([df, df_sql])
156
+
157
+ if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
158
+ ddl_list = [data.payload["ddl"] for data in ddl_data]
159
+ id_list = [
160
+ self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in sql_data
161
+ ]
162
+
163
+ df_ddl = pd.DataFrame(
164
+ {
165
+ "id": id_list,
166
+ "question": [None for _ in ddl_list],
167
+ "content": ddl_list,
168
+ }
169
+ )
170
+
171
+ df_ddl["training_data_type"] = "ddl"
172
+
173
+ df = pd.concat([df, df_ddl])
174
+
175
+ doc_data = self.documentation_collection.get()
176
+
177
+ if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
178
+ document_list = [data.payload["documentation"] for data in doc_data]
179
+ id_list = [
180
+ self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
181
+ for data in doc_data
182
+ ]
183
+
184
+ df_doc = pd.DataFrame(
185
+ {
186
+ "id": id_list,
187
+ "question": [None for _ in document_list],
188
+ "content": document_list,
189
+ }
190
+ )
191
+
192
+ df_doc["training_data_type"] = "documentation"
193
+
194
+ df = pd.concat([df, df_doc])
195
+
196
+ return df
197
+
198
+ def remove_training_data(self, id: str, **kwargs) -> bool:
199
+ try:
200
+ id, collection_name = self._parse_point_id(id)
201
+ self._client.delete(collection_name, points_selector=[id])
202
+ except ValueError:
203
+ return False
204
+
205
+ def remove_collection(self, collection_name: str) -> bool:
206
+ """
207
+ This function can reset the collection to empty state.
208
+
209
+ Args:
210
+ collection_name (str): sql or ddl or documentation
211
+
212
+ Returns:
213
+ bool: True if collection is deleted, False otherwise
214
+ """
215
+ if collection_name in ID_SUFFIXES.keys():
216
+ self._client.delete_collection(collection_name)
217
+ self._setup_collections()
218
+ return True
219
+ else:
220
+ return False
221
+
222
+ @cached_property
223
+ def embeddings_dimension(self):
224
+ return len(self.generate_embedding("ABCDEF"))
225
+
226
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
227
+ results = self._client.search(
228
+ SQL_COLLECTION_NAME,
229
+ query_vector=self.generate_embedding(question),
230
+ limit=self.n_results,
231
+ with_payload=True,
232
+ )
233
+
234
+ return [dict(result.payload) for result in results]
235
+
236
+ def get_related_ddl(self, question: str, **kwargs) -> list:
237
+ results = self._client.search(
238
+ DDL_COLLECTION_NAME,
239
+ query_vector=self.generate_embedding(question),
240
+ limit=self.n_results,
241
+ with_payload=True,
242
+ )
243
+
244
+ return [result.payload["ddl"] for result in results]
245
+
246
+ def get_related_documentation(self, question: str, **kwargs) -> list:
247
+ results = self._client.search(
248
+ DOCUMENTATION_COLLECTION_NAME,
249
+ query_vector=self.generate_embedding(question),
250
+ limit=self.n_results,
251
+ with_payload=True,
252
+ )
253
+
254
+ return [result.payload["documentation"] for result in results]
255
+
256
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
257
+ embedding_model = self._client._get_or_init_model(
258
+ model_name=self.fastembed_model
259
+ )
260
+ embedding = next(embedding_model.embed(data))
261
+
262
+ return embedding.tolist()
263
+
264
+ def _get_all_points(self, collection_name: str):
265
+ results: List[models.Record] = []
266
+ next_offset = None
267
+ stop_scrolling = False
268
+ while not stop_scrolling:
269
+ records, next_offset = self._client.scroll(
270
+ collection_name,
271
+ limit=SCROLL_SIZE,
272
+ offset=next_offset,
273
+ with_payload=True,
274
+ with_vectors=False,
275
+ )
276
+ stop_scrolling = next_offset is None or (
277
+ isinstance(next_offset, grpc.PointId)
278
+ and next_offset.num == 0
279
+ and next_offset.uuid == ""
280
+ )
281
+
282
+ results.extend(records)
283
+
284
+ return results
285
+
286
+ def _setup_collections(self):
287
+ if not self._client.collection_exists(SQL_COLLECTION_NAME):
288
+ self._client.create_collection(
289
+ collection_name=SQL_COLLECTION_NAME,
290
+ vectors_config=models.VectorParams(
291
+ size=self.embeddings_dimension,
292
+ distance=self.distance_metric,
293
+ ),
294
+ **self.collection_params,
295
+ )
296
+
297
+ if not self._client.collection_exists(DDL_COLLECTION_NAME):
298
+ self._client.create_collection(
299
+ collection_name=DDL_COLLECTION_NAME,
300
+ vectors_config=models.VectorParams(
301
+ size=self.embeddings_dimension,
302
+ distance=self.distance_metric,
303
+ ),
304
+ **self.collection_params,
305
+ )
306
+ if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
307
+ self._client.create_collection(
308
+ collection_name=DOCUMENTATION_COLLECTION_NAME,
309
+ vectors_config=models.VectorParams(
310
+ size=self.embeddings_dimension,
311
+ distance=self.distance_metric,
312
+ ),
313
+ **self.collection_params,
314
+ )
315
+
316
+ def _format_point_id(self, id: str, collection_name: str) -> str:
317
+ return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
318
+
319
+ def _parse_point_id(self, id: str) -> Tuple[str, str]:
320
+ id, suffix = id.rsplit("-", 1)
321
+ for collection_name, suffix in ID_SUFFIXES.items():
322
+ if type == suffix:
323
+ return id, collection_name
324
+ raise ValueError(f"Invalid id {id}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.3.4
3
+ Version: 0.4.1
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -28,17 +28,24 @@ Requires-Dist: chromadb ; extra == "all"
28
28
  Requires-Dist: anthropic ; extra == "all"
29
29
  Requires-Dist: zhipuai ; extra == "all"
30
30
  Requires-Dist: marqo ; extra == "all"
31
+ Requires-Dist: google-generativeai ; extra == "all"
32
+ Requires-Dist: google-cloud-aiplatform ; extra == "all"
33
+ Requires-Dist: qdrant-client ; extra == "all"
34
+ Requires-Dist: fastembed ; extra == "all"
31
35
  Requires-Dist: anthropic ; extra == "anthropic"
32
36
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
33
37
  Requires-Dist: chromadb ; extra == "chromadb"
34
38
  Requires-Dist: duckdb ; extra == "duckdb"
35
39
  Requires-Dist: google-generativeai ; extra == "gemini"
40
+ Requires-Dist: google-generativeai ; extra == "google"
41
+ Requires-Dist: google-cloud-aiplatform ; extra == "google"
36
42
  Requires-Dist: marqo ; extra == "marqo"
37
43
  Requires-Dist: mistralai ; extra == "mistralai"
38
44
  Requires-Dist: PyMySQL ; extra == "mysql"
39
45
  Requires-Dist: openai ; extra == "openai"
40
46
  Requires-Dist: psycopg2-binary ; extra == "postgres"
41
47
  Requires-Dist: db-dtypes ; extra == "postgres"
48
+ Requires-Dist: qdrant-client ; extra == "qdrant"
42
49
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
43
50
  Requires-Dist: tox ; extra == "test"
44
51
  Requires-Dist: zhipuai ; extra == "zhipuai"
@@ -50,11 +57,13 @@ Provides-Extra: bigquery
50
57
  Provides-Extra: chromadb
51
58
  Provides-Extra: duckdb
52
59
  Provides-Extra: gemini
60
+ Provides-Extra: google
53
61
  Provides-Extra: marqo
54
62
  Provides-Extra: mistralai
55
63
  Provides-Extra: mysql
56
64
  Provides-Extra: openai
57
65
  Provides-Extra: postgres
66
+ Provides-Extra: qdrant
58
67
  Provides-Extra: snowflake
59
68
  Provides-Extra: test
60
69
  Provides-Extra: zhipuai
@@ -8,12 +8,15 @@ vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
8
8
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
9
9
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
10
10
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
11
- vanna/base/base.py,sha256=89XPWy97YVx6090mNmu1zvn4k8X1pusCuAIypHHexNc,58100
11
+ vanna/base/base.py,sha256=_2vANGAcUe6IrsEhZyFnE6FdO8NMLylfyGlRI4XujWE,58143
12
12
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
13
- vanna/chromadb/chromadb_vector.py,sha256=1n4U4XpXThCFqyJf0zAYVA7mQu9rUkjOFtYn9e04JAo,8461
13
+ vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
14
14
  vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
15
- vanna/flask/__init__.py,sha256=tpwpA8596Uyn60FAy7I5oJ81c7kgCB2JG9X044P0_SA,21211
16
- vanna/flask/assets.py,sha256=pOOtPV8aWtFsTuxJneFHcfrXhXh6cOSvS-Y8JO2HYrY,180924
15
+ vanna/flask/__init__.py,sha256=5Du2oK5s-VSLicRPvxMAL1-Gh_jdX847FjJOf5AVapo,23721
16
+ vanna/flask/assets.py,sha256=ZESgn0-XrJl4_YV69Lu7Dr7-Y9Eql7xqb7PsoMzrofw,183889
17
+ vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
18
+ vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
19
+ vanna/google/gemini_chat.py,sha256=ps3A-afFbCo3HeFTLL_nMoQO1PsGvRUUPRUppbMcDew,1584
17
20
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
18
21
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
19
22
  vanna/mistral/__init__.py,sha256=70rTY-69Z2ehkkMj84dNMCukPo6AWdflBGvIB_pztS0,29
@@ -23,9 +26,11 @@ vanna/ollama/ollama.py,sha256=jfW9VQHAcmzDeo4jF3HJjOMYwAWmptknKqEJaQ0MTno,2418
23
26
  vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
24
27
  vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
25
28
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
29
+ vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
30
+ vanna/qdrant/qdrant.py,sha256=XlesB0UniR0LmiXvQ8Ct7o19EdCZPuvUgvyzYSlp94c,11940
26
31
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
27
32
  vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
28
33
  vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
29
- vanna-0.3.4.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
30
- vanna-0.3.4.dist-info/METADATA,sha256=FEg4vs5ZiSAvd5YkF5oEfFqod9n3UoNfi51Q_2WKotA,10107
31
- vanna-0.3.4.dist-info/RECORD,,
34
+ vanna-0.4.1.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
35
+ vanna-0.4.1.dist-info/METADATA,sha256=SwU235jbylf7gJ4VwDXtz7bLhUF5uTV0IdPCnK1HH68,10512
36
+ vanna-0.4.1.dist-info/RECORD,,
File without changes