vanna 0.5.1__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {vanna-0.5.1 → vanna-0.5.3}/PKG-INFO +1 -1
  2. {vanna-0.5.1 → vanna-0.5.3}/pyproject.toml +1 -1
  3. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/base/base.py +117 -11
  4. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/ollama/ollama.py +1 -1
  5. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/opensearch/opensearch_vector.py +98 -4
  6. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/qdrant/qdrant.py +73 -59
  7. {vanna-0.5.1 → vanna-0.5.3}/README.md +0 -0
  8. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  9. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  10. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/ZhipuAI/__init__.py +0 -0
  11. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/__init__.py +0 -0
  12. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/anthropic/__init__.py +0 -0
  13. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/anthropic/anthropic_chat.py +0 -0
  14. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/base/__init__.py +0 -0
  15. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/chromadb/__init__.py +0 -0
  16. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/chromadb/chromadb_vector.py +0 -0
  17. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/exceptions/__init__.py +0 -0
  18. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/flask/__init__.py +0 -0
  19. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/flask/assets.py +0 -0
  20. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/flask/auth.py +0 -0
  21. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/google/__init__.py +0 -0
  22. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/google/gemini_chat.py +0 -0
  23. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/hf/__init__.py +0 -0
  24. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/hf/hf.py +0 -0
  25. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/local.py +0 -0
  26. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/marqo/__init__.py +0 -0
  27. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/marqo/marqo.py +0 -0
  28. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/mistral/__init__.py +0 -0
  29. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/mistral/mistral.py +0 -0
  30. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/ollama/__init__.py +0 -0
  31. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/openai/__init__.py +0 -0
  32. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/openai/openai_chat.py +0 -0
  33. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/openai/openai_embeddings.py +0 -0
  34. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/opensearch/__init__.py +0 -0
  35. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/qdrant/__init__.py +0 -0
  36. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/remote.py +0 -0
  37. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/types/__init__.py +0 -0
  38. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/utils.py +0 -0
  39. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/vannadb/__init__.py +0 -0
  40. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/vannadb/vannadb_vector.py +0 -0
  41. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/vllm/__init__.py +0 -0
  42. {vanna-0.5.1 → vanna-0.5.3}/src/vanna/vllm/vllm.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.5.1
3
+ Version: 0.5.3
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.5.1"
7
+ version = "0.5.3"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -1146,19 +1146,14 @@ class VannaBase(ABC):
1146
1146
 
1147
1147
  conn = None
1148
1148
 
1149
- try:
1150
- conn = bigquery.Client(project=project_id)
1151
- except:
1152
- print("Could not found any google cloud implicit credentials")
1153
-
1154
- if cred_file_path:
1149
+ if not cred_file_path:
1150
+ try:
1151
+ conn = bigquery.Client(project=project_id)
1152
+ except:
1153
+ print("Could not found any google cloud implicit credentials")
1154
+ else:
1155
1155
  # Validate file path and pemissions
1156
1156
  validate_config_path(cred_file_path)
1157
- else:
1158
- if not conn:
1159
- raise ValidationError(
1160
- "Pleae provide a service account credentials json file"
1161
- )
1162
1157
 
1163
1158
  if not conn:
1164
1159
  with open(cred_file_path, "r") as f:
@@ -1279,6 +1274,7 @@ class VannaBase(ABC):
1279
1274
  # Execute the SQL statement and return the result as a pandas DataFrame
1280
1275
  with engine.begin() as conn:
1281
1276
  df = pd.read_sql_query(sa.text(sql), conn)
1277
+ conn.close()
1282
1278
  return df
1283
1279
 
1284
1280
  raise Exception("Couldn't run sql")
@@ -1286,6 +1282,116 @@ class VannaBase(ABC):
1286
1282
  self.dialect = "T-SQL / Microsoft SQL Server"
1287
1283
  self.run_sql = run_sql_mssql
1288
1284
  self.run_sql_is_set = True
1285
+ def connect_to_presto(
1286
+ self,
1287
+ host: str,
1288
+ catalog: str = 'hive',
1289
+ schema: str = 'default',
1290
+ user: str = None,
1291
+ password: str = None,
1292
+ port: int = None,
1293
+ combined_pem_path: str = None,
1294
+ protocol: str = 'https',
1295
+ requests_kwargs: dict = None
1296
+ ):
1297
+ """
1298
+ Connect to a Presto database using the specified parameters.
1299
+
1300
+ Args:
1301
+ host (str): The host address of the Presto database.
1302
+ catalog (str): The catalog to use in the Presto environment.
1303
+ schema (str): The schema to use in the Presto environment.
1304
+ user (str): The username for authentication.
1305
+ password (str): The password for authentication.
1306
+ port (int): The port number for the Presto connection.
1307
+ combined_pem_path (str): The path to the combined pem file for SSL connection.
1308
+ protocol (str): The protocol to use for the connection (default is 'https').
1309
+ requests_kwargs (dict): Additional keyword arguments for requests.
1310
+
1311
+ Raises:
1312
+ DependencyError: If required dependencies are not installed.
1313
+ ImproperlyConfigured: If essential configuration settings are missing.
1314
+
1315
+ Returns:
1316
+ None
1317
+ """
1318
+ try:
1319
+ from pyhive import presto
1320
+ except ImportError:
1321
+ raise DependencyError(
1322
+ "You need to install required dependencies to execute this method,"
1323
+ " run command: \npip install pyhive"
1324
+ )
1325
+
1326
+ if not host:
1327
+ host = os.getenv("PRESTO_HOST")
1328
+
1329
+ if not host:
1330
+ raise ImproperlyConfigured("Please set your presto host")
1331
+
1332
+ if not catalog:
1333
+ catalog = os.getenv("PRESTO_CATALOG")
1334
+
1335
+ if not catalog:
1336
+ raise ImproperlyConfigured("Please set your presto catalog")
1337
+
1338
+ if not user:
1339
+ user = os.getenv("PRESTO_USER")
1340
+
1341
+ if not user:
1342
+ raise ImproperlyConfigured("Please set your presto user")
1343
+
1344
+ if not password:
1345
+ password = os.getenv("PRESTO_PASSWORD")
1346
+
1347
+ if not port:
1348
+ port = os.getenv("PRESTO_PORT")
1349
+
1350
+ if not port:
1351
+ raise ImproperlyConfigured("Please set your presto port")
1352
+
1353
+ conn = None
1354
+
1355
+ try:
1356
+ if requests_kwargs is None and combined_pem_path is not None:
1357
+ # use the combined pem file to verify the SSL connection
1358
+ requests_kwargs = {
1359
+ 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
1360
+ }
1361
+ conn = presto.Connection(host=host,
1362
+ username=user,
1363
+ password=password,
1364
+ catalog=catalog,
1365
+ schema=schema,
1366
+ port=port,
1367
+ protocol=protocol,
1368
+ requests_kwargs=requests_kwargs)
1369
+ except presto.Error as e:
1370
+ raise ValidationError(e)
1371
+
1372
+ def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
1373
+ if conn:
1374
+ try:
1375
+ cs = conn.cursor()
1376
+ cs.execute(sql)
1377
+ results = cs.fetchall()
1378
+
1379
+ # Create a pandas dataframe from the results
1380
+ df = pd.DataFrame(
1381
+ results, columns=[desc[0] for desc in cs.description]
1382
+ )
1383
+ return df
1384
+
1385
+ except presto.Error as e:
1386
+ print(e)
1387
+ raise ValidationError(e)
1388
+
1389
+ except Exception as e:
1390
+ print(e)
1391
+ raise e
1392
+
1393
+ self.run_sql_is_set = True
1394
+ self.run_sql = run_sql_presto
1289
1395
 
1290
1396
  def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
1291
1397
  """
@@ -24,7 +24,7 @@ class Ollama(VannaBase):
24
24
  raise ValueError("config must contain at least Ollama model")
25
25
  self.host = config.get("ollama_host", "http://localhost:11434")
26
26
  self.model = config["model"]
27
- if ":" in self.model:
27
+ if ":" not in self.model:
28
28
  self.model += ":latest"
29
29
 
30
30
  self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))
@@ -24,7 +24,77 @@ class OpenSearch_VectorStore(VannaBase):
24
24
  self.document_index = document_index
25
25
  self.ddl_index = ddl_index
26
26
  self.question_sql_index = question_sql_index
27
- print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)
27
+ print("OpenSearch_VectorStore initialized with document_index: ",
28
+ document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
29
+ question_sql_index)
30
+
31
+ document_index_settings = {
32
+ "settings": {
33
+ "index": {
34
+ "number_of_shards": 6,
35
+ "number_of_replicas": 2
36
+ }
37
+ },
38
+ "mappings": {
39
+ "properties": {
40
+ "question": {
41
+ "type": "text",
42
+ },
43
+ "doc": {
44
+ "type": "text",
45
+ }
46
+ }
47
+ }
48
+ }
49
+
50
+ ddl_index_settings = {
51
+ "settings": {
52
+ "index": {
53
+ "number_of_shards": 6,
54
+ "number_of_replicas": 2
55
+ }
56
+ },
57
+ "mappings": {
58
+ "properties": {
59
+ "ddl": {
60
+ "type": "text",
61
+ },
62
+ "doc": {
63
+ "type": "text",
64
+ }
65
+ }
66
+ }
67
+ }
68
+
69
+ question_sql_index_settings = {
70
+ "settings": {
71
+ "index": {
72
+ "number_of_shards": 6,
73
+ "number_of_replicas": 2
74
+ }
75
+ },
76
+ "mappings": {
77
+ "properties": {
78
+ "question": {
79
+ "type": "text",
80
+ },
81
+ "sql": {
82
+ "type": "text",
83
+ }
84
+ }
85
+ }
86
+ }
87
+
88
+ if config is not None and "es_document_index_settings" in config:
89
+ document_index_settings = config["es_document_index_settings"]
90
+ if config is not None and "es_ddl_index_settings" in config:
91
+ ddl_index_settings = config["es_ddl_index_settings"]
92
+ if config is not None and "es_question_sql_index_settings" in config:
93
+ question_sql_index_settings = config["es_question_sql_index_settings"]
94
+
95
+ self.document_index_settings = document_index_settings
96
+ self.ddl_index_settings = ddl_index_settings
97
+ self.question_sql_index_settings = question_sql_index_settings
28
98
 
29
99
  es_urls = None
30
100
  if config is not None and "es_urls" in config:
@@ -85,6 +155,9 @@ class OpenSearch_VectorStore(VannaBase):
85
155
  else:
86
156
  max_retries = 10
87
157
 
158
+ print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
159
+ " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
160
+ verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
88
161
  if es_urls is not None:
89
162
  # Initialize the OpenSearch client by passing a list of URLs
90
163
  self.client = OpenSearch(
@@ -112,18 +185,26 @@ class OpenSearch_VectorStore(VannaBase):
112
185
  headers=headers
113
186
  )
114
187
 
188
+ print("OpenSearch_VectorStore initialized with client over ")
189
+
115
190
  # 执行一个简单的查询来检查连接
116
191
  try:
192
+ print('Connected to OpenSearch cluster:')
117
193
  info = self.client.info()
118
194
  print('OpenSearch cluster info:', info)
119
195
  except Exception as e:
120
196
  print('Error connecting to OpenSearch cluster:', e)
121
197
 
122
198
  # Create the indices if they don't exist
123
- # self.create_index()
199
+ self.create_index_if_not_exists(self.document_index,
200
+ self.document_index_settings)
201
+ self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
202
+ self.create_index_if_not_exists(self.question_sql_index,
203
+ self.question_sql_index_settings)
124
204
 
125
205
  def create_index(self):
126
- for index in [self.document_index, self.ddl_index, self.question_sql_index]:
206
+ for index in [self.document_index, self.ddl_index,
207
+ self.question_sql_index]:
127
208
  try:
128
209
  self.client.indices.create(index)
129
210
  except Exception as e:
@@ -131,6 +212,20 @@ class OpenSearch_VectorStore(VannaBase):
131
212
  print(f"opensearch index {index} already exists")
132
213
  pass
133
214
 
215
+ def create_index_if_not_exists(self, index_name: str,
216
+ index_settings: dict) -> bool:
217
+ try:
218
+ if not self.client.indices.exists(index_name):
219
+ print(f"Index {index_name} does not exist. Creating...")
220
+ self.client.indices.create(index=index_name, body=index_settings)
221
+ return True
222
+ else:
223
+ print(f"Index {index_name} already exists.")
224
+ return False
225
+ except Exception as e:
226
+ print(f"Error creating index: {index_name} ", e)
227
+ return False
228
+
134
229
  def add_ddl(self, ddl: str, **kwargs) -> str:
135
230
  # Assuming that you have a DDL index in your OpenSearch
136
231
  id = str(uuid.uuid4()) + "-ddl"
@@ -278,7 +373,6 @@ class OpenSearch_VectorStore(VannaBase):
278
373
  # opensearch doesn't need to generate embeddings
279
374
  pass
280
375
 
281
-
282
376
  # OpenSearch_VectorStore.__init__(self, config={'es_urls':
283
377
  # "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
284
378
  # "admin", 'es_password': "admin", 'es_verify_certs': True})
@@ -7,48 +7,51 @@ from qdrant_client import QdrantClient, grpc, models
7
7
  from ..base import VannaBase
8
8
  from ..utils import deterministic_uuid
9
9
 
10
- DOCUMENTATION_COLLECTION_NAME = "documentation"
11
- DDL_COLLECTION_NAME = "ddl"
12
- SQL_COLLECTION_NAME = "sql"
13
10
  SCROLL_SIZE = 1000
14
11
 
15
- ID_SUFFIXES = {
16
- DDL_COLLECTION_NAME: "ddl",
17
- DOCUMENTATION_COLLECTION_NAME: "doc",
18
- SQL_COLLECTION_NAME: "sql",
19
- }
20
-
21
12
 
22
13
  class Qdrant_VectorStore(VannaBase):
23
- """Vectorstore implementation using Qdrant - https://qdrant.tech/"""
14
+ """
15
+ Vectorstore implementation using Qdrant - https://qdrant.tech/
16
+
17
+ Args:
18
+ - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
19
+ - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
20
+ - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
21
+ - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
22
+ - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
23
+ - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
24
+ - api_key: API key for authentication in Qdrant Cloud. Default: `None`
25
+ - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
26
+ - path: Persistence path for QdrantLocal. Default: `None`.
27
+ - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
28
+ - n_results: Number of results to return from similarity search. Defaults to 10.
29
+ - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
30
+ Defaults to `"BAAI/bge-small-en-v1.5"`.
31
+ - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
32
+ - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
33
+ - documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`.
34
+ - ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`.
35
+ - sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`.
36
+
37
+ Raises:
38
+ TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
39
+ """
40
+
41
+ documentation_collection_name = "documentation"
42
+ ddl_collection_name = "ddl"
43
+ sql_collection_name = "sql"
44
+
45
+ id_suffixes = {
46
+ ddl_collection_name: "ddl",
47
+ documentation_collection_name: "doc",
48
+ sql_collection_name: "sql",
49
+ }
24
50
 
25
51
  def __init__(
26
52
  self,
27
53
  config={},
28
54
  ):
29
- """
30
- Vectorstore implementation using Qdrant - https://qdrant.tech/
31
-
32
- Args:
33
- - config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
34
- - client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
35
- - location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
36
- - url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
37
- - prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
38
- - https: If `true` - use HTTPS(SSL) protocol. Default: `None`
39
- - api_key: API key for authentication in Qdrant Cloud. Default: `None`
40
- - timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
41
- - path: Persistence path for QdrantLocal. Default: `None`.
42
- - prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
43
- - n_results: Number of results to return from similarity search. Defaults to 10.
44
- - fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
45
- Defaults to `"BAAI/bge-small-en-v1.5"`.
46
- - collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
47
- - distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
48
-
49
- Raises:
50
- TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
51
- """
52
55
  VannaBase.__init__(self, config=config)
53
56
  client = config.get("client")
54
57
 
@@ -75,6 +78,15 @@ class Qdrant_VectorStore(VannaBase):
75
78
  self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
76
79
  self.collection_params = config.get("collection_params", {})
77
80
  self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
81
+ self.documentation_collection_name = config.get(
82
+ "documentation_collection_name", self.documentation_collection_name
83
+ )
84
+ self.ddl_collection_name = config.get(
85
+ "ddl_collection_name", self.ddl_collection_name
86
+ )
87
+ self.sql_collection_name = config.get(
88
+ "sql_collection_name", self.sql_collection_name
89
+ )
78
90
 
79
91
  self._setup_collections()
80
92
 
@@ -83,7 +95,7 @@ class Qdrant_VectorStore(VannaBase):
83
95
  id = deterministic_uuid(question_answer)
84
96
 
85
97
  self._client.upsert(
86
- SQL_COLLECTION_NAME,
98
+ self.sql_collection_name,
87
99
  points=[
88
100
  models.PointStruct(
89
101
  id=id,
@@ -96,12 +108,12 @@ class Qdrant_VectorStore(VannaBase):
96
108
  ],
97
109
  )
98
110
 
99
- return self._format_point_id(id, SQL_COLLECTION_NAME)
111
+ return self._format_point_id(id, self.sql_collection_name)
100
112
 
101
113
  def add_ddl(self, ddl: str, **kwargs) -> str:
102
114
  id = deterministic_uuid(ddl)
103
115
  self._client.upsert(
104
- DDL_COLLECTION_NAME,
116
+ self.ddl_collection_name,
105
117
  points=[
106
118
  models.PointStruct(
107
119
  id=id,
@@ -112,13 +124,13 @@ class Qdrant_VectorStore(VannaBase):
112
124
  )
113
125
  ],
114
126
  )
115
- return self._format_point_id(id, DDL_COLLECTION_NAME)
127
+ return self._format_point_id(id, self.ddl_collection_name)
116
128
 
117
129
  def add_documentation(self, documentation: str, **kwargs) -> str:
118
130
  id = deterministic_uuid(documentation)
119
131
 
120
132
  self._client.upsert(
121
- DOCUMENTATION_COLLECTION_NAME,
133
+ self.documentation_collection_name,
122
134
  points=[
123
135
  models.PointStruct(
124
136
  id=id,
@@ -130,16 +142,17 @@ class Qdrant_VectorStore(VannaBase):
130
142
  ],
131
143
  )
132
144
 
133
- return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
145
+ return self._format_point_id(id, self.documentation_collection_name)
134
146
 
135
147
  def get_training_data(self, **kwargs) -> pd.DataFrame:
136
148
  df = pd.DataFrame()
137
149
 
138
- if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
150
+ if sql_data := self._get_all_points(self.sql_collection_name):
139
151
  question_list = [data.payload["question"] for data in sql_data]
140
152
  sql_list = [data.payload["sql"] for data in sql_data]
141
153
  id_list = [
142
- self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
154
+ self._format_point_id(data.id, self.sql_collection_name)
155
+ for data in sql_data
143
156
  ]
144
157
 
145
158
  df_sql = pd.DataFrame(
@@ -154,10 +167,11 @@ class Qdrant_VectorStore(VannaBase):
154
167
 
155
168
  df = pd.concat([df, df_sql])
156
169
 
157
- if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
170
+ if ddl_data := self._get_all_points(self.ddl_collection_name):
158
171
  ddl_list = [data.payload["ddl"] for data in ddl_data]
159
172
  id_list = [
160
- self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in ddl_data
173
+ self._format_point_id(data.id, self.ddl_collection_name)
174
+ for data in ddl_data
161
175
  ]
162
176
 
163
177
  df_ddl = pd.DataFrame(
@@ -172,10 +186,10 @@ class Qdrant_VectorStore(VannaBase):
172
186
 
173
187
  df = pd.concat([df, df_ddl])
174
188
 
175
- if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
189
+ if doc_data := self._get_all_points(self.documentation_collection_name):
176
190
  document_list = [data.payload["documentation"] for data in doc_data]
177
191
  id_list = [
178
- self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
192
+ self._format_point_id(data.id, self.documentation_collection_name)
179
193
  for data in doc_data
180
194
  ]
181
195
 
@@ -210,7 +224,7 @@ class Qdrant_VectorStore(VannaBase):
210
224
  Returns:
211
225
  bool: True if collection is deleted, False otherwise
212
226
  """
213
- if collection_name in ID_SUFFIXES.keys():
227
+ if collection_name in self.id_suffixes.keys():
214
228
  self._client.delete_collection(collection_name)
215
229
  self._setup_collections()
216
230
  return True
@@ -223,7 +237,7 @@ class Qdrant_VectorStore(VannaBase):
223
237
 
224
238
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
225
239
  results = self._client.search(
226
- SQL_COLLECTION_NAME,
240
+ self.sql_collection_name,
227
241
  query_vector=self.generate_embedding(question),
228
242
  limit=self.n_results,
229
243
  with_payload=True,
@@ -233,7 +247,7 @@ class Qdrant_VectorStore(VannaBase):
233
247
 
234
248
  def get_related_ddl(self, question: str, **kwargs) -> list:
235
249
  results = self._client.search(
236
- DDL_COLLECTION_NAME,
250
+ self.ddl_collection_name,
237
251
  query_vector=self.generate_embedding(question),
238
252
  limit=self.n_results,
239
253
  with_payload=True,
@@ -243,7 +257,7 @@ class Qdrant_VectorStore(VannaBase):
243
257
 
244
258
  def get_related_documentation(self, question: str, **kwargs) -> list:
245
259
  results = self._client.search(
246
- DOCUMENTATION_COLLECTION_NAME,
260
+ self.documentation_collection_name,
247
261
  query_vector=self.generate_embedding(question),
248
262
  limit=self.n_results,
249
263
  with_payload=True,
@@ -282,9 +296,9 @@ class Qdrant_VectorStore(VannaBase):
282
296
  return results
283
297
 
284
298
  def _setup_collections(self):
285
- if not self._client.collection_exists(SQL_COLLECTION_NAME):
299
+ if not self._client.collection_exists(self.sql_collection_name):
286
300
  self._client.create_collection(
287
- collection_name=SQL_COLLECTION_NAME,
301
+ collection_name=self.sql_collection_name,
288
302
  vectors_config=models.VectorParams(
289
303
  size=self.embeddings_dimension,
290
304
  distance=self.distance_metric,
@@ -292,18 +306,18 @@ class Qdrant_VectorStore(VannaBase):
292
306
  **self.collection_params,
293
307
  )
294
308
 
295
- if not self._client.collection_exists(DDL_COLLECTION_NAME):
309
+ if not self._client.collection_exists(self.ddl_collection_name):
296
310
  self._client.create_collection(
297
- collection_name=DDL_COLLECTION_NAME,
311
+ collection_name=self.ddl_collection_name,
298
312
  vectors_config=models.VectorParams(
299
313
  size=self.embeddings_dimension,
300
314
  distance=self.distance_metric,
301
315
  ),
302
316
  **self.collection_params,
303
317
  )
304
- if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
318
+ if not self._client.collection_exists(self.documentation_collection_name):
305
319
  self._client.create_collection(
306
- collection_name=DOCUMENTATION_COLLECTION_NAME,
320
+ collection_name=self.documentation_collection_name,
307
321
  vectors_config=models.VectorParams(
308
322
  size=self.embeddings_dimension,
309
323
  distance=self.distance_metric,
@@ -312,11 +326,11 @@ class Qdrant_VectorStore(VannaBase):
312
326
  )
313
327
 
314
328
  def _format_point_id(self, id: str, collection_name: str) -> str:
315
- return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
329
+ return "{0}-{1}".format(id, self.id_suffixes[collection_name])
316
330
 
317
331
  def _parse_point_id(self, id: str) -> Tuple[str, str]:
318
- id, suffix = id.rsplit("-", 1)
319
- for collection_name, suffix in ID_SUFFIXES.items():
320
- if type == suffix:
332
+ id, curr_suffix = id.rsplit("-", 1)
333
+ for collection_name, suffix in self.id_suffixes.items():
334
+ if curr_suffix == suffix:
321
335
  return id, collection_name
322
336
  raise ValueError(f"Invalid id {id}")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes