vanna 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/advanced/__init__.py +26 -0
- vanna/base/base.py +25 -24
- vanna/flask/__init__.py +124 -10
- vanna/flask/assets.py +36 -16
- vanna/milvus/__init__.py +1 -0
- vanna/milvus/milvus_vector.py +305 -0
- vanna/qdrant/qdrant.py +12 -14
- vanna/vannadb/vannadb_vector.py +179 -1
- vanna/vllm/vllm.py +16 -1
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/METADATA +5 -2
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/RECORD +12 -9
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/WHEEL +0 -0
vanna/milvus/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .milvus_vector import Milvus_VectorStore
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from pymilvus import DataType, MilvusClient, model
|
|
6
|
+
|
|
7
|
+
from ..base import VannaBase
|
|
8
|
+
|
|
9
|
+
# Setting the URI as a local file, e.g.`./milvus.db`,
|
|
10
|
+
# is the most convenient method, as it automatically utilizes Milvus Lite
|
|
11
|
+
# to store all data in this file.
|
|
12
|
+
#
|
|
13
|
+
# If you have large scale of data such as more than a million docs, we
|
|
14
|
+
# recommend setting up a more performant Milvus server on docker or kubernetes.
|
|
15
|
+
# When using this setup, please use the server URI,
|
|
16
|
+
# e.g.`http://localhost:19530`, as your URI.
|
|
17
|
+
|
|
18
|
+
DEFAULT_MILVUS_URI = "./milvus.db"
|
|
19
|
+
# DEFAULT_MILVUS_URI = "http://localhost:19530"
|
|
20
|
+
|
|
21
|
+
MAX_LIMIT_SIZE = 10_000
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Milvus_VectorStore(VannaBase):
|
|
25
|
+
"""
|
|
26
|
+
Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
- config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`.
|
|
30
|
+
- milvus_client: A `pymilvus.MilvusClient` instance.
|
|
31
|
+
- embedding_function:
|
|
32
|
+
A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`.
|
|
33
|
+
For more models, please refer to:
|
|
34
|
+
https://milvus.io/docs/embeddings.md
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, config=None):
|
|
37
|
+
VannaBase.__init__(self, config=config)
|
|
38
|
+
|
|
39
|
+
if "milvus_client" in config:
|
|
40
|
+
self.milvus_client = config["milvus_client"]
|
|
41
|
+
else:
|
|
42
|
+
self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI)
|
|
43
|
+
|
|
44
|
+
if "embedding_function" in config:
|
|
45
|
+
self.embedding_function = config.get("embedding_function")
|
|
46
|
+
else:
|
|
47
|
+
self.embedding_function = model.DefaultEmbeddingFunction()
|
|
48
|
+
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
|
|
49
|
+
self._create_collections()
|
|
50
|
+
self.n_results = config.get("n_results", 10)
|
|
51
|
+
|
|
52
|
+
def _create_collections(self):
|
|
53
|
+
self._create_sql_collection("vannasql")
|
|
54
|
+
self._create_ddl_collection("vannaddl")
|
|
55
|
+
self._create_doc_collection("vannadoc")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
59
|
+
return self.embedding_function.encode_documents(data).tolist()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _create_sql_collection(self, name: str):
|
|
63
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
64
|
+
vannasql_schema = MilvusClient.create_schema(
|
|
65
|
+
auto_id=False,
|
|
66
|
+
enable_dynamic_field=False,
|
|
67
|
+
)
|
|
68
|
+
vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
69
|
+
vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
|
|
70
|
+
vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
|
|
71
|
+
vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
72
|
+
|
|
73
|
+
vannasql_index_params = self.milvus_client.prepare_index_params()
|
|
74
|
+
vannasql_index_params.add_index(
|
|
75
|
+
field_name="vector",
|
|
76
|
+
index_name="vector",
|
|
77
|
+
index_type="AUTOINDEX",
|
|
78
|
+
metric_type="L2",
|
|
79
|
+
)
|
|
80
|
+
self.milvus_client.create_collection(
|
|
81
|
+
collection_name=name,
|
|
82
|
+
schema=vannasql_schema,
|
|
83
|
+
index_params=vannasql_index_params,
|
|
84
|
+
consistency_level="Strong"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def _create_ddl_collection(self, name: str):
|
|
88
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
89
|
+
vannaddl_schema = MilvusClient.create_schema(
|
|
90
|
+
auto_id=False,
|
|
91
|
+
enable_dynamic_field=False,
|
|
92
|
+
)
|
|
93
|
+
vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
94
|
+
vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
|
|
95
|
+
vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
96
|
+
|
|
97
|
+
vannaddl_index_params = self.milvus_client.prepare_index_params()
|
|
98
|
+
vannaddl_index_params.add_index(
|
|
99
|
+
field_name="vector",
|
|
100
|
+
index_name="vector",
|
|
101
|
+
index_type="AUTOINDEX",
|
|
102
|
+
metric_type="L2",
|
|
103
|
+
)
|
|
104
|
+
self.milvus_client.create_collection(
|
|
105
|
+
collection_name=name,
|
|
106
|
+
schema=vannaddl_schema,
|
|
107
|
+
index_params=vannaddl_index_params,
|
|
108
|
+
consistency_level="Strong"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def _create_doc_collection(self, name: str):
|
|
112
|
+
if not self.milvus_client.has_collection(collection_name=name):
|
|
113
|
+
vannadoc_schema = MilvusClient.create_schema(
|
|
114
|
+
auto_id=False,
|
|
115
|
+
enable_dynamic_field=False,
|
|
116
|
+
)
|
|
117
|
+
vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
|
|
118
|
+
vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
|
119
|
+
vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)
|
|
120
|
+
|
|
121
|
+
vannadoc_index_params = self.milvus_client.prepare_index_params()
|
|
122
|
+
vannadoc_index_params.add_index(
|
|
123
|
+
field_name="vector",
|
|
124
|
+
index_name="vector",
|
|
125
|
+
index_type="AUTOINDEX",
|
|
126
|
+
metric_type="L2",
|
|
127
|
+
)
|
|
128
|
+
self.milvus_client.create_collection(
|
|
129
|
+
collection_name=name,
|
|
130
|
+
schema=vannadoc_schema,
|
|
131
|
+
index_params=vannadoc_index_params,
|
|
132
|
+
consistency_level="Strong"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
136
|
+
if len(question) == 0 or len(sql) == 0:
|
|
137
|
+
raise Exception("pair of question and sql can not be null")
|
|
138
|
+
_id = str(uuid.uuid4()) + "-sql"
|
|
139
|
+
embedding = self.embedding_function.encode_documents([question])[0]
|
|
140
|
+
self.milvus_client.insert(
|
|
141
|
+
collection_name="vannasql",
|
|
142
|
+
data={
|
|
143
|
+
"id": _id,
|
|
144
|
+
"text": question,
|
|
145
|
+
"sql": sql,
|
|
146
|
+
"vector": embedding
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
return _id
|
|
150
|
+
|
|
151
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
152
|
+
if len(ddl) == 0:
|
|
153
|
+
raise Exception("ddl can not be null")
|
|
154
|
+
_id = str(uuid.uuid4()) + "-ddl"
|
|
155
|
+
embedding = self.embedding_function.encode_documents([ddl])[0]
|
|
156
|
+
self.milvus_client.insert(
|
|
157
|
+
collection_name="vannaddl",
|
|
158
|
+
data={
|
|
159
|
+
"id": _id,
|
|
160
|
+
"ddl": ddl,
|
|
161
|
+
"vector": embedding
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
return _id
|
|
165
|
+
|
|
166
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
167
|
+
if len(documentation) == 0:
|
|
168
|
+
raise Exception("documentation can not be null")
|
|
169
|
+
_id = str(uuid.uuid4()) + "-doc"
|
|
170
|
+
embedding = self.embedding_function.encode_documents([documentation])[0]
|
|
171
|
+
self.milvus_client.insert(
|
|
172
|
+
collection_name="vannadoc",
|
|
173
|
+
data={
|
|
174
|
+
"id": _id,
|
|
175
|
+
"doc": documentation,
|
|
176
|
+
"vector": embedding
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
return _id
|
|
180
|
+
|
|
181
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
182
|
+
sql_data = self.milvus_client.query(
|
|
183
|
+
collection_name="vannasql",
|
|
184
|
+
output_fields=["*"],
|
|
185
|
+
limit=MAX_LIMIT_SIZE,
|
|
186
|
+
)
|
|
187
|
+
df = pd.DataFrame()
|
|
188
|
+
df_sql = pd.DataFrame(
|
|
189
|
+
{
|
|
190
|
+
"id": [doc["id"] for doc in sql_data],
|
|
191
|
+
"question": [doc["text"] for doc in sql_data],
|
|
192
|
+
"content": [doc["sql"] for doc in sql_data],
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
df = pd.concat([df, df_sql])
|
|
196
|
+
|
|
197
|
+
ddl_data = self.milvus_client.query(
|
|
198
|
+
collection_name="vannaddl",
|
|
199
|
+
output_fields=["*"],
|
|
200
|
+
limit=MAX_LIMIT_SIZE,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
df_ddl = pd.DataFrame(
|
|
204
|
+
{
|
|
205
|
+
"id": [doc["id"] for doc in ddl_data],
|
|
206
|
+
"question": [None for doc in ddl_data],
|
|
207
|
+
"content": [doc["ddl"] for doc in ddl_data],
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
df = pd.concat([df, df_ddl])
|
|
211
|
+
|
|
212
|
+
doc_data = self.milvus_client.query(
|
|
213
|
+
collection_name="vannadoc",
|
|
214
|
+
output_fields=["*"],
|
|
215
|
+
limit=MAX_LIMIT_SIZE,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
df_doc = pd.DataFrame(
|
|
219
|
+
{
|
|
220
|
+
"id": [doc["id"] for doc in doc_data],
|
|
221
|
+
"question": [None for doc in doc_data],
|
|
222
|
+
"content": [doc["doc"] for doc in doc_data],
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
df = pd.concat([df, df_doc])
|
|
226
|
+
return df
|
|
227
|
+
|
|
228
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
229
|
+
search_params = {
|
|
230
|
+
"metric_type": "L2",
|
|
231
|
+
"params": {"nprobe": 128},
|
|
232
|
+
}
|
|
233
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
234
|
+
res = self.milvus_client.search(
|
|
235
|
+
collection_name="vannasql",
|
|
236
|
+
anns_field="vector",
|
|
237
|
+
data=embeddings,
|
|
238
|
+
limit=self.n_results,
|
|
239
|
+
output_fields=["text", "sql"],
|
|
240
|
+
search_params=search_params
|
|
241
|
+
)
|
|
242
|
+
res = res[0]
|
|
243
|
+
|
|
244
|
+
list_sql = []
|
|
245
|
+
for doc in res:
|
|
246
|
+
dict = {}
|
|
247
|
+
dict["question"] = doc["entity"]["text"]
|
|
248
|
+
dict["sql"] = doc["entity"]["sql"]
|
|
249
|
+
list_sql.append(dict)
|
|
250
|
+
return list_sql
|
|
251
|
+
|
|
252
|
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
253
|
+
search_params = {
|
|
254
|
+
"metric_type": "L2",
|
|
255
|
+
"params": {"nprobe": 128},
|
|
256
|
+
}
|
|
257
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
258
|
+
res = self.milvus_client.search(
|
|
259
|
+
collection_name="vannaddl",
|
|
260
|
+
anns_field="vector",
|
|
261
|
+
data=embeddings,
|
|
262
|
+
limit=self.n_results,
|
|
263
|
+
output_fields=["ddl"],
|
|
264
|
+
search_params=search_params
|
|
265
|
+
)
|
|
266
|
+
res = res[0]
|
|
267
|
+
|
|
268
|
+
list_ddl = []
|
|
269
|
+
for doc in res:
|
|
270
|
+
list_ddl.append(doc["entity"]["ddl"])
|
|
271
|
+
return list_ddl
|
|
272
|
+
|
|
273
|
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
274
|
+
search_params = {
|
|
275
|
+
"metric_type": "L2",
|
|
276
|
+
"params": {"nprobe": 128},
|
|
277
|
+
}
|
|
278
|
+
embeddings = self.embedding_function.encode_queries([question])
|
|
279
|
+
res = self.milvus_client.search(
|
|
280
|
+
collection_name="vannadoc",
|
|
281
|
+
anns_field="vector",
|
|
282
|
+
data=embeddings,
|
|
283
|
+
limit=self.n_results,
|
|
284
|
+
output_fields=["doc"],
|
|
285
|
+
search_params=search_params
|
|
286
|
+
)
|
|
287
|
+
res = res[0]
|
|
288
|
+
|
|
289
|
+
list_doc = []
|
|
290
|
+
for doc in res:
|
|
291
|
+
list_doc.append(doc["entity"]["doc"])
|
|
292
|
+
return list_doc
|
|
293
|
+
|
|
294
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
295
|
+
if id.endswith("-sql"):
|
|
296
|
+
self.milvus_client.delete(collection_name="vannasql", ids=[id])
|
|
297
|
+
return True
|
|
298
|
+
elif id.endswith("-ddl"):
|
|
299
|
+
self.milvus_client.delete(collection_name="vannaddl", ids=[id])
|
|
300
|
+
return True
|
|
301
|
+
elif id.endswith("-doc"):
|
|
302
|
+
self.milvus_client.delete(collection_name="vannadoc", ids=[id])
|
|
303
|
+
return True
|
|
304
|
+
else:
|
|
305
|
+
return False
|
vanna/qdrant/qdrant.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import List, Tuple
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from qdrant_client import QdrantClient, grpc, models
|
|
6
|
+
from qdrant_client.http.models.models import UpdateStatus
|
|
6
7
|
|
|
7
8
|
from ..base import VannaBase
|
|
8
9
|
from ..utils import deterministic_uuid
|
|
@@ -38,16 +39,6 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
38
39
|
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
|
-
documentation_collection_name = "documentation"
|
|
42
|
-
ddl_collection_name = "ddl"
|
|
43
|
-
sql_collection_name = "sql"
|
|
44
|
-
|
|
45
|
-
id_suffixes = {
|
|
46
|
-
ddl_collection_name: "ddl",
|
|
47
|
-
documentation_collection_name: "doc",
|
|
48
|
-
sql_collection_name: "sql",
|
|
49
|
-
}
|
|
50
|
-
|
|
51
42
|
def __init__(
|
|
52
43
|
self,
|
|
53
44
|
config={},
|
|
@@ -79,15 +70,21 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
79
70
|
self.collection_params = config.get("collection_params", {})
|
|
80
71
|
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
|
|
81
72
|
self.documentation_collection_name = config.get(
|
|
82
|
-
"documentation_collection_name",
|
|
73
|
+
"documentation_collection_name", "documentation"
|
|
83
74
|
)
|
|
84
75
|
self.ddl_collection_name = config.get(
|
|
85
|
-
"ddl_collection_name",
|
|
76
|
+
"ddl_collection_name", "ddl"
|
|
86
77
|
)
|
|
87
78
|
self.sql_collection_name = config.get(
|
|
88
|
-
"sql_collection_name",
|
|
79
|
+
"sql_collection_name", "sql"
|
|
89
80
|
)
|
|
90
81
|
|
|
82
|
+
self.id_suffixes = {
|
|
83
|
+
self.ddl_collection_name: "ddl",
|
|
84
|
+
self.documentation_collection_name: "doc",
|
|
85
|
+
self.sql_collection_name: "sql",
|
|
86
|
+
}
|
|
87
|
+
|
|
91
88
|
self._setup_collections()
|
|
92
89
|
|
|
93
90
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
@@ -210,7 +207,8 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
210
207
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
211
208
|
try:
|
|
212
209
|
id, collection_name = self._parse_point_id(id)
|
|
213
|
-
self._client.delete(collection_name, points_selector=[id])
|
|
210
|
+
res = self._client.delete(collection_name, points_selector=[id])
|
|
211
|
+
return True
|
|
214
212
|
except ValueError:
|
|
215
213
|
return False
|
|
216
214
|
|
vanna/vannadb/vannadb_vector.py
CHANGED
|
@@ -5,6 +5,7 @@ from io import StringIO
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import requests
|
|
7
7
|
|
|
8
|
+
from ..advanced import VannaAdvanced
|
|
8
9
|
from ..base import VannaBase
|
|
9
10
|
from ..types import (
|
|
10
11
|
DataFrameJSON,
|
|
@@ -20,7 +21,7 @@ from ..types import (
|
|
|
20
21
|
from ..utils import sanitize_model_name
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
class VannaDB_VectorStore(VannaBase):
|
|
24
|
+
class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
24
25
|
def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
|
|
25
26
|
VannaBase.__init__(self, config=config)
|
|
26
27
|
|
|
@@ -33,6 +34,12 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
33
34
|
else config["endpoint"]
|
|
34
35
|
)
|
|
35
36
|
self.related_training_data = {}
|
|
37
|
+
self._graphql_endpoint = "https://functionrag.com/query"
|
|
38
|
+
self._graphql_headers = {
|
|
39
|
+
"Content-Type": "application/json",
|
|
40
|
+
"API-KEY": self._api_key,
|
|
41
|
+
"NAMESPACE": self._model,
|
|
42
|
+
}
|
|
36
43
|
|
|
37
44
|
def _rpc_call(self, method, params):
|
|
38
45
|
if method != "list_orgs":
|
|
@@ -59,6 +66,177 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
59
66
|
def _dataclass_to_dict(self, obj):
|
|
60
67
|
return dataclasses.asdict(obj)
|
|
61
68
|
|
|
69
|
+
def get_all_functions(self) -> list:
|
|
70
|
+
query = """
|
|
71
|
+
{
|
|
72
|
+
get_all_sql_functions {
|
|
73
|
+
function_name
|
|
74
|
+
description
|
|
75
|
+
post_processing_code_template
|
|
76
|
+
arguments {
|
|
77
|
+
name
|
|
78
|
+
description
|
|
79
|
+
general_type
|
|
80
|
+
is_user_editable
|
|
81
|
+
available_values
|
|
82
|
+
}
|
|
83
|
+
sql_template
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
|
|
89
|
+
response_json = response.json()
|
|
90
|
+
if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
|
|
91
|
+
self.log(response_json['data']['get_all_sql_functions'])
|
|
92
|
+
resp = response_json['data']['get_all_sql_functions']
|
|
93
|
+
|
|
94
|
+
print(resp)
|
|
95
|
+
|
|
96
|
+
return resp
|
|
97
|
+
else:
|
|
98
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
102
|
+
query = """
|
|
103
|
+
query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
|
|
104
|
+
get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
|
|
105
|
+
... on SQLFunction {
|
|
106
|
+
function_name
|
|
107
|
+
description
|
|
108
|
+
post_processing_code_template
|
|
109
|
+
instantiated_post_processing_code
|
|
110
|
+
arguments {
|
|
111
|
+
name
|
|
112
|
+
description
|
|
113
|
+
general_type
|
|
114
|
+
is_user_editable
|
|
115
|
+
instantiated_value
|
|
116
|
+
available_values
|
|
117
|
+
}
|
|
118
|
+
sql_template
|
|
119
|
+
instantiated_sql
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
"""
|
|
124
|
+
static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
|
|
125
|
+
variables = {"question": question, "staticFunctionArguments": static_function_arguments}
|
|
126
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
127
|
+
response_json = response.json()
|
|
128
|
+
if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
|
|
129
|
+
self.log(response_json['data']['get_and_instantiate_function'])
|
|
130
|
+
resp = response_json['data']['get_and_instantiate_function']
|
|
131
|
+
|
|
132
|
+
print(resp)
|
|
133
|
+
|
|
134
|
+
return resp
|
|
135
|
+
else:
|
|
136
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
137
|
+
|
|
138
|
+
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
|
|
139
|
+
query = """
|
|
140
|
+
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
|
|
141
|
+
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
|
|
142
|
+
function_name
|
|
143
|
+
description
|
|
144
|
+
arguments {
|
|
145
|
+
name
|
|
146
|
+
description
|
|
147
|
+
general_type
|
|
148
|
+
is_user_editable
|
|
149
|
+
}
|
|
150
|
+
sql_template
|
|
151
|
+
post_processing_code_template
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
"""
|
|
155
|
+
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
|
|
156
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
157
|
+
response_json = response.json()
|
|
158
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
|
|
159
|
+
resp = response_json['data']['generate_and_create_sql_function']
|
|
160
|
+
|
|
161
|
+
print(resp)
|
|
162
|
+
|
|
163
|
+
return resp
|
|
164
|
+
else:
|
|
165
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
166
|
+
|
|
167
|
+
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Update an existing SQL function based on the provided parameters.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
old_function_name (str): The current name of the function to be updated.
|
|
173
|
+
updated_function (dict): A dictionary containing the updated function details. Expected keys:
|
|
174
|
+
- 'function_name': The new name of the function.
|
|
175
|
+
- 'description': The new description of the function.
|
|
176
|
+
- 'arguments': A list of dictionaries describing the function arguments.
|
|
177
|
+
- 'sql_template': The new SQL template for the function.
|
|
178
|
+
- 'post_processing_code_template': The new post-processing code template.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
bool: True if the function was successfully updated, False otherwise.
|
|
182
|
+
"""
|
|
183
|
+
mutation = """
|
|
184
|
+
mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
|
|
185
|
+
update_sql_function(input: $input)
|
|
186
|
+
}
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
SQLFunctionUpdate = {
|
|
190
|
+
'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Define the expected keys for each argument in the arguments list
|
|
194
|
+
ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
|
|
195
|
+
|
|
196
|
+
# Function to validate and transform arguments
|
|
197
|
+
def validate_arguments(args):
|
|
198
|
+
return [
|
|
199
|
+
{key: arg[key] for key in arg if key in ArgumentKeys}
|
|
200
|
+
for arg in args
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
|
|
204
|
+
updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
|
|
205
|
+
|
|
206
|
+
# Special handling for 'arguments' to ensure they conform to the spec
|
|
207
|
+
if 'arguments' in updated_function:
|
|
208
|
+
updated_function['arguments'] = validate_arguments(updated_function['arguments'])
|
|
209
|
+
|
|
210
|
+
variables = {
|
|
211
|
+
"input": {
|
|
212
|
+
"old_function_name": old_function_name,
|
|
213
|
+
**updated_function
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
print("variables", variables)
|
|
218
|
+
|
|
219
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
220
|
+
response_json = response.json()
|
|
221
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
|
|
222
|
+
return response_json['data']['update_sql_function']
|
|
223
|
+
else:
|
|
224
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
225
|
+
|
|
226
|
+
def delete_function(self, function_name: str) -> bool:
|
|
227
|
+
mutation = """
|
|
228
|
+
mutation DeleteSQLFunction($function_name: String!) {
|
|
229
|
+
delete_sql_function(function_name: $function_name)
|
|
230
|
+
}
|
|
231
|
+
"""
|
|
232
|
+
variables = {"function_name": function_name}
|
|
233
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
234
|
+
response_json = response.json()
|
|
235
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
|
|
236
|
+
return response_json['data']['delete_sql_function']
|
|
237
|
+
else:
|
|
238
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
239
|
+
|
|
62
240
|
def create_model(self, model: str, **kwargs) -> bool:
|
|
63
241
|
"""
|
|
64
242
|
**Example:**
|
vanna/vllm/vllm.py
CHANGED
|
@@ -17,6 +17,11 @@ class Vllm(VannaBase):
|
|
|
17
17
|
else:
|
|
18
18
|
self.model = config["model"]
|
|
19
19
|
|
|
20
|
+
if "auth-key" in config:
|
|
21
|
+
self.auth_key = config["auth-key"]
|
|
22
|
+
else:
|
|
23
|
+
self.auth_key = None
|
|
24
|
+
|
|
20
25
|
def system_message(self, message: str) -> any:
|
|
21
26
|
return {"role": "system", "content": message}
|
|
22
27
|
|
|
@@ -67,7 +72,17 @@ class Vllm(VannaBase):
|
|
|
67
72
|
"messages": prompt,
|
|
68
73
|
}
|
|
69
74
|
|
|
70
|
-
|
|
75
|
+
if self.auth_key is not None:
|
|
76
|
+
headers = {
|
|
77
|
+
'Content-Type': 'application/json',
|
|
78
|
+
'Authorization': f'Bearer {self.auth_key}'
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
response = requests.post(url, headers=headers,json=data)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
response = requests.post(url, json=data)
|
|
71
86
|
|
|
72
87
|
response_dict = response.json()
|
|
73
88
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.1
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -39,16 +39,18 @@ Requires-Dist: opensearch-py ; extra == "all"
|
|
|
39
39
|
Requires-Dist: opensearch-dsl ; extra == "all"
|
|
40
40
|
Requires-Dist: transformers ; extra == "all"
|
|
41
41
|
Requires-Dist: pinecone-client ; extra == "all"
|
|
42
|
+
Requires-Dist: pymilvus[model] ; extra == "all"
|
|
42
43
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
43
44
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
44
45
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
45
|
-
Requires-Dist:
|
|
46
|
+
Requires-Dist: clickhouse_connect ; extra == "clickhouse"
|
|
46
47
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
47
48
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
48
49
|
Requires-Dist: google-generativeai ; extra == "google"
|
|
49
50
|
Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
50
51
|
Requires-Dist: transformers ; extra == "hf"
|
|
51
52
|
Requires-Dist: marqo ; extra == "marqo"
|
|
53
|
+
Requires-Dist: pymilvus[model] ; extra == "milvus"
|
|
52
54
|
Requires-Dist: mistralai ; extra == "mistralai"
|
|
53
55
|
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
54
56
|
Requires-Dist: ollama ; extra == "ollama"
|
|
@@ -78,6 +80,7 @@ Provides-Extra: gemini
|
|
|
78
80
|
Provides-Extra: google
|
|
79
81
|
Provides-Extra: hf
|
|
80
82
|
Provides-Extra: marqo
|
|
83
|
+
Provides-Extra: milvus
|
|
81
84
|
Provides-Extra: mistralai
|
|
82
85
|
Provides-Extra: mysql
|
|
83
86
|
Provides-Extra: ollama
|