vanna 0.5.0__tar.gz → 0.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {vanna-0.5.0 → vanna-0.5.2}/PKG-INFO +4 -3
  2. {vanna-0.5.0 → vanna-0.5.2}/README.md +2 -2
  3. {vanna-0.5.0 → vanna-0.5.2}/pyproject.toml +2 -2
  4. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/base/base.py +117 -11
  5. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/opensearch/opensearch_vector.py +98 -4
  6. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/qdrant/qdrant.py +74 -62
  7. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  8. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  9. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/ZhipuAI/__init__.py +0 -0
  10. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/__init__.py +0 -0
  11. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/anthropic/__init__.py +0 -0
  12. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/anthropic/anthropic_chat.py +0 -0
  13. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/base/__init__.py +0 -0
  14. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/chromadb/__init__.py +0 -0
  15. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/chromadb/chromadb_vector.py +0 -0
  16. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/exceptions/__init__.py +0 -0
  17. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/flask/__init__.py +0 -0
  18. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/flask/assets.py +0 -0
  19. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/flask/auth.py +0 -0
  20. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/google/__init__.py +0 -0
  21. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/google/gemini_chat.py +0 -0
  22. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/hf/__init__.py +0 -0
  23. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/hf/hf.py +0 -0
  24. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/local.py +0 -0
  25. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/marqo/__init__.py +0 -0
  26. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/marqo/marqo.py +0 -0
  27. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/mistral/__init__.py +0 -0
  28. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/mistral/mistral.py +0 -0
  29. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/ollama/__init__.py +0 -0
  30. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/ollama/ollama.py +0 -0
  31. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/openai/__init__.py +0 -0
  32. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/openai/openai_chat.py +0 -0
  33. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/openai/openai_embeddings.py +0 -0
  34. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/opensearch/__init__.py +0 -0
  35. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/qdrant/__init__.py +0 -0
  36. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/remote.py +0 -0
  37. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/types/__init__.py +0 -0
  38. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/utils.py +0 -0
  39. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/vannadb/__init__.py +0 -0
  40. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/vannadb/vannadb_vector.py +0 -0
  41. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/vllm/__init__.py +0 -0
  42. {vanna-0.5.0 → vanna-0.5.2}/src/vanna/vllm/vllm.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.5.0
3
+ Version: 0.5.2
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -57,6 +57,7 @@ Requires-Dist: opensearch-dsl ; extra == "opensearch"
57
57
  Requires-Dist: psycopg2-binary ; extra == "postgres"
58
58
  Requires-Dist: db-dtypes ; extra == "postgres"
59
59
  Requires-Dist: qdrant-client ; extra == "qdrant"
60
+ Requires-Dist: fastembed ; extra == "qdrant"
60
61
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
61
62
  Requires-Dist: tox ; extra == "test"
62
63
  Requires-Dist: vllm ; extra == "vllm"
@@ -111,7 +112,7 @@ Vanna works in two easy steps - train a RAG "model" on your data, and then ask q
111
112
 
112
113
  If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions.
113
114
 
114
- See the [base class](src/vanna/base/base.py) for more details on how this works under the hood.
115
+ See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood.
115
116
 
116
117
  ## User Interfaces
117
118
  These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface.
@@ -304,7 +305,7 @@ Fine-Tuning
304
305
  - Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end.
305
306
 
306
307
  ## Extending Vanna
307
- Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
308
+ Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
308
309
 
309
310
  ## Vanna in 100 Seconds
310
311
 
@@ -25,7 +25,7 @@ Vanna works in two easy steps - train a RAG "model" on your data, and then ask q
25
25
 
26
26
  If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions.
27
27
 
28
- See the [base class](src/vanna/base/base.py) for more details on how this works under the hood.
28
+ See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood.
29
29
 
30
30
  ## User Interfaces
31
31
  These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface.
@@ -218,7 +218,7 @@ Fine-Tuning
218
218
  - Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end.
219
219
 
220
220
  ## Extending Vanna
221
- Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
221
+ Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
222
222
 
223
223
  ## Vanna in 100 Seconds
224
224
 
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.5.0"
7
+ version = "0.5.2"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -42,7 +42,7 @@ gemini = ["google-generativeai"]
42
42
  marqo = ["marqo"]
43
43
  zhipuai = ["zhipuai"]
44
44
  ollama = ["ollama", "httpx"]
45
- qdrant = ["qdrant-client"]
45
+ qdrant = ["qdrant-client", "fastembed"]
46
46
  vllm = ["vllm"]
47
47
  opensearch = ["opensearch-py", "opensearch-dsl"]
48
48
  hf = ["transformers"]
@@ -1146,19 +1146,14 @@ class VannaBase(ABC):
1146
1146
 
1147
1147
  conn = None
1148
1148
 
1149
- try:
1150
- conn = bigquery.Client(project=project_id)
1151
- except:
1152
- print("Could not found any google cloud implicit credentials")
1153
-
1154
- if cred_file_path:
1149
+ if not cred_file_path:
1150
+ try:
1151
+ conn = bigquery.Client(project=project_id)
1152
+ except:
1153
+ print("Could not found any google cloud implicit credentials")
1154
+ else:
1155
1155
  # Validate file path and pemissions
1156
1156
  validate_config_path(cred_file_path)
1157
- else:
1158
- if not conn:
1159
- raise ValidationError(
1160
- "Pleae provide a service account credentials json file"
1161
- )
1162
1157
 
1163
1158
  if not conn:
1164
1159
  with open(cred_file_path, "r") as f:
@@ -1279,6 +1274,7 @@ class VannaBase(ABC):
1279
1274
  # Execute the SQL statement and return the result as a pandas DataFrame
1280
1275
  with engine.begin() as conn:
1281
1276
  df = pd.read_sql_query(sa.text(sql), conn)
1277
+ conn.close()
1282
1278
  return df
1283
1279
 
1284
1280
  raise Exception("Couldn't run sql")
@@ -1286,6 +1282,116 @@ class VannaBase(ABC):
1286
1282
  self.dialect = "T-SQL / Microsoft SQL Server"
1287
1283
  self.run_sql = run_sql_mssql
1288
1284
  self.run_sql_is_set = True
1285
+ def connect_to_presto(
1286
+ self,
1287
+ host: str,
1288
+ catalog: str = 'hive',
1289
+ schema: str = 'default',
1290
+ user: str = None,
1291
+ password: str = None,
1292
+ port: int = None,
1293
+ combined_pem_path: str = None,
1294
+ protocol: str = 'https',
1295
+ requests_kwargs: dict = None
1296
+ ):
1297
+ """
1298
+ Connect to a Presto database using the specified parameters.
1299
+
1300
+ Args:
1301
+ host (str): The host address of the Presto database.
1302
+ catalog (str): The catalog to use in the Presto environment.
1303
+ schema (str): The schema to use in the Presto environment.
1304
+ user (str): The username for authentication.
1305
+ password (str): The password for authentication.
1306
+ port (int): The port number for the Presto connection.
1307
+ combined_pem_path (str): The path to the combined pem file for SSL connection.
1308
+ protocol (str): The protocol to use for the connection (default is 'https').
1309
+ requests_kwargs (dict): Additional keyword arguments for requests.
1310
+
1311
+ Raises:
1312
+ DependencyError: If required dependencies are not installed.
1313
+ ImproperlyConfigured: If essential configuration settings are missing.
1314
+
1315
+ Returns:
1316
+ None
1317
+ """
1318
+ try:
1319
+ from pyhive import presto
1320
+ except ImportError:
1321
+ raise DependencyError(
1322
+ "You need to install required dependencies to execute this method,"
1323
+ " run command: \npip install pyhive"
1324
+ )
1325
+
1326
+ if not host:
1327
+ host = os.getenv("PRESTO_HOST")
1328
+
1329
+ if not host:
1330
+ raise ImproperlyConfigured("Please set your presto host")
1331
+
1332
+ if not catalog:
1333
+ catalog = os.getenv("PRESTO_CATALOG")
1334
+
1335
+ if not catalog:
1336
+ raise ImproperlyConfigured("Please set your presto catalog")
1337
+
1338
+ if not user:
1339
+ user = os.getenv("PRESTO_USER")
1340
+
1341
+ if not user:
1342
+ raise ImproperlyConfigured("Please set your presto user")
1343
+
1344
+ if not password:
1345
+ password = os.getenv("PRESTO_PASSWORD")
1346
+
1347
+ if not port:
1348
+ port = os.getenv("PRESTO_PORT")
1349
+
1350
+ if not port:
1351
+ raise ImproperlyConfigured("Please set your presto port")
1352
+
1353
+ conn = None
1354
+
1355
+ try:
1356
+ if requests_kwargs is None and combined_pem_path is not None:
1357
+ # use the combined pem file to verify the SSL connection
1358
+ requests_kwargs = {
1359
+ 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
1360
+ }
1361
+ conn = presto.Connection(host=host,
1362
+ username=user,
1363
+ password=password,
1364
+ catalog=catalog,
1365
+ schema=schema,
1366
+ port=port,
1367
+ protocol=protocol,
1368
+ requests_kwargs=requests_kwargs)
1369
+ except presto.Error as e:
1370
+ raise ValidationError(e)
1371
+
1372
+ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1373
+ if conn:
1374
+ try:
1375
+ cs = conn.cursor()
1376
+ cs.execute(sql)
1377
+ results = cs.fetchall()
1378
+
1379
+ # Create a pandas dataframe from the results
1380
+ df = pd.DataFrame(
1381
+ results, columns=[desc[0] for desc in cs.description]
1382
+ )
1383
+ return df
1384
+
1385
+ except presto.Error as e:
1386
+ print(e)
1387
+ raise ValidationError(e)
1388
+
1389
+ except Exception as e:
1390
+ print(e)
1391
+ raise e
1392
+
1393
+ self.run_sql_is_set = True
1394
+ self.run_sql = run_sql_presto
1289
1395
 
1290
1396
  def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
1291
1397
  """
@@ -24,7 +24,77 @@ class OpenSearch_VectorStore(VannaBase):
24
24
  self.document_index = document_index
25
25
  self.ddl_index = ddl_index
26
26
  self.question_sql_index = question_sql_index
27
- print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)
27
+ print("OpenSearch_VectorStore initialized with document_index: ",
28
+ document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
29
+ question_sql_index)
30
+
31
+ document_index_settings = {
32
+ "settings": {
33
+ "index": {
34
+ "number_of_shards": 6,
35
+ "number_of_replicas": 2
36
+ }
37
+ },
38
+ "mappings": {
39
+ "properties": {
40
+ "question": {
41
+ "type": "text",
42
+ },
43
+ "doc": {
44
+ "type": "text",
45
+ }
46
+ }
47
+ }
48
+ }
49
+
50
+ ddl_index_settings = {
51
+ "settings": {
52
+ "index": {
53
+ "number_of_shards": 6,
54
+ "number_of_replicas": 2
55
+ }
56
+ },
57
+ "mappings": {
58
+ "properties": {
59
+ "ddl": {
60
+ "type": "text",
61
+ },
62
+ "doc": {
63
+ "type": "text",
64
+ }
65
+ }
66
+ }
67
+ }
68
+
69
+ question_sql_index_settings = {
70
+ "settings": {
71
+ "index": {
72
+ "number_of_shards": 6,
73
+ "number_of_replicas": 2
74
+ }
75
+ },
76
+ "mappings": {
77
+ "properties": {
78
+ "question": {
79
+ "type": "text",
80
+ },
81
+ "sql": {
82
+ "type": "text",
83
+ }
84
+ }
85
+ }
86
+ }
87
+
88
+ if config is not None and "es_document_index_settings" in config:
89
+ document_index_settings = config["es_document_index_settings"]
90
+ if config is not None and "es_ddl_index_settings" in config:
91
+ ddl_index_settings = config["es_ddl_index_settings"]
92
+ if config is not None and "es_question_sql_index_settings" in config:
93
+ question_sql_index_settings = config["es_question_sql_index_settings"]
94
+
95
+ self.document_index_settings = document_index_settings
96
+ self.ddl_index_settings = ddl_index_settings
97
+ self.question_sql_index_settings = question_sql_index_settings
28
98
 
29
99
  es_urls = None
30
100
  if config is not None and "es_urls" in config:
@@ -85,6 +155,9 @@ class OpenSearch_VectorStore(VannaBase):
85
155
  else:
86
156
  max_retries = 10
87
157
 
158
+ print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
159
+ " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
160
+ verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
88
161
  if es_urls is not None:
89
162
  # Initialize the OpenSearch client by passing a list of URLs
90
163
  self.client = OpenSearch(
@@ -112,18 +185,26 @@ class OpenSearch_VectorStore(VannaBase):
112
185
  headers=headers
113
186
  )
114
187
 
188
+ print("OpenSearch_VectorStore initialized with client over ")
189
+
115
190
  # 执行一个简单的查询来检查连接
116
191
  try:
192
+ print('Connected to OpenSearch cluster:')
117
193
  info = self.client.info()
118
194
  print('OpenSearch cluster info:', info)
119
195
  except Exception as e:
120
196
  print('Error connecting to OpenSearch cluster:', e)
121
197
 
122
198
  # Create the indices if they don't exist
123
- # self.create_index()
199
+ self.create_index_if_not_exists(self.document_index,
200
+ self.document_index_settings)
201
+ self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
202
+ self.create_index_if_not_exists(self.question_sql_index,
203
+ self.question_sql_index_settings)
124
204
 
125
205
  def create_index(self):
126
- for index in [self.document_index, self.ddl_index, self.question_sql_index]:
206
+ for index in [self.document_index, self.ddl_index,
207
+ self.question_sql_index]:
127
208
  try:
128
209
  self.client.indices.create(index)
129
210
  except Exception as e:
@@ -131,6 +212,20 @@ class OpenSearch_VectorStore(VannaBase):
131
212
  print(f"opensearch index {index} already exists")
132
213
  pass
133
214
 
215
+ def create_index_if_not_exists(self, index_name: str,
216
+ index_settings: dict) -> bool:
217
+ try:
218
+ if not self.client.indices.exists(index_name):
219
+ print(f"Index {index_name} does not exist. Creating...")
220
+ self.client.indices.create(index=index_name, body=index_settings)
221
+ return True
222
+ else:
223
+ print(f"Index {index_name} already exists.")
224
+ return False
225
+ except Exception as e:
226
+ print(f"Error creating index: {index_name} ", e)
227
+ return False
228
+
134
229
  def add_ddl(self, ddl: str, **kwargs) -> str:
135
230
  # Assuming that you have a DDL index in your OpenSearch
136
231
  id = str(uuid.uuid4()) + "-ddl"
@@ -278,7 +373,6 @@ class OpenSearch_VectorStore(VannaBase):
278
373
  # opensearch doesn't need to generate embeddings
279
374
  pass
280
375
 
281
-
282
376
  # OpenSearch_VectorStore.__init__(self, config={'es_urls':
283
377
  # "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
284
378
  # "admin", 'es_password': "admin", 'es_verify_certs': True})
@@ -7,48 +7,51 @@ from qdrant_client import QdrantClient, grpc, models
7
7
  from ..base import VannaBase
8
8
  from ..utils import deterministic_uuid
9
9
 
10
- DOCUMENTATION_COLLECTION_NAME = "documentation"
11
- DDL_COLLECTION_NAME = "ddl"
12
- SQL_COLLECTION_NAME = "sql"
13
10
  SCROLL_SIZE = 1000
14
11
 
15
- ID_SUFFIXES = {
16
- DDL_COLLECTION_NAME: "ddl",
17
- DOCUMENTATION_COLLECTION_NAME: "doc",
18
- SQL_COLLECTION_NAME: "sql",
19
- }
20
-
21
12
 
22
13
  class Qdrant_VectorStore(VannaBase):
23
- """Vectorstore implementation using Qdrant - https://qdrant.tech/"""
14
+ """
15
+ Vectorstore implementation using Qdrant - https://qdrant.tech/
16
+
17
+ Args:
18
+ - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
19
+ - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
20
+ - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
21
+ - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
22
+ - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
23
+ - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
24
+ - api_key: API key for authentication in Qdrant Cloud. Default: `None`
25
+ - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
26
+ - path: Persistence path for QdrantLocal. Default: `None`.
27
+ - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
28
+ - n_results: Number of results to return from similarity search. Defaults to 10.
29
+ - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
30
+ Defaults to `"BAAI/bge-small-en-v1.5"`.
31
+ - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
32
+ - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
33
+ - documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`.
34
+ - ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`.
35
+ - sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`.
36
+
37
+ Raises:
38
+ TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
39
+ """
40
+
41
+ documentation_collection_name = "documentation"
42
+ ddl_collection_name = "ddl"
43
+ sql_collection_name = "sql"
44
+
45
+ id_suffixes = {
46
+ ddl_collection_name: "ddl",
47
+ documentation_collection_name: "doc",
48
+ sql_collection_name: "sql",
49
+ }
24
50
 
25
51
  def __init__(
26
52
  self,
27
53
  config={},
28
54
  ):
29
- """
30
- Vectorstore implementation using Qdrant - https://qdrant.tech/
31
-
32
- Args:
33
- - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
34
- - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
35
- - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
36
- - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
37
- - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
38
- - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
39
- - api_key: API key for authentication in Qdrant Cloud. Default: `None`
40
- - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
41
- - path: Persistence path for QdrantLocal. Default: `None`.
42
- - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
43
- - n_results: Number of results to return from similarity search. Defaults to 10.
44
- - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
45
- Defaults to `"BAAI/bge-small-en-v1.5"`.
46
- - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
47
- - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
48
-
49
- Raises:
50
- TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
51
- """
52
55
  VannaBase.__init__(self, config=config)
53
56
  client = config.get("client")
54
57
 
@@ -75,15 +78,24 @@ class Qdrant_VectorStore(VannaBase):
75
78
  self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
76
79
  self.collection_params = config.get("collection_params", {})
77
80
  self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
81
+ self.documentation_collection_name = config.get(
82
+ "documentation_collection_name", self.documentation_collection_name
83
+ )
84
+ self.ddl_collection_name = config.get(
85
+ "ddl_collection_name", self.ddl_collection_name
86
+ )
87
+ self.sql_collection_name = config.get(
88
+ "sql_collection_name", self.sql_collection_name
89
+ )
78
90
 
79
91
  self._setup_collections()
80
92
 
81
93
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
82
- question_answer = format("Question: {0}\n\nSQL: {1}", question, sql)
94
+ question_answer = "Question: {0}\n\nSQL: {1}".format(question, sql)
83
95
  id = deterministic_uuid(question_answer)
84
96
 
85
97
  self._client.upsert(
86
- SQL_COLLECTION_NAME,
98
+ self.sql_collection_name,
87
99
  points=[
88
100
  models.PointStruct(
89
101
  id=id,
@@ -96,12 +108,12 @@ class Qdrant_VectorStore(VannaBase):
96
108
  ],
97
109
  )
98
110
 
99
- return self._format_point_id(id, SQL_COLLECTION_NAME)
111
+ return self._format_point_id(id, self.sql_collection_name)
100
112
 
101
113
  def add_ddl(self, ddl: str, **kwargs) -> str:
102
114
  id = deterministic_uuid(ddl)
103
115
  self._client.upsert(
104
- DDL_COLLECTION_NAME,
116
+ self.ddl_collection_name,
105
117
  points=[
106
118
  models.PointStruct(
107
119
  id=id,
@@ -112,13 +124,13 @@ class Qdrant_VectorStore(VannaBase):
112
124
  )
113
125
  ],
114
126
  )
115
- return self._format_point_id(id, DDL_COLLECTION_NAME)
127
+ return self._format_point_id(id, self.ddl_collection_name)
116
128
 
117
129
  def add_documentation(self, documentation: str, **kwargs) -> str:
118
130
  id = deterministic_uuid(documentation)
119
131
 
120
132
  self._client.upsert(
121
- DOCUMENTATION_COLLECTION_NAME,
133
+ self.documentation_collection_name,
122
134
  points=[
123
135
  models.PointStruct(
124
136
  id=id,
@@ -130,16 +142,17 @@ class Qdrant_VectorStore(VannaBase):
130
142
  ],
131
143
  )
132
144
 
133
- return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
145
+ return self._format_point_id(id, self.documentation_collection_name)
134
146
 
135
147
  def get_training_data(self, **kwargs) -> pd.DataFrame:
136
148
  df = pd.DataFrame()
137
149
 
138
- if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
150
+ if sql_data := self._get_all_points(self.sql_collection_name):
139
151
  question_list = [data.payload["question"] for data in sql_data]
140
152
  sql_list = [data.payload["sql"] for data in sql_data]
141
153
  id_list = [
142
- self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
154
+ self._format_point_id(data.id, self.sql_collection_name)
155
+ for data in sql_data
143
156
  ]
144
157
 
145
158
  df_sql = pd.DataFrame(
@@ -154,10 +167,11 @@ class Qdrant_VectorStore(VannaBase):
154
167
 
155
168
  df = pd.concat([df, df_sql])
156
169
 
157
- if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
170
+ if ddl_data := self._get_all_points(self.ddl_collection_name):
158
171
  ddl_list = [data.payload["ddl"] for data in ddl_data]
159
172
  id_list = [
160
- self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in sql_data
173
+ self._format_point_id(data.id, self.ddl_collection_name)
174
+ for data in ddl_data
161
175
  ]
162
176
 
163
177
  df_ddl = pd.DataFrame(
@@ -172,12 +186,10 @@ class Qdrant_VectorStore(VannaBase):
172
186
 
173
187
  df = pd.concat([df, df_ddl])
174
188
 
175
- doc_data = self.documentation_collection.get()
176
-
177
- if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
189
+ if doc_data := self._get_all_points(self.documentation_collection_name):
178
190
  document_list = [data.payload["documentation"] for data in doc_data]
179
191
  id_list = [
180
- self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
192
+ self._format_point_id(data.id, self.documentation_collection_name)
181
193
  for data in doc_data
182
194
  ]
183
195
 
@@ -212,7 +224,7 @@ class Qdrant_VectorStore(VannaBase):
212
224
  Returns:
213
225
  bool: True if collection is deleted, False otherwise
214
226
  """
215
- if collection_name in ID_SUFFIXES.keys():
227
+ if collection_name in self.id_suffixes.keys():
216
228
  self._client.delete_collection(collection_name)
217
229
  self._setup_collections()
218
230
  return True
@@ -225,7 +237,7 @@ class Qdrant_VectorStore(VannaBase):
225
237
 
226
238
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
227
239
  results = self._client.search(
228
- SQL_COLLECTION_NAME,
240
+ self.sql_collection_name,
229
241
  query_vector=self.generate_embedding(question),
230
242
  limit=self.n_results,
231
243
  with_payload=True,
@@ -235,7 +247,7 @@ class Qdrant_VectorStore(VannaBase):
235
247
 
236
248
  def get_related_ddl(self, question: str, **kwargs) -> list:
237
249
  results = self._client.search(
238
- DDL_COLLECTION_NAME,
250
+ self.ddl_collection_name,
239
251
  query_vector=self.generate_embedding(question),
240
252
  limit=self.n_results,
241
253
  with_payload=True,
@@ -245,7 +257,7 @@ class Qdrant_VectorStore(VannaBase):
245
257
 
246
258
  def get_related_documentation(self, question: str, **kwargs) -> list:
247
259
  results = self._client.search(
248
- DOCUMENTATION_COLLECTION_NAME,
260
+ self.documentation_collection_name,
249
261
  query_vector=self.generate_embedding(question),
250
262
  limit=self.n_results,
251
263
  with_payload=True,
@@ -284,9 +296,9 @@ class Qdrant_VectorStore(VannaBase):
284
296
  return results
285
297
 
286
298
  def _setup_collections(self):
287
- if not self._client.collection_exists(SQL_COLLECTION_NAME):
299
+ if not self._client.collection_exists(self.sql_collection_name):
288
300
  self._client.create_collection(
289
- collection_name=SQL_COLLECTION_NAME,
301
+ collection_name=self.sql_collection_name,
290
302
  vectors_config=models.VectorParams(
291
303
  size=self.embeddings_dimension,
292
304
  distance=self.distance_metric,
@@ -294,18 +306,18 @@ class Qdrant_VectorStore(VannaBase):
294
306
  **self.collection_params,
295
307
  )
296
308
 
297
- if not self._client.collection_exists(DDL_COLLECTION_NAME):
309
+ if not self._client.collection_exists(self.ddl_collection_name):
298
310
  self._client.create_collection(
299
- collection_name=DDL_COLLECTION_NAME,
311
+ collection_name=self.ddl_collection_name,
300
312
  vectors_config=models.VectorParams(
301
313
  size=self.embeddings_dimension,
302
314
  distance=self.distance_metric,
303
315
  ),
304
316
  **self.collection_params,
305
317
  )
306
- if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
318
+ if not self._client.collection_exists(self.documentation_collection_name):
307
319
  self._client.create_collection(
308
- collection_name=DOCUMENTATION_COLLECTION_NAME,
320
+ collection_name=self.documentation_collection_name,
309
321
  vectors_config=models.VectorParams(
310
322
  size=self.embeddings_dimension,
311
323
  distance=self.distance_metric,
@@ -314,11 +326,11 @@ class Qdrant_VectorStore(VannaBase):
314
326
  )
315
327
 
316
328
  def _format_point_id(self, id: str, collection_name: str) -> str:
317
- return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
329
+ return "{0}-{1}".format(id, self.id_suffixes[collection_name])
318
330
 
319
331
  def _parse_point_id(self, id: str) -> Tuple[str, str]:
320
- id, suffix = id.rsplit("-", 1)
321
- for collection_name, suffix in ID_SUFFIXES.items():
322
- if type == suffix:
332
+ id, curr_suffix = id.rsplit("-", 1)
333
+ for collection_name, suffix in self.id_suffixes.items():
334
+ if curr_suffix == suffix:
323
335
  return id, collection_name
324
336
  raise ValueError(f"Invalid id {id}")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes