vanna 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/advanced/__init__.py +26 -0
- vanna/base/base.py +26 -26
- vanna/flask/__init__.py +124 -10
- vanna/flask/assets.py +36 -16
- vanna/mock/__init__.py +3 -0
- vanna/mock/embedding.py +11 -0
- vanna/mock/llm.py +19 -0
- vanna/mock/vectordb.py +55 -0
- vanna/opensearch/opensearch_vector.py +11 -2
- vanna/pinecone/__init__.py +3 -0
- vanna/pinecone/pinecone_vector.py +275 -0
- vanna/qdrant/qdrant.py +3 -1
- vanna/vannadb/vannadb_vector.py +179 -1
- vanna/vllm/vllm.py +16 -1
- {vanna-0.5.4.dist-info → vanna-0.6.0.dist-info}/METADATA +6 -2
- {vanna-0.5.4.dist-info → vanna-0.6.0.dist-info}/RECORD +17 -10
- {vanna-0.5.4.dist-info → vanna-0.6.0.dist-info}/WHEEL +0 -0
vanna/mock/__init__.py
ADDED
vanna/mock/embedding.py
ADDED
vanna/mock/llm.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
|
|
2
|
+
from ..base import VannaBase
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MockLLM(VannaBase):
|
|
6
|
+
def __init__(self, config=None):
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
def system_message(self, message: str) -> any:
|
|
10
|
+
return {"role": "system", "content": message}
|
|
11
|
+
|
|
12
|
+
def user_message(self, message: str) -> any:
|
|
13
|
+
return {"role": "user", "content": message}
|
|
14
|
+
|
|
15
|
+
def assistant_message(self, message: str) -> any:
|
|
16
|
+
return {"role": "assistant", "content": message}
|
|
17
|
+
|
|
18
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
19
|
+
return "Mock LLM response"
|
vanna/mock/vectordb.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
from ..base import VannaBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MockVectorDB(VannaBase):
|
|
7
|
+
def __init__(self, config=None):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
def _get_id(self, value: str, **kwargs) -> str:
|
|
11
|
+
# Hash the value and return the ID
|
|
12
|
+
return str(hash(value))
|
|
13
|
+
|
|
14
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
15
|
+
return self._get_id(ddl)
|
|
16
|
+
|
|
17
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
18
|
+
return self._get_id(doc)
|
|
19
|
+
|
|
20
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
21
|
+
return self._get_id(question)
|
|
22
|
+
|
|
23
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
27
|
+
return []
|
|
28
|
+
|
|
29
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
30
|
+
return []
|
|
31
|
+
|
|
32
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
33
|
+
return pd.DataFrame({'id': {0: '19546-ddl',
|
|
34
|
+
1: '91597-sql',
|
|
35
|
+
2: '133976-sql',
|
|
36
|
+
3: '59851-doc',
|
|
37
|
+
4: '73046-sql'},
|
|
38
|
+
'training_data_type': {0: 'ddl',
|
|
39
|
+
1: 'sql',
|
|
40
|
+
2: 'sql',
|
|
41
|
+
3: 'documentation',
|
|
42
|
+
4: 'sql'},
|
|
43
|
+
'question': {0: None,
|
|
44
|
+
1: 'What are the top selling genres?',
|
|
45
|
+
2: 'What are the low 7 artists by sales?',
|
|
46
|
+
3: None,
|
|
47
|
+
4: 'What is the total sales for each customer?'},
|
|
48
|
+
'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
|
|
49
|
+
1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
|
|
50
|
+
2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
|
|
51
|
+
3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
|
|
52
|
+
4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
|
|
53
|
+
|
|
54
|
+
def remove_training_data(id: str, **kwargs) -> bool:
|
|
55
|
+
return True
|
|
@@ -155,6 +155,11 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
155
155
|
else:
|
|
156
156
|
max_retries = 10
|
|
157
157
|
|
|
158
|
+
if config is not None and "es_http_compress" in config:
|
|
159
|
+
es_http_compress = config["es_http_compress"]
|
|
160
|
+
else:
|
|
161
|
+
es_http_compress = False
|
|
162
|
+
|
|
158
163
|
print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
|
|
159
164
|
" host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
|
|
160
165
|
verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
|
|
@@ -162,7 +167,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
162
167
|
# Initialize the OpenSearch client by passing a list of URLs
|
|
163
168
|
self.client = OpenSearch(
|
|
164
169
|
hosts=[es_urls],
|
|
165
|
-
http_compress=
|
|
170
|
+
http_compress=es_http_compress,
|
|
166
171
|
use_ssl=ssl,
|
|
167
172
|
verify_certs=verify_certs,
|
|
168
173
|
timeout=timeout,
|
|
@@ -175,7 +180,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
175
180
|
# Initialize the OpenSearch client by passing a host and port
|
|
176
181
|
self.client = OpenSearch(
|
|
177
182
|
hosts=[{'host': host, 'port': port}],
|
|
178
|
-
http_compress=
|
|
183
|
+
http_compress=es_http_compress,
|
|
179
184
|
use_ssl=ssl,
|
|
180
185
|
verify_certs=verify_certs,
|
|
181
186
|
timeout=timeout,
|
|
@@ -267,6 +272,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
267
272
|
}
|
|
268
273
|
}
|
|
269
274
|
}
|
|
275
|
+
print(query)
|
|
270
276
|
response = self.client.search(index=self.ddl_index, body=query,
|
|
271
277
|
**kwargs)
|
|
272
278
|
return [hit['_source']['ddl'] for hit in response['hits']['hits']]
|
|
@@ -279,6 +285,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
279
285
|
}
|
|
280
286
|
}
|
|
281
287
|
}
|
|
288
|
+
print(query)
|
|
282
289
|
response = self.client.search(index=self.document_index,
|
|
283
290
|
body=query,
|
|
284
291
|
**kwargs)
|
|
@@ -292,6 +299,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
292
299
|
}
|
|
293
300
|
}
|
|
294
301
|
}
|
|
302
|
+
print(query)
|
|
295
303
|
response = self.client.search(index=self.question_sql_index,
|
|
296
304
|
body=query,
|
|
297
305
|
**kwargs)
|
|
@@ -307,6 +315,7 @@ class OpenSearch_VectorStore(VannaBase):
|
|
|
307
315
|
body={"query": {"match_all": {}}},
|
|
308
316
|
size=1000
|
|
309
317
|
)
|
|
318
|
+
print(query)
|
|
310
319
|
# records = [hit['_source'] for hit in response['hits']['hits']]
|
|
311
320
|
for hit in response['hits']['hits']:
|
|
312
321
|
data.append(
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from pinecone import Pinecone, PodSpec, ServerlessSpec
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from ..base import VannaBase
|
|
7
|
+
from ..utils import deterministic_uuid
|
|
8
|
+
|
|
9
|
+
from fastembed import TextEmbedding
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PineconeDB_VectorStore(VannaBase):
|
|
13
|
+
"""
|
|
14
|
+
Vectorstore using PineconeDB
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config.
|
|
18
|
+
- client (Pinecone, optional): Pinecone client. Defaults to None.
|
|
19
|
+
- api_key (str, optional): Pinecone API key. Defaults to None.
|
|
20
|
+
- n_results (int, optional): Number of results to return. Defaults to 10.
|
|
21
|
+
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5.
|
|
22
|
+
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
|
|
23
|
+
- documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation".
|
|
24
|
+
- distance_metric (str, optional): Distance metric to use. Defaults to "cosine".
|
|
25
|
+
- ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl".
|
|
26
|
+
- sql_namespace (str, optional): Namespace for SQL. Defaults to "sql".
|
|
27
|
+
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
|
|
28
|
+
- metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}.
|
|
29
|
+
- server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod".
|
|
30
|
+
- podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config).
|
|
31
|
+
- serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2").
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod".
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, config=None):
|
|
37
|
+
VannaBase.__init__(self, config=config)
|
|
38
|
+
if config is None:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"config is required, pass either a Pinecone client or an API key in the config."
|
|
41
|
+
)
|
|
42
|
+
client = config.get("client")
|
|
43
|
+
api_key = config.get("api_key")
|
|
44
|
+
if not api_key and not client:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"api_key is required in config or pass a configured client"
|
|
47
|
+
)
|
|
48
|
+
if not client and api_key:
|
|
49
|
+
self._client = Pinecone(api_key=api_key)
|
|
50
|
+
elif not isinstance(client, Pinecone):
|
|
51
|
+
raise ValueError("client must be an instance of Pinecone")
|
|
52
|
+
else:
|
|
53
|
+
self._client = client
|
|
54
|
+
|
|
55
|
+
self.n_results = config.get("n_results", 10)
|
|
56
|
+
self.dimensions = config.get("dimensions", 384)
|
|
57
|
+
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
|
|
58
|
+
self.documentation_namespace = config.get(
|
|
59
|
+
"documentation_namespace", "documentation"
|
|
60
|
+
)
|
|
61
|
+
self.distance_metric = config.get("distance_metric", "cosine")
|
|
62
|
+
self.ddl_namespace = config.get("ddl_namespace", "ddl")
|
|
63
|
+
self.sql_namespace = config.get("sql_namespace", "sql")
|
|
64
|
+
self.index_name = config.get("index_name", "vanna-index")
|
|
65
|
+
self.metadata_config = config.get("metadata_config", {})
|
|
66
|
+
self.server_type = config.get("server_type", "serverless")
|
|
67
|
+
if self.server_type not in ["serverless", "pod"]:
|
|
68
|
+
raise ValueError("server_type must be either 'serverless' or 'pod'")
|
|
69
|
+
self.podspec = config.get(
|
|
70
|
+
"podspec",
|
|
71
|
+
PodSpec(
|
|
72
|
+
environment="us-west-2",
|
|
73
|
+
pod_type="p1.x1",
|
|
74
|
+
metadata_config=self.metadata_config,
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
self.serverless_spec = config.get(
|
|
78
|
+
"serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
|
|
79
|
+
)
|
|
80
|
+
self._setup_index()
|
|
81
|
+
|
|
82
|
+
def _set_index_host(self, host: str) -> None:
|
|
83
|
+
self.Index = self._client.Index(host=host)
|
|
84
|
+
|
|
85
|
+
def _setup_index(self) -> None:
|
|
86
|
+
existing_indexes = self._get_indexes()
|
|
87
|
+
if self.index_name not in existing_indexes and self.server_type == "serverless":
|
|
88
|
+
self._client.create_index(
|
|
89
|
+
name=self.index_name,
|
|
90
|
+
dimension=self.dimensions,
|
|
91
|
+
metric=self.distance_metric,
|
|
92
|
+
spec=self.serverless_spec,
|
|
93
|
+
)
|
|
94
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
95
|
+
self._set_index_host(pinecone_index_host)
|
|
96
|
+
elif self.index_name not in existing_indexes and self.server_type == "pod":
|
|
97
|
+
self._client.create_index(
|
|
98
|
+
name=self.index_name,
|
|
99
|
+
dimension=self.dimensions,
|
|
100
|
+
metric=self.distance_metric,
|
|
101
|
+
spec=self.podspec,
|
|
102
|
+
)
|
|
103
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
104
|
+
self._set_index_host(pinecone_index_host)
|
|
105
|
+
else:
|
|
106
|
+
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
|
|
107
|
+
self._set_index_host(pinecone_index_host)
|
|
108
|
+
|
|
109
|
+
def _get_indexes(self) -> list:
|
|
110
|
+
return [index["name"] for index in self._client.list_indexes()]
|
|
111
|
+
|
|
112
|
+
def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
|
|
113
|
+
fetch_response = self.Index.fetch(ids=[id], namespace=namespace)
|
|
114
|
+
if fetch_response["vectors"] == {}:
|
|
115
|
+
return False
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
119
|
+
id = deterministic_uuid(ddl) + "-ddl"
|
|
120
|
+
if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
|
|
121
|
+
print(f"DDL with id: {id} already exists in the index. Skipping...")
|
|
122
|
+
return id
|
|
123
|
+
self.Index.upsert(
|
|
124
|
+
vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
|
|
125
|
+
namespace=self.ddl_namespace,
|
|
126
|
+
)
|
|
127
|
+
return id
|
|
128
|
+
|
|
129
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
130
|
+
id = deterministic_uuid(doc) + "-doc"
|
|
131
|
+
|
|
132
|
+
if self._check_if_embedding_exists(
|
|
133
|
+
id=id, namespace=self.documentation_namespace
|
|
134
|
+
):
|
|
135
|
+
print(
|
|
136
|
+
f"Documentation with id: {id} already exists in the index. Skipping..."
|
|
137
|
+
)
|
|
138
|
+
return id
|
|
139
|
+
self.Index.upsert(
|
|
140
|
+
vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
|
|
141
|
+
namespace=self.documentation_namespace,
|
|
142
|
+
)
|
|
143
|
+
return id
|
|
144
|
+
|
|
145
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
146
|
+
question_sql_json = json.dumps(
|
|
147
|
+
{
|
|
148
|
+
"question": question,
|
|
149
|
+
"sql": sql,
|
|
150
|
+
},
|
|
151
|
+
ensure_ascii=False,
|
|
152
|
+
)
|
|
153
|
+
id = deterministic_uuid(question_sql_json) + "-sql"
|
|
154
|
+
if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
|
|
155
|
+
print(
|
|
156
|
+
f"Question-SQL with id: {id} already exists in the index. Skipping..."
|
|
157
|
+
)
|
|
158
|
+
return id
|
|
159
|
+
self.Index.upsert(
|
|
160
|
+
vectors=[
|
|
161
|
+
(
|
|
162
|
+
id,
|
|
163
|
+
self.generate_embedding(question_sql_json),
|
|
164
|
+
{"sql": question_sql_json},
|
|
165
|
+
)
|
|
166
|
+
],
|
|
167
|
+
namespace=self.sql_namespace,
|
|
168
|
+
)
|
|
169
|
+
return id
|
|
170
|
+
|
|
171
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
172
|
+
res = self.Index.query(
|
|
173
|
+
namespace=self.ddl_namespace,
|
|
174
|
+
vector=self.generate_embedding(question),
|
|
175
|
+
top_k=self.n_results,
|
|
176
|
+
include_values=True,
|
|
177
|
+
include_metadata=True,
|
|
178
|
+
)
|
|
179
|
+
return [match["metadata"]["ddl"] for match in res["matches"]] if res else []
|
|
180
|
+
|
|
181
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
182
|
+
res = self.Index.query(
|
|
183
|
+
namespace=self.documentation_namespace,
|
|
184
|
+
vector=self.generate_embedding(question),
|
|
185
|
+
top_k=self.n_results,
|
|
186
|
+
include_values=True,
|
|
187
|
+
include_metadata=True,
|
|
188
|
+
)
|
|
189
|
+
return (
|
|
190
|
+
[match["metadata"]["documentation"] for match in res["matches"]]
|
|
191
|
+
if res
|
|
192
|
+
else []
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
196
|
+
res = self.Index.query(
|
|
197
|
+
namespace=self.sql_namespace,
|
|
198
|
+
vector=self.generate_embedding(question),
|
|
199
|
+
top_k=self.n_results,
|
|
200
|
+
include_values=True,
|
|
201
|
+
include_metadata=True,
|
|
202
|
+
)
|
|
203
|
+
return (
|
|
204
|
+
[
|
|
205
|
+
{
|
|
206
|
+
key: value
|
|
207
|
+
for key, value in json.loads(match["metadata"]["sql"]).items()
|
|
208
|
+
}
|
|
209
|
+
for match in res["matches"]
|
|
210
|
+
]
|
|
211
|
+
if res
|
|
212
|
+
else []
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
216
|
+
# Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector
|
|
217
|
+
df = pd.DataFrame()
|
|
218
|
+
namespaces = {
|
|
219
|
+
"sql": self.sql_namespace,
|
|
220
|
+
"ddl": self.ddl_namespace,
|
|
221
|
+
"documentation": self.documentation_namespace,
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
for data_type, namespace in namespaces.items():
|
|
225
|
+
data = self.Index.query(
|
|
226
|
+
top_k=10000, # max results that pinecone allows
|
|
227
|
+
namespace=namespace,
|
|
228
|
+
include_values=True,
|
|
229
|
+
include_metadata=True,
|
|
230
|
+
vector=[0.0] * self.dimensions,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if data is not None:
|
|
234
|
+
id_list = [match["id"] for match in data["matches"]]
|
|
235
|
+
content_list = [
|
|
236
|
+
match["metadata"][data_type] for match in data["matches"]
|
|
237
|
+
]
|
|
238
|
+
question_list = [
|
|
239
|
+
(
|
|
240
|
+
json.loads(match["metadata"][data_type])["question"]
|
|
241
|
+
if data_type == "sql"
|
|
242
|
+
else None
|
|
243
|
+
)
|
|
244
|
+
for match in data["matches"]
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
df_data = pd.DataFrame(
|
|
248
|
+
{
|
|
249
|
+
"id": id_list,
|
|
250
|
+
"question": question_list,
|
|
251
|
+
"content": content_list,
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
df_data["training_data_type"] = data_type
|
|
255
|
+
df = pd.concat([df, df_data])
|
|
256
|
+
|
|
257
|
+
return df
|
|
258
|
+
|
|
259
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
260
|
+
if id.endswith("-sql"):
|
|
261
|
+
self.Index.delete(ids=[id], namespace=self.sql_namespace)
|
|
262
|
+
return True
|
|
263
|
+
elif id.endswith("-ddl"):
|
|
264
|
+
self.Index.delete(ids=[id], namespace=self.ddl_namespace)
|
|
265
|
+
return True
|
|
266
|
+
elif id.endswith("-doc"):
|
|
267
|
+
self.Index.delete(ids=[id], namespace=self.documentation_namespace)
|
|
268
|
+
return True
|
|
269
|
+
else:
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
273
|
+
embedding_model = TextEmbedding(model_name=self.fastembed_model)
|
|
274
|
+
embedding = next(embedding_model.embed(data))
|
|
275
|
+
return embedding.tolist()
|
vanna/qdrant/qdrant.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import List, Tuple
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from qdrant_client import QdrantClient, grpc, models
|
|
6
|
+
from qdrant_client.http.models.models import UpdateStatus
|
|
6
7
|
|
|
7
8
|
from ..base import VannaBase
|
|
8
9
|
from ..utils import deterministic_uuid
|
|
@@ -210,7 +211,8 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
210
211
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
211
212
|
try:
|
|
212
213
|
id, collection_name = self._parse_point_id(id)
|
|
213
|
-
self._client.delete(collection_name, points_selector=[id])
|
|
214
|
+
res = self._client.delete(collection_name, points_selector=[id])
|
|
215
|
+
res == UpdateStatus.COMPLETED
|
|
214
216
|
except ValueError:
|
|
215
217
|
return False
|
|
216
218
|
|
vanna/vannadb/vannadb_vector.py
CHANGED
|
@@ -5,6 +5,7 @@ from io import StringIO
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import requests
|
|
7
7
|
|
|
8
|
+
from ..advanced import VannaAdvanced
|
|
8
9
|
from ..base import VannaBase
|
|
9
10
|
from ..types import (
|
|
10
11
|
DataFrameJSON,
|
|
@@ -20,7 +21,7 @@ from ..types import (
|
|
|
20
21
|
from ..utils import sanitize_model_name
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
class VannaDB_VectorStore(VannaBase):
|
|
24
|
+
class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
24
25
|
def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
|
|
25
26
|
VannaBase.__init__(self, config=config)
|
|
26
27
|
|
|
@@ -33,6 +34,12 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
33
34
|
else config["endpoint"]
|
|
34
35
|
)
|
|
35
36
|
self.related_training_data = {}
|
|
37
|
+
self._graphql_endpoint = "https://functionrag.com/query"
|
|
38
|
+
self._graphql_headers = {
|
|
39
|
+
"Content-Type": "application/json",
|
|
40
|
+
"API-KEY": self._api_key,
|
|
41
|
+
"NAMESPACE": self._model,
|
|
42
|
+
}
|
|
36
43
|
|
|
37
44
|
def _rpc_call(self, method, params):
|
|
38
45
|
if method != "list_orgs":
|
|
@@ -59,6 +66,177 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
59
66
|
def _dataclass_to_dict(self, obj):
|
|
60
67
|
return dataclasses.asdict(obj)
|
|
61
68
|
|
|
69
|
+
def get_all_functions(self) -> list:
|
|
70
|
+
query = """
|
|
71
|
+
{
|
|
72
|
+
get_all_sql_functions {
|
|
73
|
+
function_name
|
|
74
|
+
description
|
|
75
|
+
post_processing_code_template
|
|
76
|
+
arguments {
|
|
77
|
+
name
|
|
78
|
+
description
|
|
79
|
+
general_type
|
|
80
|
+
is_user_editable
|
|
81
|
+
available_values
|
|
82
|
+
}
|
|
83
|
+
sql_template
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
|
|
89
|
+
response_json = response.json()
|
|
90
|
+
if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
|
|
91
|
+
self.log(response_json['data']['get_all_sql_functions'])
|
|
92
|
+
resp = response_json['data']['get_all_sql_functions']
|
|
93
|
+
|
|
94
|
+
print(resp)
|
|
95
|
+
|
|
96
|
+
return resp
|
|
97
|
+
else:
|
|
98
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
102
|
+
query = """
|
|
103
|
+
query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
|
|
104
|
+
get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
|
|
105
|
+
... on SQLFunction {
|
|
106
|
+
function_name
|
|
107
|
+
description
|
|
108
|
+
post_processing_code_template
|
|
109
|
+
instantiated_post_processing_code
|
|
110
|
+
arguments {
|
|
111
|
+
name
|
|
112
|
+
description
|
|
113
|
+
general_type
|
|
114
|
+
is_user_editable
|
|
115
|
+
instantiated_value
|
|
116
|
+
available_values
|
|
117
|
+
}
|
|
118
|
+
sql_template
|
|
119
|
+
instantiated_sql
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
"""
|
|
124
|
+
static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
|
|
125
|
+
variables = {"question": question, "staticFunctionArguments": static_function_arguments}
|
|
126
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
127
|
+
response_json = response.json()
|
|
128
|
+
if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
|
|
129
|
+
self.log(response_json['data']['get_and_instantiate_function'])
|
|
130
|
+
resp = response_json['data']['get_and_instantiate_function']
|
|
131
|
+
|
|
132
|
+
print(resp)
|
|
133
|
+
|
|
134
|
+
return resp
|
|
135
|
+
else:
|
|
136
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
137
|
+
|
|
138
|
+
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
|
|
139
|
+
query = """
|
|
140
|
+
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
|
|
141
|
+
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
|
|
142
|
+
function_name
|
|
143
|
+
description
|
|
144
|
+
arguments {
|
|
145
|
+
name
|
|
146
|
+
description
|
|
147
|
+
general_type
|
|
148
|
+
is_user_editable
|
|
149
|
+
}
|
|
150
|
+
sql_template
|
|
151
|
+
post_processing_code_template
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
"""
|
|
155
|
+
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
|
|
156
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
157
|
+
response_json = response.json()
|
|
158
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
|
|
159
|
+
resp = response_json['data']['generate_and_create_sql_function']
|
|
160
|
+
|
|
161
|
+
print(resp)
|
|
162
|
+
|
|
163
|
+
return resp
|
|
164
|
+
else:
|
|
165
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
166
|
+
|
|
167
|
+
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Update an existing SQL function based on the provided parameters.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
old_function_name (str): The current name of the function to be updated.
|
|
173
|
+
updated_function (dict): A dictionary containing the updated function details. Expected keys:
|
|
174
|
+
- 'function_name': The new name of the function.
|
|
175
|
+
- 'description': The new description of the function.
|
|
176
|
+
- 'arguments': A list of dictionaries describing the function arguments.
|
|
177
|
+
- 'sql_template': The new SQL template for the function.
|
|
178
|
+
- 'post_processing_code_template': The new post-processing code template.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
bool: True if the function was successfully updated, False otherwise.
|
|
182
|
+
"""
|
|
183
|
+
mutation = """
|
|
184
|
+
mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
|
|
185
|
+
update_sql_function(input: $input)
|
|
186
|
+
}
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
SQLFunctionUpdate = {
|
|
190
|
+
'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Define the expected keys for each argument in the arguments list
|
|
194
|
+
ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
|
|
195
|
+
|
|
196
|
+
# Function to validate and transform arguments
|
|
197
|
+
def validate_arguments(args):
|
|
198
|
+
return [
|
|
199
|
+
{key: arg[key] for key in arg if key in ArgumentKeys}
|
|
200
|
+
for arg in args
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
|
|
204
|
+
updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
|
|
205
|
+
|
|
206
|
+
# Special handling for 'arguments' to ensure they conform to the spec
|
|
207
|
+
if 'arguments' in updated_function:
|
|
208
|
+
updated_function['arguments'] = validate_arguments(updated_function['arguments'])
|
|
209
|
+
|
|
210
|
+
variables = {
|
|
211
|
+
"input": {
|
|
212
|
+
"old_function_name": old_function_name,
|
|
213
|
+
**updated_function
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
print("variables", variables)
|
|
218
|
+
|
|
219
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
220
|
+
response_json = response.json()
|
|
221
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
|
|
222
|
+
return response_json['data']['update_sql_function']
|
|
223
|
+
else:
|
|
224
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
225
|
+
|
|
226
|
+
def delete_function(self, function_name: str) -> bool:
|
|
227
|
+
mutation = """
|
|
228
|
+
mutation DeleteSQLFunction($function_name: String!) {
|
|
229
|
+
delete_sql_function(function_name: $function_name)
|
|
230
|
+
}
|
|
231
|
+
"""
|
|
232
|
+
variables = {"function_name": function_name}
|
|
233
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
234
|
+
response_json = response.json()
|
|
235
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
|
|
236
|
+
return response_json['data']['delete_sql_function']
|
|
237
|
+
else:
|
|
238
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
239
|
+
|
|
62
240
|
def create_model(self, model: str, **kwargs) -> bool:
|
|
63
241
|
"""
|
|
64
242
|
**Example:**
|