vanna 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/mock/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .embedding import MockEmbedding
2
+ from .llm import MockLLM
3
+ from .vectordb import MockVectorDB
@@ -0,0 +1,11 @@
1
+ from typing import List
2
+
3
+ from ..base import VannaBase
4
+
5
+
6
+ class MockEmbedding(VannaBase):
7
+ def __init__(self, config=None):
8
+ pass
9
+
10
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
11
+ return [1.0, 2.0, 3.0, 4.0, 5.0]
vanna/mock/llm.py ADDED
@@ -0,0 +1,19 @@
1
+
2
+ from ..base import VannaBase
3
+
4
+
5
+ class MockLLM(VannaBase):
6
+ def __init__(self, config=None):
7
+ pass
8
+
9
+ def system_message(self, message: str) -> any:
10
+ return {"role": "system", "content": message}
11
+
12
+ def user_message(self, message: str) -> any:
13
+ return {"role": "user", "content": message}
14
+
15
+ def assistant_message(self, message: str) -> any:
16
+ return {"role": "assistant", "content": message}
17
+
18
+ def submit_prompt(self, prompt, **kwargs) -> str:
19
+ return "Mock LLM response"
vanna/mock/vectordb.py ADDED
@@ -0,0 +1,55 @@
1
+ import pandas as pd
2
+
3
+ from ..base import VannaBase
4
+
5
+
6
+ class MockVectorDB(VannaBase):
7
+ def __init__(self, config=None):
8
+ pass
9
+
10
+ def _get_id(self, value: str, **kwargs) -> str:
11
+ # Hash the value and return the ID
12
+ return str(hash(value))
13
+
14
+ def add_ddl(self, ddl: str, **kwargs) -> str:
15
+ return self._get_id(ddl)
16
+
17
+ def add_documentation(self, doc: str, **kwargs) -> str:
18
+ return self._get_id(doc)
19
+
20
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
21
+ return self._get_id(question)
22
+
23
+ def get_related_ddl(self, question: str, **kwargs) -> list:
24
+ return []
25
+
26
+ def get_related_documentation(self, question: str, **kwargs) -> list:
27
+ return []
28
+
29
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
30
+ return []
31
+
32
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
33
+ return pd.DataFrame({'id': {0: '19546-ddl',
34
+ 1: '91597-sql',
35
+ 2: '133976-sql',
36
+ 3: '59851-doc',
37
+ 4: '73046-sql'},
38
+ 'training_data_type': {0: 'ddl',
39
+ 1: 'sql',
40
+ 2: 'sql',
41
+ 3: 'documentation',
42
+ 4: 'sql'},
43
+ 'question': {0: None,
44
+ 1: 'What are the top selling genres?',
45
+ 2: 'What are the low 7 artists by sales?',
46
+ 3: None,
47
+ 4: 'What is the total sales for each customer?'},
48
+ 'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
49
+ 1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
50
+ 2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
51
+ 3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
52
+ 4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
53
+
54
+ def remove_training_data(id: str, **kwargs) -> bool:
55
+ return True
@@ -155,6 +155,11 @@ class OpenSearch_VectorStore(VannaBase):
155
155
  else:
156
156
  max_retries = 10
157
157
 
158
+ if config is not None and "es_http_compress" in config:
159
+ es_http_compress = config["es_http_compress"]
160
+ else:
161
+ es_http_compress = False
162
+
158
163
  print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
159
164
  " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
160
165
  verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
@@ -162,7 +167,7 @@ class OpenSearch_VectorStore(VannaBase):
162
167
  # Initialize the OpenSearch client by passing a list of URLs
163
168
  self.client = OpenSearch(
164
169
  hosts=[es_urls],
165
- http_compress=True,
170
+ http_compress=es_http_compress,
166
171
  use_ssl=ssl,
167
172
  verify_certs=verify_certs,
168
173
  timeout=timeout,
@@ -175,7 +180,7 @@ class OpenSearch_VectorStore(VannaBase):
175
180
  # Initialize the OpenSearch client by passing a host and port
176
181
  self.client = OpenSearch(
177
182
  hosts=[{'host': host, 'port': port}],
178
- http_compress=True,
183
+ http_compress=es_http_compress,
179
184
  use_ssl=ssl,
180
185
  verify_certs=verify_certs,
181
186
  timeout=timeout,
@@ -267,6 +272,7 @@ class OpenSearch_VectorStore(VannaBase):
267
272
  }
268
273
  }
269
274
  }
275
+ print(query)
270
276
  response = self.client.search(index=self.ddl_index, body=query,
271
277
  **kwargs)
272
278
  return [hit['_source']['ddl'] for hit in response['hits']['hits']]
@@ -279,6 +285,7 @@ class OpenSearch_VectorStore(VannaBase):
279
285
  }
280
286
  }
281
287
  }
288
+ print(query)
282
289
  response = self.client.search(index=self.document_index,
283
290
  body=query,
284
291
  **kwargs)
@@ -292,6 +299,7 @@ class OpenSearch_VectorStore(VannaBase):
292
299
  }
293
300
  }
294
301
  }
302
+ print(query)
295
303
  response = self.client.search(index=self.question_sql_index,
296
304
  body=query,
297
305
  **kwargs)
@@ -307,6 +315,7 @@ class OpenSearch_VectorStore(VannaBase):
307
315
  body={"query": {"match_all": {}}},
308
316
  size=1000
309
317
  )
318
+ print(query)
310
319
  # records = [hit['_source'] for hit in response['hits']['hits']]
311
320
  for hit in response['hits']['hits']:
312
321
  data.append(
@@ -0,0 +1,3 @@
1
+ from .pinecone_vector import PineconeDB_VectorStore
2
+
3
+ __all__ = ["PineconeDB_VectorStore"]
@@ -0,0 +1,275 @@
1
+ import json
2
+ from typing import List
3
+
4
+ from pinecone import Pinecone, PodSpec, ServerlessSpec
5
+ import pandas as pd
6
+ from ..base import VannaBase
7
+ from ..utils import deterministic_uuid
8
+
9
+ from fastembed import TextEmbedding
10
+
11
+
12
+ class PineconeDB_VectorStore(VannaBase):
13
+ """
14
+ Vectorstore using PineconeDB
15
+
16
+ Args:
17
+ config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config.
18
+ - client (Pinecone, optional): Pinecone client. Defaults to None.
19
+ - api_key (str, optional): Pinecone API key. Defaults to None.
20
+ - n_results (int, optional): Number of results to return. Defaults to 10.
21
+ - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5.
22
+ - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
23
+ - documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation".
24
+ - distance_metric (str, optional): Distance metric to use. Defaults to "cosine".
25
+ - ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl".
26
+ - sql_namespace (str, optional): Namespace for SQL. Defaults to "sql".
27
+ - index_name (str, optional): Name of the index. Defaults to "vanna-index".
28
+ - metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}.
29
+ - server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod".
30
+ - podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config).
31
+ - serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2").
32
+ Raises:
33
+ ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod".
34
+ """
35
+
36
+ def __init__(self, config=None):
37
+ VannaBase.__init__(self, config=config)
38
+ if config is None:
39
+ raise ValueError(
40
+ "config is required, pass either a Pinecone client or an API key in the config."
41
+ )
42
+ client = config.get("client")
43
+ api_key = config.get("api_key")
44
+ if not api_key and not client:
45
+ raise ValueError(
46
+ "api_key is required in config or pass a configured client"
47
+ )
48
+ if not client and api_key:
49
+ self._client = Pinecone(api_key=api_key)
50
+ elif not isinstance(client, Pinecone):
51
+ raise ValueError("client must be an instance of Pinecone")
52
+ else:
53
+ self._client = client
54
+
55
+ self.n_results = config.get("n_results", 10)
56
+ self.dimensions = config.get("dimensions", 384)
57
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
58
+ self.documentation_namespace = config.get(
59
+ "documentation_namespace", "documentation"
60
+ )
61
+ self.distance_metric = config.get("distance_metric", "cosine")
62
+ self.ddl_namespace = config.get("ddl_namespace", "ddl")
63
+ self.sql_namespace = config.get("sql_namespace", "sql")
64
+ self.index_name = config.get("index_name", "vanna-index")
65
+ self.metadata_config = config.get("metadata_config", {})
66
+ self.server_type = config.get("server_type", "serverless")
67
+ if self.server_type not in ["serverless", "pod"]:
68
+ raise ValueError("server_type must be either 'serverless' or 'pod'")
69
+ self.podspec = config.get(
70
+ "podspec",
71
+ PodSpec(
72
+ environment="us-west-2",
73
+ pod_type="p1.x1",
74
+ metadata_config=self.metadata_config,
75
+ ),
76
+ )
77
+ self.serverless_spec = config.get(
78
+ "serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
79
+ )
80
+ self._setup_index()
81
+
82
+ def _set_index_host(self, host: str) -> None:
83
+ self.Index = self._client.Index(host=host)
84
+
85
+ def _setup_index(self) -> None:
86
+ existing_indexes = self._get_indexes()
87
+ if self.index_name not in existing_indexes and self.server_type == "serverless":
88
+ self._client.create_index(
89
+ name=self.index_name,
90
+ dimension=self.dimensions,
91
+ metric=self.distance_metric,
92
+ spec=self.serverless_spec,
93
+ )
94
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
95
+ self._set_index_host(pinecone_index_host)
96
+ elif self.index_name not in existing_indexes and self.server_type == "pod":
97
+ self._client.create_index(
98
+ name=self.index_name,
99
+ dimension=self.dimensions,
100
+ metric=self.distance_metric,
101
+ spec=self.podspec,
102
+ )
103
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
104
+ self._set_index_host(pinecone_index_host)
105
+ else:
106
+ pinecone_index_host = self._client.describe_index(self.index_name)["host"]
107
+ self._set_index_host(pinecone_index_host)
108
+
109
+ def _get_indexes(self) -> list:
110
+ return [index["name"] for index in self._client.list_indexes()]
111
+
112
+ def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
113
+ fetch_response = self.Index.fetch(ids=[id], namespace=namespace)
114
+ if fetch_response["vectors"] == {}:
115
+ return False
116
+ return True
117
+
118
+ def add_ddl(self, ddl: str, **kwargs) -> str:
119
+ id = deterministic_uuid(ddl) + "-ddl"
120
+ if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
121
+ print(f"DDL with id: {id} already exists in the index. Skipping...")
122
+ return id
123
+ self.Index.upsert(
124
+ vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
125
+ namespace=self.ddl_namespace,
126
+ )
127
+ return id
128
+
129
+ def add_documentation(self, doc: str, **kwargs) -> str:
130
+ id = deterministic_uuid(doc) + "-doc"
131
+
132
+ if self._check_if_embedding_exists(
133
+ id=id, namespace=self.documentation_namespace
134
+ ):
135
+ print(
136
+ f"Documentation with id: {id} already exists in the index. Skipping..."
137
+ )
138
+ return id
139
+ self.Index.upsert(
140
+ vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
141
+ namespace=self.documentation_namespace,
142
+ )
143
+ return id
144
+
145
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
146
+ question_sql_json = json.dumps(
147
+ {
148
+ "question": question,
149
+ "sql": sql,
150
+ },
151
+ ensure_ascii=False,
152
+ )
153
+ id = deterministic_uuid(question_sql_json) + "-sql"
154
+ if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
155
+ print(
156
+ f"Question-SQL with id: {id} already exists in the index. Skipping..."
157
+ )
158
+ return id
159
+ self.Index.upsert(
160
+ vectors=[
161
+ (
162
+ id,
163
+ self.generate_embedding(question_sql_json),
164
+ {"sql": question_sql_json},
165
+ )
166
+ ],
167
+ namespace=self.sql_namespace,
168
+ )
169
+ return id
170
+
171
+ def get_related_ddl(self, question: str, **kwargs) -> list:
172
+ res = self.Index.query(
173
+ namespace=self.ddl_namespace,
174
+ vector=self.generate_embedding(question),
175
+ top_k=self.n_results,
176
+ include_values=True,
177
+ include_metadata=True,
178
+ )
179
+ return [match["metadata"]["ddl"] for match in res["matches"]] if res else []
180
+
181
+ def get_related_documentation(self, question: str, **kwargs) -> list:
182
+ res = self.Index.query(
183
+ namespace=self.documentation_namespace,
184
+ vector=self.generate_embedding(question),
185
+ top_k=self.n_results,
186
+ include_values=True,
187
+ include_metadata=True,
188
+ )
189
+ return (
190
+ [match["metadata"]["documentation"] for match in res["matches"]]
191
+ if res
192
+ else []
193
+ )
194
+
195
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
196
+ res = self.Index.query(
197
+ namespace=self.sql_namespace,
198
+ vector=self.generate_embedding(question),
199
+ top_k=self.n_results,
200
+ include_values=True,
201
+ include_metadata=True,
202
+ )
203
+ return (
204
+ [
205
+ {
206
+ key: value
207
+ for key, value in json.loads(match["metadata"]["sql"]).items()
208
+ }
209
+ for match in res["matches"]
210
+ ]
211
+ if res
212
+ else []
213
+ )
214
+
215
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
216
+ # Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector
217
+ df = pd.DataFrame()
218
+ namespaces = {
219
+ "sql": self.sql_namespace,
220
+ "ddl": self.ddl_namespace,
221
+ "documentation": self.documentation_namespace,
222
+ }
223
+
224
+ for data_type, namespace in namespaces.items():
225
+ data = self.Index.query(
226
+ top_k=10000, # max results that pinecone allows
227
+ namespace=namespace,
228
+ include_values=True,
229
+ include_metadata=True,
230
+ vector=[0.0] * self.dimensions,
231
+ )
232
+
233
+ if data is not None:
234
+ id_list = [match["id"] for match in data["matches"]]
235
+ content_list = [
236
+ match["metadata"][data_type] for match in data["matches"]
237
+ ]
238
+ question_list = [
239
+ (
240
+ json.loads(match["metadata"][data_type])["question"]
241
+ if data_type == "sql"
242
+ else None
243
+ )
244
+ for match in data["matches"]
245
+ ]
246
+
247
+ df_data = pd.DataFrame(
248
+ {
249
+ "id": id_list,
250
+ "question": question_list,
251
+ "content": content_list,
252
+ }
253
+ )
254
+ df_data["training_data_type"] = data_type
255
+ df = pd.concat([df, df_data])
256
+
257
+ return df
258
+
259
+ def remove_training_data(self, id: str, **kwargs) -> bool:
260
+ if id.endswith("-sql"):
261
+ self.Index.delete(ids=[id], namespace=self.sql_namespace)
262
+ return True
263
+ elif id.endswith("-ddl"):
264
+ self.Index.delete(ids=[id], namespace=self.ddl_namespace)
265
+ return True
266
+ elif id.endswith("-doc"):
267
+ self.Index.delete(ids=[id], namespace=self.documentation_namespace)
268
+ return True
269
+ else:
270
+ return False
271
+
272
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
273
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
274
+ embedding = next(embedding_model.embed(data))
275
+ return embedding.tolist()
vanna/qdrant/qdrant.py CHANGED
@@ -3,6 +3,7 @@ from typing import List, Tuple
3
3
 
4
4
  import pandas as pd
5
5
  from qdrant_client import QdrantClient, grpc, models
6
+ from qdrant_client.http.models.models import UpdateStatus
6
7
 
7
8
  from ..base import VannaBase
8
9
  from ..utils import deterministic_uuid
@@ -210,7 +211,8 @@ class Qdrant_VectorStore(VannaBase):
210
211
  def remove_training_data(self, id: str, **kwargs) -> bool:
211
212
  try:
212
213
  id, collection_name = self._parse_point_id(id)
213
- self._client.delete(collection_name, points_selector=[id])
214
+ res = self._client.delete(collection_name, points_selector=[id])
215
+ res == UpdateStatus.COMPLETED
214
216
  except ValueError:
215
217
  return False
216
218
 
@@ -5,6 +5,7 @@ from io import StringIO
5
5
  import pandas as pd
6
6
  import requests
7
7
 
8
+ from ..advanced import VannaAdvanced
8
9
  from ..base import VannaBase
9
10
  from ..types import (
10
11
  DataFrameJSON,
@@ -20,7 +21,7 @@ from ..types import (
20
21
  from ..utils import sanitize_model_name
21
22
 
22
23
 
23
- class VannaDB_VectorStore(VannaBase):
24
+ class VannaDB_VectorStore(VannaBase, VannaAdvanced):
24
25
  def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
25
26
  VannaBase.__init__(self, config=config)
26
27
 
@@ -33,6 +34,12 @@ class VannaDB_VectorStore(VannaBase):
33
34
  else config["endpoint"]
34
35
  )
35
36
  self.related_training_data = {}
37
+ self._graphql_endpoint = "https://functionrag.com/query"
38
+ self._graphql_headers = {
39
+ "Content-Type": "application/json",
40
+ "API-KEY": self._api_key,
41
+ "NAMESPACE": self._model,
42
+ }
36
43
 
37
44
  def _rpc_call(self, method, params):
38
45
  if method != "list_orgs":
@@ -59,6 +66,177 @@ class VannaDB_VectorStore(VannaBase):
59
66
  def _dataclass_to_dict(self, obj):
60
67
  return dataclasses.asdict(obj)
61
68
 
69
+ def get_all_functions(self) -> list:
70
+ query = """
71
+ {
72
+ get_all_sql_functions {
73
+ function_name
74
+ description
75
+ post_processing_code_template
76
+ arguments {
77
+ name
78
+ description
79
+ general_type
80
+ is_user_editable
81
+ available_values
82
+ }
83
+ sql_template
84
+ }
85
+ }
86
+ """
87
+
88
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
89
+ response_json = response.json()
90
+ if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
91
+ self.log(response_json['data']['get_all_sql_functions'])
92
+ resp = response_json['data']['get_all_sql_functions']
93
+
94
+ print(resp)
95
+
96
+ return resp
97
+ else:
98
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
99
+
100
+
101
+ def get_function(self, question: str, additional_data: dict = {}) -> dict:
102
+ query = """
103
+ query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
104
+ get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
105
+ ... on SQLFunction {
106
+ function_name
107
+ description
108
+ post_processing_code_template
109
+ instantiated_post_processing_code
110
+ arguments {
111
+ name
112
+ description
113
+ general_type
114
+ is_user_editable
115
+ instantiated_value
116
+ available_values
117
+ }
118
+ sql_template
119
+ instantiated_sql
120
+ }
121
+ }
122
+ }
123
+ """
124
+ static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
125
+ variables = {"question": question, "staticFunctionArguments": static_function_arguments}
126
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
127
+ response_json = response.json()
128
+ if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
129
+ self.log(response_json['data']['get_and_instantiate_function'])
130
+ resp = response_json['data']['get_and_instantiate_function']
131
+
132
+ print(resp)
133
+
134
+ return resp
135
+ else:
136
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
137
+
138
+ def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
139
+ query = """
140
+ mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
141
+ generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
142
+ function_name
143
+ description
144
+ arguments {
145
+ name
146
+ description
147
+ general_type
148
+ is_user_editable
149
+ }
150
+ sql_template
151
+ post_processing_code_template
152
+ }
153
+ }
154
+ """
155
+ variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
156
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
157
+ response_json = response.json()
158
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
159
+ resp = response_json['data']['generate_and_create_sql_function']
160
+
161
+ print(resp)
162
+
163
+ return resp
164
+ else:
165
+ raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
166
+
167
+ def update_function(self, old_function_name: str, updated_function: dict) -> bool:
168
+ """
169
+ Update an existing SQL function based on the provided parameters.
170
+
171
+ Args:
172
+ old_function_name (str): The current name of the function to be updated.
173
+ updated_function (dict): A dictionary containing the updated function details. Expected keys:
174
+ - 'function_name': The new name of the function.
175
+ - 'description': The new description of the function.
176
+ - 'arguments': A list of dictionaries describing the function arguments.
177
+ - 'sql_template': The new SQL template for the function.
178
+ - 'post_processing_code_template': The new post-processing code template.
179
+
180
+ Returns:
181
+ bool: True if the function was successfully updated, False otherwise.
182
+ """
183
+ mutation = """
184
+ mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
185
+ update_sql_function(input: $input)
186
+ }
187
+ """
188
+
189
+ SQLFunctionUpdate = {
190
+ 'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
191
+ }
192
+
193
+ # Define the expected keys for each argument in the arguments list
194
+ ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
195
+
196
+ # Function to validate and transform arguments
197
+ def validate_arguments(args):
198
+ return [
199
+ {key: arg[key] for key in arg if key in ArgumentKeys}
200
+ for arg in args
201
+ ]
202
+
203
+ # Keep only the keys that conform to the SQLFunctionUpdate GraphQL input type
204
+ updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
205
+
206
+ # Special handling for 'arguments' to ensure they conform to the spec
207
+ if 'arguments' in updated_function:
208
+ updated_function['arguments'] = validate_arguments(updated_function['arguments'])
209
+
210
+ variables = {
211
+ "input": {
212
+ "old_function_name": old_function_name,
213
+ **updated_function
214
+ }
215
+ }
216
+
217
+ print("variables", variables)
218
+
219
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
220
+ response_json = response.json()
221
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
222
+ return response_json['data']['update_sql_function']
223
+ else:
224
+ raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
225
+
226
+ def delete_function(self, function_name: str) -> bool:
227
+ mutation = """
228
+ mutation DeleteSQLFunction($function_name: String!) {
229
+ delete_sql_function(function_name: $function_name)
230
+ }
231
+ """
232
+ variables = {"function_name": function_name}
233
+ response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
234
+ response_json = response.json()
235
+ if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
236
+ return response_json['data']['delete_sql_function']
237
+ else:
238
+ raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
239
+
62
240
  def create_model(self, model: str, **kwargs) -> bool:
63
241
  """
64
242
  **Example:**