vanna 0.5.5__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/advanced/__init__.py +26 -0
- vanna/base/base.py +24 -24
- vanna/flask/__init__.py +124 -10
- vanna/flask/assets.py +36 -16
- vanna/qdrant/qdrant.py +3 -1
- vanna/vannadb/vannadb_vector.py +179 -1
- vanna/vllm/vllm.py +16 -1
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/METADATA +2 -2
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/RECORD +10 -9
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/WHEEL +0 -0
vanna/qdrant/qdrant.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import List, Tuple
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from qdrant_client import QdrantClient, grpc, models
|
|
6
|
+
from qdrant_client.http.models.models import UpdateStatus
|
|
6
7
|
|
|
7
8
|
from ..base import VannaBase
|
|
8
9
|
from ..utils import deterministic_uuid
|
|
@@ -210,7 +211,8 @@ class Qdrant_VectorStore(VannaBase):
|
|
|
210
211
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
211
212
|
try:
|
|
212
213
|
id, collection_name = self._parse_point_id(id)
|
|
213
|
-
self._client.delete(collection_name, points_selector=[id])
|
|
214
|
+
res = self._client.delete(collection_name, points_selector=[id])
|
|
215
|
+
res == UpdateStatus.COMPLETED
|
|
214
216
|
except ValueError:
|
|
215
217
|
return False
|
|
216
218
|
|
vanna/vannadb/vannadb_vector.py
CHANGED
|
@@ -5,6 +5,7 @@ from io import StringIO
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import requests
|
|
7
7
|
|
|
8
|
+
from ..advanced import VannaAdvanced
|
|
8
9
|
from ..base import VannaBase
|
|
9
10
|
from ..types import (
|
|
10
11
|
DataFrameJSON,
|
|
@@ -20,7 +21,7 @@ from ..types import (
|
|
|
20
21
|
from ..utils import sanitize_model_name
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
class VannaDB_VectorStore(VannaBase):
|
|
24
|
+
class VannaDB_VectorStore(VannaBase, VannaAdvanced):
|
|
24
25
|
def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
|
|
25
26
|
VannaBase.__init__(self, config=config)
|
|
26
27
|
|
|
@@ -33,6 +34,12 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
33
34
|
else config["endpoint"]
|
|
34
35
|
)
|
|
35
36
|
self.related_training_data = {}
|
|
37
|
+
self._graphql_endpoint = "https://functionrag.com/query"
|
|
38
|
+
self._graphql_headers = {
|
|
39
|
+
"Content-Type": "application/json",
|
|
40
|
+
"API-KEY": self._api_key,
|
|
41
|
+
"NAMESPACE": self._model,
|
|
42
|
+
}
|
|
36
43
|
|
|
37
44
|
def _rpc_call(self, method, params):
|
|
38
45
|
if method != "list_orgs":
|
|
@@ -59,6 +66,177 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
59
66
|
def _dataclass_to_dict(self, obj):
|
|
60
67
|
return dataclasses.asdict(obj)
|
|
61
68
|
|
|
69
|
+
def get_all_functions(self) -> list:
|
|
70
|
+
query = """
|
|
71
|
+
{
|
|
72
|
+
get_all_sql_functions {
|
|
73
|
+
function_name
|
|
74
|
+
description
|
|
75
|
+
post_processing_code_template
|
|
76
|
+
arguments {
|
|
77
|
+
name
|
|
78
|
+
description
|
|
79
|
+
general_type
|
|
80
|
+
is_user_editable
|
|
81
|
+
available_values
|
|
82
|
+
}
|
|
83
|
+
sql_template
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
|
|
89
|
+
response_json = response.json()
|
|
90
|
+
if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
|
|
91
|
+
self.log(response_json['data']['get_all_sql_functions'])
|
|
92
|
+
resp = response_json['data']['get_all_sql_functions']
|
|
93
|
+
|
|
94
|
+
print(resp)
|
|
95
|
+
|
|
96
|
+
return resp
|
|
97
|
+
else:
|
|
98
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
102
|
+
query = """
|
|
103
|
+
query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
|
|
104
|
+
get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
|
|
105
|
+
... on SQLFunction {
|
|
106
|
+
function_name
|
|
107
|
+
description
|
|
108
|
+
post_processing_code_template
|
|
109
|
+
instantiated_post_processing_code
|
|
110
|
+
arguments {
|
|
111
|
+
name
|
|
112
|
+
description
|
|
113
|
+
general_type
|
|
114
|
+
is_user_editable
|
|
115
|
+
instantiated_value
|
|
116
|
+
available_values
|
|
117
|
+
}
|
|
118
|
+
sql_template
|
|
119
|
+
instantiated_sql
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
"""
|
|
124
|
+
static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
|
|
125
|
+
variables = {"question": question, "staticFunctionArguments": static_function_arguments}
|
|
126
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
127
|
+
response_json = response.json()
|
|
128
|
+
if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
|
|
129
|
+
self.log(response_json['data']['get_and_instantiate_function'])
|
|
130
|
+
resp = response_json['data']['get_and_instantiate_function']
|
|
131
|
+
|
|
132
|
+
print(resp)
|
|
133
|
+
|
|
134
|
+
return resp
|
|
135
|
+
else:
|
|
136
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
137
|
+
|
|
138
|
+
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
|
|
139
|
+
query = """
|
|
140
|
+
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
|
|
141
|
+
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
|
|
142
|
+
function_name
|
|
143
|
+
description
|
|
144
|
+
arguments {
|
|
145
|
+
name
|
|
146
|
+
description
|
|
147
|
+
general_type
|
|
148
|
+
is_user_editable
|
|
149
|
+
}
|
|
150
|
+
sql_template
|
|
151
|
+
post_processing_code_template
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
"""
|
|
155
|
+
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
|
|
156
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
|
|
157
|
+
response_json = response.json()
|
|
158
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
|
|
159
|
+
resp = response_json['data']['generate_and_create_sql_function']
|
|
160
|
+
|
|
161
|
+
print(resp)
|
|
162
|
+
|
|
163
|
+
return resp
|
|
164
|
+
else:
|
|
165
|
+
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
|
|
166
|
+
|
|
167
|
+
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Update an existing SQL function based on the provided parameters.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
old_function_name (str): The current name of the function to be updated.
|
|
173
|
+
updated_function (dict): A dictionary containing the updated function details. Expected keys:
|
|
174
|
+
- 'function_name': The new name of the function.
|
|
175
|
+
- 'description': The new description of the function.
|
|
176
|
+
- 'arguments': A list of dictionaries describing the function arguments.
|
|
177
|
+
- 'sql_template': The new SQL template for the function.
|
|
178
|
+
- 'post_processing_code_template': The new post-processing code template.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
bool: True if the function was successfully updated, False otherwise.
|
|
182
|
+
"""
|
|
183
|
+
mutation = """
|
|
184
|
+
mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
|
|
185
|
+
update_sql_function(input: $input)
|
|
186
|
+
}
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
SQLFunctionUpdate = {
|
|
190
|
+
'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
# Define the expected keys for each argument in the arguments list
|
|
194
|
+
ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
|
|
195
|
+
|
|
196
|
+
# Function to validate and transform arguments
|
|
197
|
+
def validate_arguments(args):
|
|
198
|
+
return [
|
|
199
|
+
{key: arg[key] for key in arg if key in ArgumentKeys}
|
|
200
|
+
for arg in args
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
|
|
204
|
+
updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
|
|
205
|
+
|
|
206
|
+
# Special handling for 'arguments' to ensure they conform to the spec
|
|
207
|
+
if 'arguments' in updated_function:
|
|
208
|
+
updated_function['arguments'] = validate_arguments(updated_function['arguments'])
|
|
209
|
+
|
|
210
|
+
variables = {
|
|
211
|
+
"input": {
|
|
212
|
+
"old_function_name": old_function_name,
|
|
213
|
+
**updated_function
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
print("variables", variables)
|
|
218
|
+
|
|
219
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
220
|
+
response_json = response.json()
|
|
221
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
|
|
222
|
+
return response_json['data']['update_sql_function']
|
|
223
|
+
else:
|
|
224
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
225
|
+
|
|
226
|
+
def delete_function(self, function_name: str) -> bool:
|
|
227
|
+
mutation = """
|
|
228
|
+
mutation DeleteSQLFunction($function_name: String!) {
|
|
229
|
+
delete_sql_function(function_name: $function_name)
|
|
230
|
+
}
|
|
231
|
+
"""
|
|
232
|
+
variables = {"function_name": function_name}
|
|
233
|
+
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
|
|
234
|
+
response_json = response.json()
|
|
235
|
+
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
|
|
236
|
+
return response_json['data']['delete_sql_function']
|
|
237
|
+
else:
|
|
238
|
+
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
|
|
239
|
+
|
|
62
240
|
def create_model(self, model: str, **kwargs) -> bool:
|
|
63
241
|
"""
|
|
64
242
|
**Example:**
|
vanna/vllm/vllm.py
CHANGED
|
@@ -17,6 +17,11 @@ class Vllm(VannaBase):
|
|
|
17
17
|
else:
|
|
18
18
|
self.model = config["model"]
|
|
19
19
|
|
|
20
|
+
if "auth-key" in config:
|
|
21
|
+
self.auth_key = config["auth-key"]
|
|
22
|
+
else:
|
|
23
|
+
self.auth_key = None
|
|
24
|
+
|
|
20
25
|
def system_message(self, message: str) -> any:
|
|
21
26
|
return {"role": "system", "content": message}
|
|
22
27
|
|
|
@@ -67,7 +72,17 @@ class Vllm(VannaBase):
|
|
|
67
72
|
"messages": prompt,
|
|
68
73
|
}
|
|
69
74
|
|
|
70
|
-
|
|
75
|
+
if self.auth_key is not None:
|
|
76
|
+
headers = {
|
|
77
|
+
'Content-Type': 'application/json',
|
|
78
|
+
'Authorization': f'Bearer {self.auth_key}'
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
response = requests.post(url, headers=headers,json=data)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
response = requests.post(url, json=data)
|
|
71
86
|
|
|
72
87
|
response_dict = response.json()
|
|
73
88
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -42,7 +42,7 @@ Requires-Dist: pinecone-client ; extra == "all"
|
|
|
42
42
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
43
43
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
44
44
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
45
|
-
Requires-Dist:
|
|
45
|
+
Requires-Dist: clickhouse_connect ; extra == "clickhouse"
|
|
46
46
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
47
47
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
48
48
|
Requires-Dist: google-generativeai ; extra == "google"
|
|
@@ -5,15 +5,16 @@ vanna/utils.py,sha256=cs0B_0MwhmPI2nWjVHifDYCmCR0kkddylQ2vloaPDSw,2247
|
|
|
5
5
|
vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=WtZKUBIwlNH0BGbb4lZbVR7pTWIrn7b4RLIk-7u0SuQ,8725
|
|
6
6
|
vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oERsa8pE,2849
|
|
7
7
|
vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
|
|
8
|
+
vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,658
|
|
8
9
|
vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
|
|
9
10
|
vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
|
|
10
11
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
11
|
-
vanna/base/base.py,sha256=
|
|
12
|
+
vanna/base/base.py,sha256=XiBJU4LOKoFnY1vnkr_jz3soAdVdoL0D1WIZdzZCVyU,70942
|
|
12
13
|
vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
|
|
13
14
|
vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
|
|
14
15
|
vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
|
|
15
|
-
vanna/flask/__init__.py,sha256=
|
|
16
|
-
vanna/flask/assets.py,sha256=
|
|
16
|
+
vanna/flask/__init__.py,sha256=wpAZg6emB7l8TmyMlxVLFMjxkUQReIIPTfqlkF-onlc,30267
|
|
17
|
+
vanna/flask/assets.py,sha256=_UoUr57sS0QL2BuTxAOe9k4yy8T7-fp2NpbRSVtW3IM,451769
|
|
17
18
|
vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
|
|
18
19
|
vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
|
|
19
20
|
vanna/google/gemini_chat.py,sha256=ps3A-afFbCo3HeFTLL_nMoQO1PsGvRUUPRUppbMcDew,1584
|
|
@@ -37,12 +38,12 @@ vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8
|
|
|
37
38
|
vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
|
|
38
39
|
vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
|
|
39
40
|
vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
|
|
40
|
-
vanna/qdrant/qdrant.py,sha256=
|
|
41
|
+
vanna/qdrant/qdrant.py,sha256=RyICUvOO_jt8u9MB4oIYhqv3BicZ0d9pQkSFwkIfUjg,12719
|
|
41
42
|
vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
|
|
42
43
|
vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
|
|
43
|
-
vanna/vannadb/vannadb_vector.py,sha256=
|
|
44
|
+
vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
|
|
44
45
|
vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
|
|
45
|
-
vanna/vllm/vllm.py,sha256=
|
|
46
|
-
vanna-0.
|
|
47
|
-
vanna-0.
|
|
48
|
-
vanna-0.
|
|
46
|
+
vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
|
|
47
|
+
vanna-0.6.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
48
|
+
vanna-0.6.0.dist-info/METADATA,sha256=Nu63PMm3HT3HAkpxKU4fte9CKpIDgnL90Ww2AljWSOs,11506
|
|
49
|
+
vanna-0.6.0.dist-info/RECORD,,
|
|
File without changes
|