vanna 0.5.5__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/qdrant/qdrant.py CHANGED
@@ -3,6 +3,7 @@ from typing import List, Tuple
3
3
 
4
4
  import pandas as pd
5
5
  from qdrant_client import QdrantClient, grpc, models
6
+ from qdrant_client.http.models.models import UpdateStatus
6
7
 
7
8
  from ..base import VannaBase
8
9
  from ..utils import deterministic_uuid
@@ -210,7 +211,8 @@ class Qdrant_VectorStore(VannaBase):
210
211
  def remove_training_data(self, id: str, **kwargs) -> bool:
211
212
  try:
212
213
  id, collection_name = self._parse_point_id(id)
213
- self._client.delete(collection_name, points_selector=[id])
214
+ res = self._client.delete(collection_name, points_selector=[id])
215
+ res == UpdateStatus.COMPLETED
214
216
  except ValueError:
215
217
  return False
216
218
 
@@ -5,6 +5,7 @@ from io import StringIO
5
5
  import pandas as pd
6
6
  import requests
7
7
 
8
+ from ..advanced import VannaAdvanced
8
9
  from ..base import VannaBase
9
10
  from ..types import (
10
11
  DataFrameJSON,
@@ -20,7 +21,7 @@ from ..types import (
20
21
  from ..utils import sanitize_model_name
21
22
 
22
23
 
23
- class VannaDB_VectorStore(VannaBase):
24
+ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
24
25
  def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
25
26
  VannaBase.__init__(self, config=config)
26
27
 
@@ -33,6 +34,12 @@ class VannaDB_VectorStore(VannaBase):
33
34
  else config["endpoint"]
34
35
  )
35
36
  self.related_training_data = {}
37
+ self._graphql_endpoint = "https://functionrag.com/query"
38
+ self._graphql_headers = {
39
+ "Content-Type": "application/json",
40
+ "API-KEY": self._api_key,
41
+ "NAMESPACE": self._model,
42
+ }
36
43
 
37
44
  def _rpc_call(self, method, params):
38
45
  if method != "list_orgs":
@@ -59,6 +66,177 @@ class VannaDB_VectorStore(VannaBase):
59
66
  def _dataclass_to_dict(self, obj):
60
67
  return dataclasses.asdict(obj)
61
68
 
69
+ def get_all_functions(self) -> list:
70
+ query = """
71
+ {
72
+ get_all_sql_functions {
73
+ function_name
74
+ description
75
+ post_processing_code_template
76
+ arguments {
77
+ name
78
+ description
79
+ general_type
80
+ is_user_editable
81
+ available_values
82
+ }
83
+ sql_template
84
+ }
85
+ }
86
+ """
87
+
88
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
89
+ response_json = response.json()
90
+ if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
91
+ self.log(response_json['data']['get_all_sql_functions'])
92
+ resp = response_json['data']['get_all_sql_functions']
93
+
94
+ print(resp)
95
+
96
+ return resp
97
+ else:
98
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
99
+
100
+
101
+ def get_function(self, question: str, additional_data: dict = {}) -> dict:
102
+ query = """
103
+ query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
104
+ get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
105
+ ... on SQLFunction {
106
+ function_name
107
+ description
108
+ post_processing_code_template
109
+ instantiated_post_processing_code
110
+ arguments {
111
+ name
112
+ description
113
+ general_type
114
+ is_user_editable
115
+ instantiated_value
116
+ available_values
117
+ }
118
+ sql_template
119
+ instantiated_sql
120
+ }
121
+ }
122
+ }
123
+ """
124
+ static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
125
+ variables = {"question": question, "staticFunctionArguments": static_function_arguments}
126
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
127
+ response_json = response.json()
128
+ if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
129
+ self.log(response_json['data']['get_and_instantiate_function'])
130
+ resp = response_json['data']['get_and_instantiate_function']
131
+
132
+ print(resp)
133
+
134
+ return resp
135
+ else:
136
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
137
+
138
+ def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
139
+ query = """
140
+ mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
141
+ generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
142
+ function_name
143
+ description
144
+ arguments {
145
+ name
146
+ description
147
+ general_type
148
+ is_user_editable
149
+ }
150
+ sql_template
151
+ post_processing_code_template
152
+ }
153
+ }
154
+ """
155
+ variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
156
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
157
+ response_json = response.json()
158
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
159
+ resp = response_json['data']['generate_and_create_sql_function']
160
+
161
+ print(resp)
162
+
163
+ return resp
164
+ else:
165
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
166
+
167
+ def update_function(self, old_function_name: str, updated_function: dict) -> bool:
168
+ """
169
+ Update an existing SQL function based on the provided parameters.
170
+
171
+ Args:
172
+ old_function_name (str): The current name of the function to be updated.
173
+ updated_function (dict): A dictionary containing the updated function details. Expected keys:
174
+ - 'function_name': The new name of the function.
175
+ - 'description': The new description of the function.
176
+ - 'arguments': A list of dictionaries describing the function arguments.
177
+ - 'sql_template': The new SQL template for the function.
178
+ - 'post_processing_code_template': The new post-processing code template.
179
+
180
+ Returns:
181
+ bool: True if the function was successfully updated, False otherwise.
182
+ """
183
+ mutation = """
184
+ mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
185
+ update_sql_function(input: $input)
186
+ }
187
+ """
188
+
189
+ SQLFunctionUpdate = {
190
+ 'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
191
+ }
192
+
193
+ # Define the expected keys for each argument in the arguments list
194
+ ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
195
+
196
+ # Function to validate and transform arguments
197
+ def validate_arguments(args):
198
+ return [
199
+ {key: arg[key] for key in arg if key in ArgumentKeys}
200
+ for arg in args
201
+ ]
202
+
203
+ # Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
204
+ updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
205
+
206
+ # Special handling for 'arguments' to ensure they conform to the spec
207
+ if 'arguments' in updated_function:
208
+ updated_function['arguments'] = validate_arguments(updated_function['arguments'])
209
+
210
+ variables = {
211
+ "input": {
212
+ "old_function_name": old_function_name,
213
+ **updated_function
214
+ }
215
+ }
216
+
217
+ print("variables", variables)
218
+
219
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
220
+ response_json = response.json()
221
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
222
+ return response_json['data']['update_sql_function']
223
+ else:
224
+ raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
225
+
226
+ def delete_function(self, function_name: str) -> bool:
227
+ mutation = """
228
+ mutation DeleteSQLFunction($function_name: String!) {
229
+ delete_sql_function(function_name: $function_name)
230
+ }
231
+ """
232
+ variables = {"function_name": function_name}
233
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
234
+ response_json = response.json()
235
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
236
+ return response_json['data']['delete_sql_function']
237
+ else:
238
+ raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
239
+
62
240
  def create_model(self, model: str, **kwargs) -> bool:
63
241
  """
64
242
  **Example:**
vanna/vllm/vllm.py CHANGED
@@ -17,6 +17,11 @@ class Vllm(VannaBase):
17
17
  else:
18
18
  self.model = config["model"]
19
19
 
20
+ if "auth-key" in config:
21
+ self.auth_key = config["auth-key"]
22
+ else:
23
+ self.auth_key = None
24
+
20
25
  def system_message(self, message: str) -> any:
21
26
  return {"role": "system", "content": message}
22
27
 
@@ -67,7 +72,17 @@ class Vllm(VannaBase):
67
72
  "messages": prompt,
68
73
  }
69
74
 
70
- response = requests.post(url, json=data)
75
+ if self.auth_key is not None:
76
+ headers = {
77
+ 'Content-Type': 'application/json',
78
+ 'Authorization': f'Bearer {self.auth_key}'
79
+ }
80
+
81
+ response = requests.post(url, headers=headers,json=data)
82
+
83
+
84
+ else:
85
+ response = requests.post(url, json=data)
71
86
 
72
87
  response_dict = response.json()
73
88
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.5.5
3
+ Version: 0.6.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -42,7 +42,7 @@ Requires-Dist: pinecone-client ; extra == "all"
42
42
  Requires-Dist: anthropic ; extra == "anthropic"
43
43
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
44
44
  Requires-Dist: chromadb ; extra == "chromadb"
45
- Requires-Dist: clickhouse_driver ; extra == "clickhouse"
45
+ Requires-Dist: clickhouse_connect ; extra == "clickhouse"
46
46
  Requires-Dist: duckdb ; extra == "duckdb"
47
47
  Requires-Dist: google-generativeai ; extra == "gemini"
48
48
  Requires-Dist: google-generativeai ; extra == "google"
@@ -5,15 +5,16 @@ vanna/utils.py,sha256=cs0B_0MwhmPI2nWjVHifDYCmCR0kkddylQ2vloaPDSw,2247
5
5
  vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=WtZKUBIwlNH0BGbb4lZbVR7pTWIrn7b4RLIk-7u0SuQ,8725
6
6
  vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oERsa8pE,2849
7
7
  vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
8
+ vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,658
8
9
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
9
10
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
10
11
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
11
- vanna/base/base.py,sha256=YCL9MhhrGeoVv9da85NdWvEQtnqfkWSyj5AZo_wQ0TU,70853
12
+ vanna/base/base.py,sha256=XiBJU4LOKoFnY1vnkr_jz3soAdVdoL0D1WIZdzZCVyU,70942
12
13
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
13
14
  vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
14
15
  vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
15
- vanna/flask/__init__.py,sha256=2t9DbgL2-Anm05ESmZ1WAYMd1VSkOUsTf_-exVCebwQ,25192
16
- vanna/flask/assets.py,sha256=7wMnTS-UPZkNL8SSTrd4U0amuMgeOIIj94WGzBQW1fw,187711
16
+ vanna/flask/__init__.py,sha256=wpAZg6emB7l8TmyMlxVLFMjxkUQReIIPTfqlkF-onlc,30267
17
+ vanna/flask/assets.py,sha256=_UoUr57sS0QL2BuTxAOe9k4yy8T7-fp2NpbRSVtW3IM,451769
17
18
  vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
18
19
  vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
19
20
  vanna/google/gemini_chat.py,sha256=ps3A-afFbCo3HeFTLL_nMoQO1PsGvRUUPRUppbMcDew,1584
@@ -37,12 +38,12 @@ vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8
37
38
  vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
38
39
  vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
39
40
  vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
40
- vanna/qdrant/qdrant.py,sha256=6M00nMiuOuftTDf3NsOrOcG9BA4DlIIDck2MNp9iEyg,12613
41
+ vanna/qdrant/qdrant.py,sha256=RyICUvOO_jt8u9MB4oIYhqv3BicZ0d9pQkSFwkIfUjg,12719
41
42
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
42
43
  vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
43
- vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
44
+ vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
44
45
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
45
- vanna/vllm/vllm.py,sha256=QerC3xF5eNzE_nGBDl6YrPYF4WYnjf0hHxxlDWdKX-0,2427
46
- vanna-0.5.5.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
47
- vanna-0.5.5.dist-info/METADATA,sha256=gmPrDKsawOtYazNHG5CQ_iURmdVLt1Jc8Rfwneue1w0,11505
48
- vanna-0.5.5.dist-info/RECORD,,
46
+ vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
47
+ vanna-0.6.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
48
+ vanna-0.6.0.dist-info/METADATA,sha256=Nu63PMm3HT3HAkpxKU4fte9CKpIDgnL90Ww2AljWSOs,11506
49
+ vanna-0.6.0.dist-info/RECORD,,
File without changes