vanna 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/base/base.py +117 -11
- vanna/ollama/ollama.py +1 -1
- vanna/opensearch/opensearch_vector.py +98 -4
- vanna/qdrant/qdrant.py +73 -59
- {vanna-0.5.1.dist-info → vanna-0.5.3.dist-info}/METADATA +1 -1
- {vanna-0.5.1.dist-info → vanna-0.5.3.dist-info}/RECORD +7 -7
- {vanna-0.5.1.dist-info → vanna-0.5.3.dist-info}/WHEEL +0 -0
vanna/base/base.py
CHANGED
|
@@ -1146,19 +1146,14 @@ class VannaBase(ABC):
|
|
|
1146
1146
|
|
|
1147
1147
|
conn = None
|
|
1148
1148
|
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1149
|
+
if not cred_file_path:
|
|
1150
|
+
try:
|
|
1151
|
+
conn = bigquery.Client(project=project_id)
|
|
1152
|
+
except:
|
|
1153
|
+
print("Could not found any google cloud implicit credentials")
|
|
1154
|
+
else:
|
|
1155
1155
|
# Validate file path and pemissions
|
|
1156
1156
|
validate_config_path(cred_file_path)
|
|
1157
|
-
else:
|
|
1158
|
-
if not conn:
|
|
1159
|
-
raise ValidationError(
|
|
1160
|
-
"Pleae provide a service account credentials json file"
|
|
1161
|
-
)
|
|
1162
1157
|
|
|
1163
1158
|
if not conn:
|
|
1164
1159
|
with open(cred_file_path, "r") as f:
|
|
@@ -1279,6 +1274,7 @@ class VannaBase(ABC):
|
|
|
1279
1274
|
# Execute the SQL statement and return the result as a pandas DataFrame
|
|
1280
1275
|
with engine.begin() as conn:
|
|
1281
1276
|
df = pd.read_sql_query(sa.text(sql), conn)
|
|
1277
|
+
conn.close()
|
|
1282
1278
|
return df
|
|
1283
1279
|
|
|
1284
1280
|
raise Exception("Couldn't run sql")
|
|
@@ -1286,6 +1282,116 @@ class VannaBase(ABC):
|
|
|
1286
1282
|
self.dialect = "T-SQL / Microsoft SQL Server"
|
|
1287
1283
|
self.run_sql = run_sql_mssql
|
|
1288
1284
|
self.run_sql_is_set = True
|
|
1285
|
+
def connect_to_presto(
|
|
1286
|
+
self,
|
|
1287
|
+
host: str,
|
|
1288
|
+
catalog: str = 'hive',
|
|
1289
|
+
schema: str = 'default',
|
|
1290
|
+
user: str = None,
|
|
1291
|
+
password: str = None,
|
|
1292
|
+
port: int = None,
|
|
1293
|
+
combined_pem_path: str = None,
|
|
1294
|
+
protocol: str = 'https',
|
|
1295
|
+
requests_kwargs: dict = None
|
|
1296
|
+
):
|
|
1297
|
+
"""
|
|
1298
|
+
Connect to a Presto database using the specified parameters.
|
|
1299
|
+
|
|
1300
|
+
Args:
|
|
1301
|
+
host (str): The host address of the Presto database.
|
|
1302
|
+
catalog (str): The catalog to use in the Presto environment.
|
|
1303
|
+
schema (str): The schema to use in the Presto environment.
|
|
1304
|
+
user (str): The username for authentication.
|
|
1305
|
+
password (str): The password for authentication.
|
|
1306
|
+
port (int): The port number for the Presto connection.
|
|
1307
|
+
combined_pem_path (str): The path to the combined pem file for SSL connection.
|
|
1308
|
+
protocol (str): The protocol to use for the connection (default is 'https').
|
|
1309
|
+
requests_kwargs (dict): Additional keyword arguments for requests.
|
|
1310
|
+
|
|
1311
|
+
Raises:
|
|
1312
|
+
DependencyError: If required dependencies are not installed.
|
|
1313
|
+
ImproperlyConfigured: If essential configuration settings are missing.
|
|
1314
|
+
|
|
1315
|
+
Returns:
|
|
1316
|
+
None
|
|
1317
|
+
"""
|
|
1318
|
+
try:
|
|
1319
|
+
from pyhive import presto
|
|
1320
|
+
except ImportError:
|
|
1321
|
+
raise DependencyError(
|
|
1322
|
+
"You need to install required dependencies to execute this method,"
|
|
1323
|
+
" run command: \npip install pyhive"
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1326
|
+
if not host:
|
|
1327
|
+
host = os.getenv("PRESTO_HOST")
|
|
1328
|
+
|
|
1329
|
+
if not host:
|
|
1330
|
+
raise ImproperlyConfigured("Please set your presto host")
|
|
1331
|
+
|
|
1332
|
+
if not catalog:
|
|
1333
|
+
catalog = os.getenv("PRESTO_CATALOG")
|
|
1334
|
+
|
|
1335
|
+
if not catalog:
|
|
1336
|
+
raise ImproperlyConfigured("Please set your presto catalog")
|
|
1337
|
+
|
|
1338
|
+
if not user:
|
|
1339
|
+
user = os.getenv("PRESTO_USER")
|
|
1340
|
+
|
|
1341
|
+
if not user:
|
|
1342
|
+
raise ImproperlyConfigured("Please set your presto user")
|
|
1343
|
+
|
|
1344
|
+
if not password:
|
|
1345
|
+
password = os.getenv("PRESTO_PASSWORD")
|
|
1346
|
+
|
|
1347
|
+
if not port:
|
|
1348
|
+
port = os.getenv("PRESTO_PORT")
|
|
1349
|
+
|
|
1350
|
+
if not port:
|
|
1351
|
+
raise ImproperlyConfigured("Please set your presto port")
|
|
1352
|
+
|
|
1353
|
+
conn = None
|
|
1354
|
+
|
|
1355
|
+
try:
|
|
1356
|
+
if requests_kwargs is None and combined_pem_path is not None:
|
|
1357
|
+
# use the combined pem file to verify the SSL connection
|
|
1358
|
+
requests_kwargs = {
|
|
1359
|
+
'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证
|
|
1360
|
+
}
|
|
1361
|
+
conn = presto.Connection(host=host,
|
|
1362
|
+
username=user,
|
|
1363
|
+
password=password,
|
|
1364
|
+
catalog=catalog,
|
|
1365
|
+
schema=schema,
|
|
1366
|
+
port=port,
|
|
1367
|
+
protocol=protocol,
|
|
1368
|
+
requests_kwargs=requests_kwargs)
|
|
1369
|
+
except presto.Error as e:
|
|
1370
|
+
raise ValidationError(e)
|
|
1371
|
+
|
|
1372
|
+
def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]:
|
|
1373
|
+
if conn:
|
|
1374
|
+
try:
|
|
1375
|
+
cs = conn.cursor()
|
|
1376
|
+
cs.execute(sql)
|
|
1377
|
+
results = cs.fetchall()
|
|
1378
|
+
|
|
1379
|
+
# Create a pandas dataframe from the results
|
|
1380
|
+
df = pd.DataFrame(
|
|
1381
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
1382
|
+
)
|
|
1383
|
+
return df
|
|
1384
|
+
|
|
1385
|
+
except presto.Error as e:
|
|
1386
|
+
print(e)
|
|
1387
|
+
raise ValidationError(e)
|
|
1388
|
+
|
|
1389
|
+
except Exception as e:
|
|
1390
|
+
print(e)
|
|
1391
|
+
raise e
|
|
1392
|
+
|
|
1393
|
+
self.run_sql_is_set = True
|
|
1394
|
+
self.run_sql = run_sql_presto
|
|
1289
1395
|
|
|
1290
1396
|
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
|
|
1291
1397
|
"""
|
vanna/ollama/ollama.py
CHANGED
|
@@ -24,7 +24,7 @@ class Ollama(VannaBase):
|
|
|
24
24
|
raise ValueError("config must contain at least Ollama model")
|
|
25
25
|
self.host = config.get("ollama_host", "http://localhost:11434")
|
|
26
26
|
self.model = config["model"]
|
|
27
|
-
if ":" in self.model:
|
|
27
|
+
if ":" not in self.model:
|
|
28
28
|
self.model += ":latest"
|
|
29
29
|
|
|
30
30
|
self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))
|
|
@@ -24,7 +24,77 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
24
24
|
self.document_index = document_index
|
|
25
25
|
self.ddl_index = ddl_index
|
|
26
26
|
self.question_sql_index = question_sql_index
|
|
27
|
-
print("OpenSearch_VectorStore initialized with document_index: ",
|
|
27
|
+
print("OpenSearch_VectorStore initialized with document_index: ",
|
|
28
|
+
document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
|
|
29
|
+
question_sql_index)
|
|
30
|
+
|
|
31
|
+
document_index_settings = {
|
|
32
|
+
"settings": {
|
|
33
|
+
"index": {
|
|
34
|
+
"number_of_shards": 6,
|
|
35
|
+
"number_of_replicas": 2
|
|
36
|
+
}
|
|
37
|
+
},
|
|
38
|
+
"mappings": {
|
|
39
|
+
"properties": {
|
|
40
|
+
"question": {
|
|
41
|
+
"type": "text",
|
|
42
|
+
},
|
|
43
|
+
"doc": {
|
|
44
|
+
"type": "text",
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
ddl_index_settings = {
|
|
51
|
+
"settings": {
|
|
52
|
+
"index": {
|
|
53
|
+
"number_of_shards": 6,
|
|
54
|
+
"number_of_replicas": 2
|
|
55
|
+
}
|
|
56
|
+
},
|
|
57
|
+
"mappings": {
|
|
58
|
+
"properties": {
|
|
59
|
+
"ddl": {
|
|
60
|
+
"type": "text",
|
|
61
|
+
},
|
|
62
|
+
"doc": {
|
|
63
|
+
"type": "text",
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
question_sql_index_settings = {
|
|
70
|
+
"settings": {
|
|
71
|
+
"index": {
|
|
72
|
+
"number_of_shards": 6,
|
|
73
|
+
"number_of_replicas": 2
|
|
74
|
+
}
|
|
75
|
+
},
|
|
76
|
+
"mappings": {
|
|
77
|
+
"properties": {
|
|
78
|
+
"question": {
|
|
79
|
+
"type": "text",
|
|
80
|
+
},
|
|
81
|
+
"sql": {
|
|
82
|
+
"type": "text",
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if config is not None and "es_document_index_settings" in config:
|
|
89
|
+
document_index_settings = config["es_document_index_settings"]
|
|
90
|
+
if config is not None and "es_ddl_index_settings" in config:
|
|
91
|
+
ddl_index_settings = config["es_ddl_index_settings"]
|
|
92
|
+
if config is not None and "es_question_sql_index_settings" in config:
|
|
93
|
+
question_sql_index_settings = config["es_question_sql_index_settings"]
|
|
94
|
+
|
|
95
|
+
self.document_index_settings = document_index_settings
|
|
96
|
+
self.ddl_index_settings = ddl_index_settings
|
|
97
|
+
self.question_sql_index_settings = question_sql_index_settings
|
|
28
98
|
|
|
29
99
|
es_urls = None
|
|
30
100
|
if config is not None and "es_urls" in config:
|
|
@@ -85,6 +155,9 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
85
155
|
else:
|
|
86
156
|
max_retries = 10
|
|
87
157
|
|
|
158
|
+
print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
|
|
159
|
+
" host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
|
|
160
|
+
verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
|
|
88
161
|
if es_urls is not None:
|
|
89
162
|
# Initialize the OpenSearch client by passing a list of URLs
|
|
90
163
|
self.client = OpenSearch(
|
|
@@ -112,18 +185,26 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
112
185
|
headers=headers
|
|
113
186
|
)
|
|
114
187
|
|
|
188
|
+
print("OpenSearch_VectorStore initialized with client over ")
|
|
189
|
+
|
|
115
190
|
# 执行一个简单的查询来检查连接
|
|
116
191
|
try:
|
|
192
|
+
print('Connected to OpenSearch cluster:')
|
|
117
193
|
info = self.client.info()
|
|
118
194
|
print('OpenSearch cluster info:', info)
|
|
119
195
|
except Exception as e:
|
|
120
196
|
print('Error connecting to OpenSearch cluster:', e)
|
|
121
197
|
|
|
122
198
|
# Create the indices if they don't exist
|
|
123
|
-
|
|
199
|
+
self.create_index_if_not_exists(self.document_index,
|
|
200
|
+
self.document_index_settings)
|
|
201
|
+
self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
|
|
202
|
+
self.create_index_if_not_exists(self.question_sql_index,
|
|
203
|
+
self.question_sql_index_settings)
|
|
124
204
|
|
|
125
205
|
def create_index(self):
|
|
126
|
-
for index in [self.document_index, self.ddl_index,
|
|
206
|
+
for index in [self.document_index, self.ddl_index,
|
|
207
|
+
self.question_sql_index]:
|
|
127
208
|
try:
|
|
128
209
|
self.client.indices.create(index)
|
|
129
210
|
except Exception as e:
|
|
@@ -131,6 +212,20 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
131
212
|
print(f"opensearch index {index} already exists")
|
|
132
213
|
pass
|
|
133
214
|
|
|
215
|
+
def create_index_if_not_exists(self, index_name: str,
|
|
216
|
+
index_settings: dict) -> bool:
|
|
217
|
+
try:
|
|
218
|
+
if not self.client.indices.exists(index_name):
|
|
219
|
+
print(f"Index {index_name} does not exist. Creating...")
|
|
220
|
+
self.client.indices.create(index=index_name, body=index_settings)
|
|
221
|
+
return True
|
|
222
|
+
else:
|
|
223
|
+
print(f"Index {index_name} already exists.")
|
|
224
|
+
return False
|
|
225
|
+
except Exception as e:
|
|
226
|
+
print(f"Error creating index: {index_name} ", e)
|
|
227
|
+
return False
|
|
228
|
+
|
|
134
229
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
135
230
|
# Assuming that you have a DDL index in your OpenSearch
|
|
136
231
|
id = str(uuid.uuid4()) + "-ddl"
|
|
@@ -278,7 +373,6 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
278
373
|
# opensearch doesn't need to generate embeddings
|
|
279
374
|
pass
|
|
280
375
|
|
|
281
|
-
|
|
282
376
|
# OpenSearch_VectorStore.__init__(self, config={'es_urls':
|
|
283
377
|
# "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
|
|
284
378
|
# "admin", 'es_password': "admin", 'es_verify_certs': True})
|
vanna/qdrant/qdrant.py
CHANGED
|
@@ -7,48 +7,51 @@ from qdrant_client import QdrantClient, grpc, models
|
|
|
7
7
|
from ..base import VannaBase
|
|
8
8
|
from ..utils import deterministic_uuid
|
|
9
9
|
|
|
10
|
-
DOCUMENTATION_COLLECTION_NAME = "documentation"
|
|
11
|
-
DDL_COLLECTION_NAME = "ddl"
|
|
12
|
-
SQL_COLLECTION_NAME = "sql"
|
|
13
10
|
SCROLL_SIZE = 1000
|
|
14
11
|
|
|
15
|
-
ID_SUFFIXES = {
|
|
16
|
-
DDL_COLLECTION_NAME: "ddl",
|
|
17
|
-
DOCUMENTATION_COLLECTION_NAME: "doc",
|
|
18
|
-
SQL_COLLECTION_NAME: "sql",
|
|
19
|
-
}
|
|
20
|
-
|
|
21
12
|
|
|
22
13
|
class Qdrant_VectorStore(VannaBase):
|
|
23
|
-
"""
|
|
14
|
+
"""
|
|
15
|
+
Vectorstore implementation using Qdrant - https://qdrant.tech/
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
|
|
19
|
+
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
|
|
20
|
+
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
|
|
21
|
+
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
|
|
22
|
+
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
|
|
23
|
+
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
|
|
24
|
+
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
|
|
25
|
+
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
|
|
26
|
+
- path: Persistence path for QdrantLocal. Default: `None`.
|
|
27
|
+
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
|
|
28
|
+
- n_results: Number of results to return from similarity search. Defaults to 10.
|
|
29
|
+
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
|
|
30
|
+
Defaults to `"BAAI/bge-small-en-v1.5"`.
|
|
31
|
+
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
|
|
32
|
+
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
|
|
33
|
+
- documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`.
|
|
34
|
+
- ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`.
|
|
35
|
+
- sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
documentation_collection_name = "documentation"
|
|
42
|
+
ddl_collection_name = "ddl"
|
|
43
|
+
sql_collection_name = "sql"
|
|
44
|
+
|
|
45
|
+
id_suffixes = {
|
|
46
|
+
ddl_collection_name: "ddl",
|
|
47
|
+
documentation_collection_name: "doc",
|
|
48
|
+
sql_collection_name: "sql",
|
|
49
|
+
}
|
|
24
50
|
|
|
25
51
|
def __init__(
|
|
26
52
|
self,
|
|
27
53
|
config={},
|
|
28
54
|
):
|
|
29
|
-
"""
|
|
30
|
-
Vectorstore implementation using Qdrant - https://qdrant.tech/
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
|
|
34
|
-
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
|
|
35
|
-
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
|
|
36
|
-
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
|
|
37
|
-
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
|
|
38
|
-
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
|
|
39
|
-
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
|
|
40
|
-
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
|
|
41
|
-
- path: Persistence path for QdrantLocal. Default: `None`.
|
|
42
|
-
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
|
|
43
|
-
- n_results: Number of results to return from similarity search. Defaults to 10.
|
|
44
|
-
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
|
|
45
|
-
Defaults to `"BAAI/bge-small-en-v1.5"`.
|
|
46
|
-
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
|
|
47
|
-
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
|
|
48
|
-
|
|
49
|
-
Raises:
|
|
50
|
-
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
|
|
51
|
-
"""
|
|
52
55
|
VannaBase.__init__(self, config=config)
|
|
53
56
|
client = config.get("client")
|
|
54
57
|
|
|
@@ -75,6 +78,15 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
75
78
|
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
|
|
76
79
|
self.collection_params = config.get("collection_params", {})
|
|
77
80
|
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
|
|
81
|
+
self.documentation_collection_name = config.get(
|
|
82
|
+
"documentation_collection_name", self.documentation_collection_name
|
|
83
|
+
)
|
|
84
|
+
self.ddl_collection_name = config.get(
|
|
85
|
+
"ddl_collection_name", self.ddl_collection_name
|
|
86
|
+
)
|
|
87
|
+
self.sql_collection_name = config.get(
|
|
88
|
+
"sql_collection_name", self.sql_collection_name
|
|
89
|
+
)
|
|
78
90
|
|
|
79
91
|
self._setup_collections()
|
|
80
92
|
|
|
@@ -83,7 +95,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
83
95
|
id = deterministic_uuid(question_answer)
|
|
84
96
|
|
|
85
97
|
self._client.upsert(
|
|
86
|
-
|
|
98
|
+
self.sql_collection_name,
|
|
87
99
|
points=[
|
|
88
100
|
models.PointStruct(
|
|
89
101
|
id=id,
|
|
@@ -96,12 +108,12 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
96
108
|
],
|
|
97
109
|
)
|
|
98
110
|
|
|
99
|
-
return self._format_point_id(id,
|
|
111
|
+
return self._format_point_id(id, self.sql_collection_name)
|
|
100
112
|
|
|
101
113
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
102
114
|
id = deterministic_uuid(ddl)
|
|
103
115
|
self._client.upsert(
|
|
104
|
-
|
|
116
|
+
self.ddl_collection_name,
|
|
105
117
|
points=[
|
|
106
118
|
models.PointStruct(
|
|
107
119
|
id=id,
|
|
@@ -112,13 +124,13 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
112
124
|
)
|
|
113
125
|
],
|
|
114
126
|
)
|
|
115
|
-
return self._format_point_id(id,
|
|
127
|
+
return self._format_point_id(id, self.ddl_collection_name)
|
|
116
128
|
|
|
117
129
|
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
118
130
|
id = deterministic_uuid(documentation)
|
|
119
131
|
|
|
120
132
|
self._client.upsert(
|
|
121
|
-
|
|
133
|
+
self.documentation_collection_name,
|
|
122
134
|
points=[
|
|
123
135
|
models.PointStruct(
|
|
124
136
|
id=id,
|
|
@@ -130,16 +142,17 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
130
142
|
],
|
|
131
143
|
)
|
|
132
144
|
|
|
133
|
-
return self._format_point_id(id,
|
|
145
|
+
return self._format_point_id(id, self.documentation_collection_name)
|
|
134
146
|
|
|
135
147
|
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
136
148
|
df = pd.DataFrame()
|
|
137
149
|
|
|
138
|
-
if sql_data := self._get_all_points(
|
|
150
|
+
if sql_data := self._get_all_points(self.sql_collection_name):
|
|
139
151
|
question_list = [data.payload["question"] for data in sql_data]
|
|
140
152
|
sql_list = [data.payload["sql"] for data in sql_data]
|
|
141
153
|
id_list = [
|
|
142
|
-
self._format_point_id(data.id,
|
|
154
|
+
self._format_point_id(data.id, self.sql_collection_name)
|
|
155
|
+
for data in sql_data
|
|
143
156
|
]
|
|
144
157
|
|
|
145
158
|
df_sql = pd.DataFrame(
|
|
@@ -154,10 +167,11 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
154
167
|
|
|
155
168
|
df = pd.concat([df, df_sql])
|
|
156
169
|
|
|
157
|
-
if ddl_data := self._get_all_points(
|
|
170
|
+
if ddl_data := self._get_all_points(self.ddl_collection_name):
|
|
158
171
|
ddl_list = [data.payload["ddl"] for data in ddl_data]
|
|
159
172
|
id_list = [
|
|
160
|
-
self._format_point_id(data.id,
|
|
173
|
+
self._format_point_id(data.id, self.ddl_collection_name)
|
|
174
|
+
for data in ddl_data
|
|
161
175
|
]
|
|
162
176
|
|
|
163
177
|
df_ddl = pd.DataFrame(
|
|
@@ -172,10 +186,10 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
172
186
|
|
|
173
187
|
df = pd.concat([df, df_ddl])
|
|
174
188
|
|
|
175
|
-
if doc_data := self._get_all_points(
|
|
189
|
+
if doc_data := self._get_all_points(self.documentation_collection_name):
|
|
176
190
|
document_list = [data.payload["documentation"] for data in doc_data]
|
|
177
191
|
id_list = [
|
|
178
|
-
self._format_point_id(data.id,
|
|
192
|
+
self._format_point_id(data.id, self.documentation_collection_name)
|
|
179
193
|
for data in doc_data
|
|
180
194
|
]
|
|
181
195
|
|
|
@@ -210,7 +224,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
210
224
|
Returns:
|
|
211
225
|
bool: True if collection is deleted, False otherwise
|
|
212
226
|
"""
|
|
213
|
-
if collection_name in
|
|
227
|
+
if collection_name in self.id_suffixes.keys():
|
|
214
228
|
self._client.delete_collection(collection_name)
|
|
215
229
|
self._setup_collections()
|
|
216
230
|
return True
|
|
@@ -223,7 +237,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
223
237
|
|
|
224
238
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
225
239
|
results = self._client.search(
|
|
226
|
-
|
|
240
|
+
self.sql_collection_name,
|
|
227
241
|
query_vector=self.generate_embedding(question),
|
|
228
242
|
limit=self.n_results,
|
|
229
243
|
with_payload=True,
|
|
@@ -233,7 +247,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
233
247
|
|
|
234
248
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
235
249
|
results = self._client.search(
|
|
236
|
-
|
|
250
|
+
self.ddl_collection_name,
|
|
237
251
|
query_vector=self.generate_embedding(question),
|
|
238
252
|
limit=self.n_results,
|
|
239
253
|
with_payload=True,
|
|
@@ -243,7 +257,7 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
243
257
|
|
|
244
258
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
245
259
|
results = self._client.search(
|
|
246
|
-
|
|
260
|
+
self.documentation_collection_name,
|
|
247
261
|
query_vector=self.generate_embedding(question),
|
|
248
262
|
limit=self.n_results,
|
|
249
263
|
with_payload=True,
|
|
@@ -282,9 +296,9 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
282
296
|
return results
|
|
283
297
|
|
|
284
298
|
def _setup_collections(self):
|
|
285
|
-
if not self._client.collection_exists(
|
|
299
|
+
if not self._client.collection_exists(self.sql_collection_name):
|
|
286
300
|
self._client.create_collection(
|
|
287
|
-
collection_name=
|
|
301
|
+
collection_name=self.sql_collection_name,
|
|
288
302
|
vectors_config=models.VectorParams(
|
|
289
303
|
size=self.embeddings_dimension,
|
|
290
304
|
distance=self.distance_metric,
|
|
@@ -292,18 +306,18 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
292
306
|
**self.collection_params,
|
|
293
307
|
)
|
|
294
308
|
|
|
295
|
-
if not self._client.collection_exists(
|
|
309
|
+
if not self._client.collection_exists(self.ddl_collection_name):
|
|
296
310
|
self._client.create_collection(
|
|
297
|
-
collection_name=
|
|
311
|
+
collection_name=self.ddl_collection_name,
|
|
298
312
|
vectors_config=models.VectorParams(
|
|
299
313
|
size=self.embeddings_dimension,
|
|
300
314
|
distance=self.distance_metric,
|
|
301
315
|
),
|
|
302
316
|
**self.collection_params,
|
|
303
317
|
)
|
|
304
|
-
if not self._client.collection_exists(
|
|
318
|
+
if not self._client.collection_exists(self.documentation_collection_name):
|
|
305
319
|
self._client.create_collection(
|
|
306
|
-
collection_name=
|
|
320
|
+
collection_name=self.documentation_collection_name,
|
|
307
321
|
vectors_config=models.VectorParams(
|
|
308
322
|
size=self.embeddings_dimension,
|
|
309
323
|
distance=self.distance_metric,
|
|
@@ -312,11 +326,11 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
312
326
|
)
|
|
313
327
|
|
|
314
328
|
def _format_point_id(self, id: str, collection_name: str) -> str:
|
|
315
|
-
return "{0}-{1}".format(id,
|
|
329
|
+
return "{0}-{1}".format(id, self.id_suffixes[collection_name])
|
|
316
330
|
|
|
317
331
|
def _parse_point_id(self, id: str) -> Tuple[str, str]:
|
|
318
|
-
id,
|
|
319
|
-
for collection_name, suffix in
|
|
320
|
-
if
|
|
332
|
+
id, curr_suffix = id.rsplit("-", 1)
|
|
333
|
+
for collection_name, suffix in self.id_suffixes.items():
|
|
334
|
+
if curr_suffix == suffix:
|
|
321
335
|
return id, collection_name
|
|
322
336
|
raise ValueError(f"Invalid id {id}")
|
|
@@ -8,7 +8,7 @@ vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
|
|
|
8
8
|
vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
|
|
9
9
|
vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
|
|
10
10
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
11
|
-
vanna/base/base.py,sha256=
|
|
11
|
+
vanna/base/base.py,sha256=hUvb94NSVjSIsIa_x38xx8OOdBwL4mmuGmuDj5ednjo,65593
|
|
12
12
|
vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
|
|
13
13
|
vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
|
|
14
14
|
vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
|
|
@@ -24,19 +24,19 @@ vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
|
|
|
24
24
|
vanna/mistral/__init__.py,sha256=70rTY-69Z2ehkkMj84dNMCukPo6AWdflBGvIB_pztS0,29
|
|
25
25
|
vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
|
|
26
26
|
vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
|
|
27
|
-
vanna/ollama/ollama.py,sha256=
|
|
27
|
+
vanna/ollama/ollama.py,sha256=rXa7cfvdlO1E5SLysXIl3IZpIaA2r0RBvV5jX2-upiE,3794
|
|
28
28
|
vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
|
|
29
29
|
vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
|
|
30
30
|
vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
|
|
31
31
|
vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
|
|
32
|
-
vanna/opensearch/opensearch_vector.py,sha256=
|
|
32
|
+
vanna/opensearch/opensearch_vector.py,sha256=90-nJuRkgOBh49VH5Lknw5qfBAlfSuqvUAqvHFbfa7g,11980
|
|
33
33
|
vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
|
|
34
|
-
vanna/qdrant/qdrant.py,sha256=
|
|
34
|
+
vanna/qdrant/qdrant.py,sha256=6M00nMiuOuftTDf3NsOrOcG9BA4DlIIDck2MNp9iEyg,12613
|
|
35
35
|
vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
|
|
36
36
|
vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
|
|
37
37
|
vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
|
|
38
38
|
vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
|
|
39
39
|
vanna/vllm/vllm.py,sha256=QerC3xF5eNzE_nGBDl6YrPYF4WYnjf0hHxxlDWdKX-0,2427
|
|
40
|
-
vanna-0.5.
|
|
41
|
-
vanna-0.5.
|
|
42
|
-
vanna-0.5.
|
|
40
|
+
vanna-0.5.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
41
|
+
vanna-0.5.3.dist-info/METADATA,sha256=JPRFSFUNqsLazznvDMWQIUy_bL9CKRQKCSgPlIwpucU,11248
|
|
42
|
+
vanna-0.5.3.dist-info/RECORD,,
|
|
File without changes
|