vanna 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/azuresearch/azuresearch_vector.py +2 -2
- vanna/faiss/__init__.py +1 -0
- vanna/faiss/faiss.py +176 -0
- vanna/qdrant/qdrant.py +9 -10
- {vanna-0.7.1.dist-info → vanna-0.7.3.dist-info}/METADATA +1 -1
- {vanna-0.7.1.dist-info → vanna-0.7.3.dist-info}/RECORD +7 -5
- {vanna-0.7.1.dist-info → vanna-0.7.3.dist-info}/WHEEL +0 -0
|
@@ -186,10 +186,10 @@ class AzureAISearch_VectorStore(VannaBase):
|
|
|
186
186
|
result = df["document"].tolist()
|
|
187
187
|
return result
|
|
188
188
|
|
|
189
|
-
def get_similar_question_sql(self,
|
|
189
|
+
def get_similar_question_sql(self, question: str) -> List[str]:
|
|
190
190
|
result = []
|
|
191
191
|
# Vectorize the text
|
|
192
|
-
vector_query = VectorizedQuery(vector=self.generate_embedding(
|
|
192
|
+
vector_query = VectorizedQuery(vector=self.generate_embedding(question), fields="document_vector")
|
|
193
193
|
df = pd.DataFrame(
|
|
194
194
|
self.search_client.search(
|
|
195
195
|
top=self.n_results_sql,
|
vanna/faiss/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .faiss import FAISS
|
vanna/faiss/faiss.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import List, Dict, Any
|
|
5
|
+
|
|
6
|
+
import faiss
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from ..base import VannaBase
|
|
11
|
+
from ..exceptions import DependencyError
|
|
12
|
+
|
|
13
|
+
class FAISS(VannaBase):
|
|
14
|
+
def __init__(self, config=None):
|
|
15
|
+
if config is None:
|
|
16
|
+
config = {}
|
|
17
|
+
|
|
18
|
+
VannaBase.__init__(self, config=config)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import faiss
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise DependencyError(
|
|
24
|
+
"FAISS is not installed. Please install it with 'pip install faiss-cpu' or 'pip install faiss-gpu'"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from sentence_transformers import SentenceTransformer
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise DependencyError(
|
|
31
|
+
"SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.path = config.get("path", ".")
|
|
35
|
+
self.embedding_dim = config.get('embedding_dim', 384)
|
|
36
|
+
self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10))
|
|
37
|
+
self.n_results_ddl = config.get('n_results_ddl', config.get("n_results", 10))
|
|
38
|
+
self.n_results_documentation = config.get('n_results_documentation', config.get("n_results", 10))
|
|
39
|
+
self.curr_client = config.get("client", "persistent")
|
|
40
|
+
|
|
41
|
+
if self.curr_client == 'persistent':
|
|
42
|
+
self.sql_index = self._load_or_create_index('sql_index.faiss')
|
|
43
|
+
self.ddl_index = self._load_or_create_index('ddl_index.faiss')
|
|
44
|
+
self.doc_index = self._load_or_create_index('doc_index.faiss')
|
|
45
|
+
elif self.curr_client == 'in-memory':
|
|
46
|
+
self.sql_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
47
|
+
self.ddl_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
48
|
+
self.doc_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
49
|
+
elif isinstance(self.curr_client, list) and len(self.curr_client) == 3 and all(isinstance(idx, faiss.Index) for idx in self.curr_client):
|
|
50
|
+
self.sql_index = self.curr_client[0]
|
|
51
|
+
self.ddl_index = self.curr_client[1]
|
|
52
|
+
self.doc_index = self.curr_client[2]
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(f"Unsupported storage type was set in config: {self.curr_client}")
|
|
55
|
+
|
|
56
|
+
self.sql_metadata: List[Dict[str, Any]] = self._load_or_create_metadata('sql_metadata.json')
|
|
57
|
+
self.ddl_metadata: List[Dict[str, str]] = self._load_or_create_metadata('ddl_metadata.json')
|
|
58
|
+
self.doc_metadata: List[Dict[str, str]] = self._load_or_create_metadata('doc_metadata.json')
|
|
59
|
+
|
|
60
|
+
model_name = config.get('embedding_model', 'all-MiniLM-L6-v2')
|
|
61
|
+
self.embedding_model = SentenceTransformer(model_name)
|
|
62
|
+
|
|
63
|
+
def _load_or_create_index(self, filename):
|
|
64
|
+
filepath = os.path.join(self.path, filename)
|
|
65
|
+
if os.path.exists(filepath):
|
|
66
|
+
return faiss.read_index(filepath)
|
|
67
|
+
return faiss.IndexFlatL2(self.embedding_dim)
|
|
68
|
+
|
|
69
|
+
def _load_or_create_metadata(self, filename):
|
|
70
|
+
filepath = os.path.join(self.path, filename)
|
|
71
|
+
if os.path.exists(filepath):
|
|
72
|
+
with open(filepath, 'r') as f:
|
|
73
|
+
return json.load(f)
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
def _save_index(self, index, filename):
|
|
77
|
+
if self.curr_client == 'persistent':
|
|
78
|
+
filepath = os.path.join(self.path, filename)
|
|
79
|
+
faiss.write_index(index, filepath)
|
|
80
|
+
|
|
81
|
+
def _save_metadata(self, metadata, filename):
|
|
82
|
+
if self.curr_client == 'persistent':
|
|
83
|
+
filepath = os.path.join(self.path, filename)
|
|
84
|
+
with open(filepath, 'w') as f:
|
|
85
|
+
json.dump(metadata, f)
|
|
86
|
+
|
|
87
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
88
|
+
embedding = self.embedding_model.encode(data)
|
|
89
|
+
assert embedding.shape[0] == self.embedding_dim, \
|
|
90
|
+
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
|
|
91
|
+
return embedding.tolist()
|
|
92
|
+
|
|
93
|
+
def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
|
|
94
|
+
embedding = self.generate_embedding(text)
|
|
95
|
+
index.add(np.array([embedding], dtype=np.float32))
|
|
96
|
+
entry_id = str(uuid.uuid4())
|
|
97
|
+
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
|
|
98
|
+
return entry_id
|
|
99
|
+
|
|
100
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
101
|
+
entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql})
|
|
102
|
+
self._save_index(self.sql_index, 'sql_index.faiss')
|
|
103
|
+
self._save_metadata(self.sql_metadata, 'sql_metadata.json')
|
|
104
|
+
return entry_id
|
|
105
|
+
|
|
106
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
107
|
+
entry_id = self._add_to_index(self.ddl_index, self.ddl_metadata, ddl, {"ddl": ddl})
|
|
108
|
+
self._save_index(self.ddl_index, 'ddl_index.faiss')
|
|
109
|
+
self._save_metadata(self.ddl_metadata, 'ddl_metadata.json')
|
|
110
|
+
return entry_id
|
|
111
|
+
|
|
112
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
113
|
+
entry_id = self._add_to_index(self.doc_index, self.doc_metadata, documentation, {"documentation": documentation})
|
|
114
|
+
self._save_index(self.doc_index, 'doc_index.faiss')
|
|
115
|
+
self._save_metadata(self.doc_metadata, 'doc_metadata.json')
|
|
116
|
+
return entry_id
|
|
117
|
+
|
|
118
|
+
def _get_similar(self, index, metadata_list, text, n_results) -> list:
|
|
119
|
+
embedding = self.generate_embedding(text)
|
|
120
|
+
D, I = index.search(np.array([embedding], dtype=np.float32), k=n_results)
|
|
121
|
+
return [] if len(I[0]) == 0 or I[0][0] == -1 else [metadata_list[i] for i in I[0]]
|
|
122
|
+
|
|
123
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
124
|
+
return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql)
|
|
125
|
+
|
|
126
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
127
|
+
return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)]
|
|
128
|
+
|
|
129
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
130
|
+
return [metadata["documentation"] for metadata in self._get_similar(self.doc_index, self.doc_metadata, question, self.n_results_documentation)]
|
|
131
|
+
|
|
132
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
133
|
+
sql_data = pd.DataFrame(self.sql_metadata)
|
|
134
|
+
sql_data['training_data_type'] = 'sql'
|
|
135
|
+
|
|
136
|
+
ddl_data = pd.DataFrame(self.ddl_metadata)
|
|
137
|
+
ddl_data['training_data_type'] = 'ddl'
|
|
138
|
+
|
|
139
|
+
doc_data = pd.DataFrame(self.doc_metadata)
|
|
140
|
+
doc_data['training_data_type'] = 'documentation'
|
|
141
|
+
|
|
142
|
+
return pd.concat([sql_data, ddl_data, doc_data], ignore_index=True)
|
|
143
|
+
|
|
144
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
145
|
+
for metadata_list, index, index_name in [
|
|
146
|
+
(self.sql_metadata, self.sql_index, 'sql_index.faiss'),
|
|
147
|
+
(self.ddl_metadata, self.ddl_index, 'ddl_index.faiss'),
|
|
148
|
+
(self.doc_metadata, self.doc_index, 'doc_index.faiss')
|
|
149
|
+
]:
|
|
150
|
+
for i, item in enumerate(metadata_list):
|
|
151
|
+
if item['id'] == id:
|
|
152
|
+
del metadata_list[i]
|
|
153
|
+
new_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
154
|
+
embeddings = [self.generate_embedding(json.dumps(m)) for m in metadata_list]
|
|
155
|
+
if embeddings:
|
|
156
|
+
new_index.add(np.array(embeddings, dtype=np.float32))
|
|
157
|
+
setattr(self, index_name.split('.')[0], new_index)
|
|
158
|
+
|
|
159
|
+
if self.curr_client == 'persistent':
|
|
160
|
+
self._save_index(new_index, index_name)
|
|
161
|
+
self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json")
|
|
162
|
+
|
|
163
|
+
return True
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
def remove_collection(self, collection_name: str) -> bool:
|
|
167
|
+
if collection_name in ["sql", "ddl", "documentation"]:
|
|
168
|
+
setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim))
|
|
169
|
+
setattr(self, f"{collection_name}_metadata", [])
|
|
170
|
+
|
|
171
|
+
if self.curr_client == 'persistent':
|
|
172
|
+
self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss")
|
|
173
|
+
self._save_metadata([], f"{collection_name}_metadata.json")
|
|
174
|
+
|
|
175
|
+
return True
|
|
176
|
+
return False
|
vanna/qdrant/qdrant.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import List, Tuple
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from qdrant_client import QdrantClient, grpc, models
|
|
6
|
-
from qdrant_client.http.models.models import UpdateStatus
|
|
7
6
|
|
|
8
7
|
from ..base import VannaBase
|
|
9
8
|
from ..utils import deterministic_uuid
|
|
@@ -234,32 +233,32 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
234
233
|
return len(self.generate_embedding("ABCDEF"))
|
|
235
234
|
|
|
236
235
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
237
|
-
results = self._client.
|
|
236
|
+
results = self._client.query_points(
|
|
238
237
|
self.sql_collection_name,
|
|
239
|
-
|
|
238
|
+
query=self.generate_embedding(question),
|
|
240
239
|
limit=self.n_results,
|
|
241
240
|
with_payload=True,
|
|
242
|
-
)
|
|
241
|
+
).points
|
|
243
242
|
|
|
244
243
|
return [dict(result.payload) for result in results]
|
|
245
244
|
|
|
246
245
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
247
|
-
results = self._client.
|
|
246
|
+
results = self._client.query_points(
|
|
248
247
|
self.ddl_collection_name,
|
|
249
|
-
|
|
248
|
+
query=self.generate_embedding(question),
|
|
250
249
|
limit=self.n_results,
|
|
251
250
|
with_payload=True,
|
|
252
|
-
)
|
|
251
|
+
).points
|
|
253
252
|
|
|
254
253
|
return [result.payload["ddl"] for result in results]
|
|
255
254
|
|
|
256
255
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
257
|
-
results = self._client.
|
|
256
|
+
results = self._client.query_points(
|
|
258
257
|
self.documentation_collection_name,
|
|
259
|
-
|
|
258
|
+
query=self.generate_embedding(question),
|
|
260
259
|
limit=self.n_results,
|
|
261
260
|
with_payload=True,
|
|
262
|
-
)
|
|
261
|
+
).points
|
|
263
262
|
|
|
264
263
|
return [result.payload["documentation"] for result in results]
|
|
265
264
|
|
|
@@ -9,7 +9,7 @@ vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,65
|
|
|
9
9
|
vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
|
|
10
10
|
vanna/anthropic/anthropic_chat.py,sha256=7X3x8SYwDY28aGyBnt0YNRMG8YY1p_t-foMfKGj8_Oo,2627
|
|
11
11
|
vanna/azuresearch/__init__.py,sha256=tZfvsrCJESiL3EnxA4PrOc5NoO8MXEzCfHX_hnj8n-c,58
|
|
12
|
-
vanna/azuresearch/azuresearch_vector.py,sha256=
|
|
12
|
+
vanna/azuresearch/azuresearch_vector.py,sha256=_-t53PUnJM914GYbTYlyee06ocfu7l2NkZerBQtlJcs,9566
|
|
13
13
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
14
14
|
vanna/base/base.py,sha256=j5xQmK-MeFKAuPjgYLSl1ThCHZieG-ab-RFFSkDlbiw,73679
|
|
15
15
|
vanna/bedrock/__init__.py,sha256=hRT2bgJbHEqViLdL-t9hfjSfFdIOkPU2ADBt-B1En-8,46
|
|
@@ -17,6 +17,8 @@ vanna/bedrock/bedrock_converse.py,sha256=Nx5kYm-diAfYmsWAnTP5xnv7V84Og69-AP9b3se
|
|
|
17
17
|
vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
|
|
18
18
|
vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
|
|
19
19
|
vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
|
|
20
|
+
vanna/faiss/__init__.py,sha256=MXuojmLPt4kUtkES9XKWJcCDHVa4L5a6YF5gebhmKLw,24
|
|
21
|
+
vanna/faiss/faiss.py,sha256=HLUO5PQdnJio9OXJiJcgmRuxVWXvg_XRBnnohS21Z0w,8304
|
|
20
22
|
vanna/flask/__init__.py,sha256=jcdaau1tQ142nL1ZsDklk0ilMkEyRxgQZdmsl1IN4LQ,43866
|
|
21
23
|
vanna/flask/assets.py,sha256=af-vact_5HSftltugBpPxzLkAI14Z0lVWcObyVe6eKE,453462
|
|
22
24
|
vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
|
|
@@ -45,7 +47,7 @@ vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8
|
|
|
45
47
|
vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
|
|
46
48
|
vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
|
|
47
49
|
vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
|
|
48
|
-
vanna/qdrant/qdrant.py,sha256=
|
|
50
|
+
vanna/qdrant/qdrant.py,sha256=Acl_jN-ZrtoQav_G3FuKypXiuYSo_hlP5lyOOwTxCWM,12527
|
|
49
51
|
vanna/qianfan/Qianfan_Chat.py,sha256=Z-s9MwH22T4KMR8AViAjms6qoj67pHeQkMsbK-aXf1M,5273
|
|
50
52
|
vanna/qianfan/Qianfan_embeddings.py,sha256=TYynAJXlyuZfmoj49h8nU6bXu_GjlXREp3tgfQUca04,954
|
|
51
53
|
vanna/qianfan/__init__.py,sha256=QpR43BjZQZcrcDRkyYcYiS-kyqtYmu23AHDzK0Wy1D0,90
|
|
@@ -59,6 +61,6 @@ vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
|
|
|
59
61
|
vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
|
|
60
62
|
vanna/weaviate/__init__.py,sha256=HL6PAl7ePBAkeG8uln-BmM7IUtWohyTPvDfcPzSGSCg,46
|
|
61
63
|
vanna/weaviate/weaviate_vector.py,sha256=tUJIZjEy2mda8CB6C8zeN2SKkEO-UJdLsIqy69skuF0,7584
|
|
62
|
-
vanna-0.7.
|
|
63
|
-
vanna-0.7.
|
|
64
|
-
vanna-0.7.
|
|
64
|
+
vanna-0.7.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
65
|
+
vanna-0.7.3.dist-info/METADATA,sha256=BOfBtwy1ENcdHApatLWXjqvKj8Zl3bti1hlueVoplR8,12407
|
|
66
|
+
vanna-0.7.3.dist-info/RECORD,,
|
|
File without changes
|