vanna 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/ZhipuAI/ZhipuAI_Chat.py +3 -3
- vanna/base/base.py +105 -4
- vanna/chromadb/chromadb_vector.py +21 -15
- vanna/flask/__init__.py +117 -43
- vanna/flask/assets.py +17 -17
- vanna/flask/auth.py +55 -0
- vanna/google/__init__.py +1 -0
- vanna/google/gemini_chat.py +52 -0
- vanna/remote.py +3 -381
- vanna/vannadb/vannadb_vector.py +58 -27
- {vanna-0.3.3.dist-info → vanna-0.4.0.dist-info}/METADATA +6 -1
- {vanna-0.3.3.dist-info → vanna-0.4.0.dist-info}/RECORD +13 -10
- {vanna-0.3.3.dist-info → vanna-0.4.0.dist-info}/WHEEL +0 -0
vanna/ZhipuAI/ZhipuAI_Chat.py
CHANGED
|
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
40
40
|
initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
|
|
41
41
|
) -> str:
|
|
42
42
|
if len(ddl_list) > 0:
|
|
43
|
-
initial_prompt +=
|
|
43
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
44
44
|
|
|
45
45
|
for ddl in ddl_list:
|
|
46
46
|
if (
|
|
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
57
57
|
initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
|
|
58
58
|
) -> str:
|
|
59
59
|
if len(documentation_List) > 0:
|
|
60
|
-
initial_prompt +=
|
|
60
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
61
61
|
|
|
62
62
|
for documentation in documentation_List:
|
|
63
63
|
if (
|
|
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
|
|
|
74
74
|
initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
|
|
75
75
|
) -> str:
|
|
76
76
|
if len(sql_List) > 0:
|
|
77
|
-
initial_prompt +=
|
|
77
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
78
78
|
|
|
79
79
|
for question in sql_List:
|
|
80
80
|
if (
|
vanna/base/base.py
CHANGED
|
@@ -124,6 +124,18 @@ class VannaBase(ABC):
|
|
|
124
124
|
return self.extract_sql(llm_response)
|
|
125
125
|
|
|
126
126
|
def extract_sql(self, llm_response: str) -> str:
|
|
127
|
+
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
|
|
128
|
+
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
|
|
129
|
+
if sql:
|
|
130
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
|
|
131
|
+
)
|
|
132
|
+
return sql.group(0)
|
|
133
|
+
|
|
134
|
+
# If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
|
|
135
|
+
sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
|
|
136
|
+
if sql:
|
|
137
|
+
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
|
|
138
|
+
return sql.group(0)
|
|
127
139
|
# If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
|
|
128
140
|
sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
|
|
129
141
|
if sql:
|
|
@@ -363,7 +375,7 @@ class VannaBase(ABC):
|
|
|
363
375
|
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
|
|
364
376
|
) -> str:
|
|
365
377
|
if len(ddl_list) > 0:
|
|
366
|
-
initial_prompt +=
|
|
378
|
+
initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
367
379
|
|
|
368
380
|
for ddl in ddl_list:
|
|
369
381
|
if (
|
|
@@ -382,7 +394,7 @@ class VannaBase(ABC):
|
|
|
382
394
|
max_tokens: int = 14000,
|
|
383
395
|
) -> str:
|
|
384
396
|
if len(documentation_list) > 0:
|
|
385
|
-
initial_prompt +=
|
|
397
|
+
initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
386
398
|
|
|
387
399
|
for documentation in documentation_list:
|
|
388
400
|
if (
|
|
@@ -398,7 +410,7 @@ class VannaBase(ABC):
|
|
|
398
410
|
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
|
|
399
411
|
) -> str:
|
|
400
412
|
if len(sql_list) > 0:
|
|
401
|
-
initial_prompt +=
|
|
413
|
+
initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
402
414
|
|
|
403
415
|
for question in sql_list:
|
|
404
416
|
if (
|
|
@@ -642,6 +654,7 @@ class VannaBase(ABC):
|
|
|
642
654
|
password=password,
|
|
643
655
|
account=account,
|
|
644
656
|
database=database,
|
|
657
|
+
client_session_keep_alive=True
|
|
645
658
|
)
|
|
646
659
|
|
|
647
660
|
def run_sql_snowflake(sql: str) -> pd.DataFrame:
|
|
@@ -890,6 +903,94 @@ class VannaBase(ABC):
|
|
|
890
903
|
self.run_sql_is_set = True
|
|
891
904
|
self.run_sql = run_sql_mysql
|
|
892
905
|
|
|
906
|
+
def connect_to_oracle(
|
|
907
|
+
self,
|
|
908
|
+
user: str = None,
|
|
909
|
+
password: str = None,
|
|
910
|
+
dsn: str = None,
|
|
911
|
+
):
|
|
912
|
+
|
|
913
|
+
"""
|
|
914
|
+
Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
915
|
+
**Example:**
|
|
916
|
+
```python
|
|
917
|
+
vn.connect_to_oracle(
|
|
918
|
+
user="username",
|
|
919
|
+
password="password",
|
|
920
|
+
dns="host:port/sid",
|
|
921
|
+
)
|
|
922
|
+
```
|
|
923
|
+
Args:
|
|
924
|
+
USER (str): Oracle db user name.
|
|
925
|
+
PASSWORD (str): Oracle db user password.
|
|
926
|
+
DSN (str): Oracle db host ip - host:port/sid.
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
try:
|
|
930
|
+
import oracledb
|
|
931
|
+
except ImportError:
|
|
932
|
+
|
|
933
|
+
raise DependencyError(
|
|
934
|
+
"You need to install required dependencies to execute this method,"
|
|
935
|
+
" run command: \npip install oracledb"
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
if not dsn:
|
|
939
|
+
dsn = os.getenv("DSN")
|
|
940
|
+
|
|
941
|
+
if not dsn:
|
|
942
|
+
raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
|
|
943
|
+
|
|
944
|
+
if not user:
|
|
945
|
+
user = os.getenv("USER")
|
|
946
|
+
|
|
947
|
+
if not user:
|
|
948
|
+
raise ImproperlyConfigured("Please set your Oracle db user")
|
|
949
|
+
|
|
950
|
+
if not password:
|
|
951
|
+
password = os.getenv("PASSWORD")
|
|
952
|
+
|
|
953
|
+
if not password:
|
|
954
|
+
raise ImproperlyConfigured("Please set your Oracle db password")
|
|
955
|
+
|
|
956
|
+
conn = None
|
|
957
|
+
|
|
958
|
+
try:
|
|
959
|
+
conn = oracledb.connect(
|
|
960
|
+
user=user,
|
|
961
|
+
password=password,
|
|
962
|
+
dsn=dsn,
|
|
963
|
+
)
|
|
964
|
+
except oracledb.Error as e:
|
|
965
|
+
raise ValidationError(e)
|
|
966
|
+
|
|
967
|
+
def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
|
|
968
|
+
if conn:
|
|
969
|
+
try:
|
|
970
|
+
sql = sql.rstrip()
|
|
971
|
+
if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
|
|
972
|
+
sql = sql[:-1]
|
|
973
|
+
|
|
974
|
+
cs = conn.cursor()
|
|
975
|
+
cs.execute(sql)
|
|
976
|
+
results = cs.fetchall()
|
|
977
|
+
|
|
978
|
+
# Create a pandas dataframe from the results
|
|
979
|
+
df = pd.DataFrame(
|
|
980
|
+
results, columns=[desc[0] for desc in cs.description]
|
|
981
|
+
)
|
|
982
|
+
return df
|
|
983
|
+
|
|
984
|
+
except oracledb.Error as e:
|
|
985
|
+
conn.rollback()
|
|
986
|
+
raise ValidationError(e)
|
|
987
|
+
|
|
988
|
+
except Exception as e:
|
|
989
|
+
conn.rollback()
|
|
990
|
+
raise e
|
|
991
|
+
|
|
992
|
+
self.run_sql_is_set = True
|
|
993
|
+
self.run_sql = run_sql_oracle
|
|
893
994
|
|
|
894
995
|
def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
|
|
895
996
|
"""
|
|
@@ -1238,7 +1339,7 @@ class VannaBase(ABC):
|
|
|
1238
1339
|
"""
|
|
1239
1340
|
|
|
1240
1341
|
if question and not sql:
|
|
1241
|
-
raise ValidationError(
|
|
1342
|
+
raise ValidationError("Please also provide a SQL query")
|
|
1242
1343
|
|
|
1243
1344
|
if documentation:
|
|
1244
1345
|
print("Adding documentation....")
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import uuid
|
|
3
2
|
from typing import List
|
|
4
3
|
|
|
5
4
|
import chromadb
|
|
@@ -16,17 +15,16 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
|
|
|
16
15
|
class ChromaDB_VectorStore(VannaBase):
|
|
17
16
|
def __init__(self, config=None):
|
|
18
17
|
VannaBase.__init__(self, config=config)
|
|
18
|
+
if config is None:
|
|
19
|
+
config = {}
|
|
19
20
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
self.embedding_function = default_ef
|
|
28
|
-
curr_client = "persistent" # defaults to persistent storage
|
|
29
|
-
self.n_results = 10 # defaults to 10 documents
|
|
21
|
+
path = config.get("path", ".")
|
|
22
|
+
self.embedding_function = config.get("embedding_function", default_ef)
|
|
23
|
+
curr_client = config.get("client", "persistent")
|
|
24
|
+
collection_metadata = config.get("collection_metadata", None)
|
|
25
|
+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
26
|
+
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
|
|
27
|
+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
30
28
|
|
|
31
29
|
if curr_client == "persistent":
|
|
32
30
|
self.chroma_client = chromadb.PersistentClient(
|
|
@@ -43,13 +41,19 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
43
41
|
raise ValueError(f"Unsupported client was set in config: {curr_client}")
|
|
44
42
|
|
|
45
43
|
self.documentation_collection = self.chroma_client.get_or_create_collection(
|
|
46
|
-
name="documentation",
|
|
44
|
+
name="documentation",
|
|
45
|
+
embedding_function=self.embedding_function,
|
|
46
|
+
metadata=collection_metadata,
|
|
47
47
|
)
|
|
48
48
|
self.ddl_collection = self.chroma_client.get_or_create_collection(
|
|
49
|
-
name="ddl",
|
|
49
|
+
name="ddl",
|
|
50
|
+
embedding_function=self.embedding_function,
|
|
51
|
+
metadata=collection_metadata,
|
|
50
52
|
)
|
|
51
53
|
self.sql_collection = self.chroma_client.get_or_create_collection(
|
|
52
|
-
name="sql",
|
|
54
|
+
name="sql",
|
|
55
|
+
embedding_function=self.embedding_function,
|
|
56
|
+
metadata=collection_metadata,
|
|
53
57
|
)
|
|
54
58
|
|
|
55
59
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
@@ -232,7 +236,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
232
236
|
return ChromaDB_VectorStore._extract_documents(
|
|
233
237
|
self.sql_collection.query(
|
|
234
238
|
query_texts=[question],
|
|
235
|
-
n_results=self.
|
|
239
|
+
n_results=self.n_results_sql,
|
|
236
240
|
)
|
|
237
241
|
)
|
|
238
242
|
|
|
@@ -240,6 +244,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
240
244
|
return ChromaDB_VectorStore._extract_documents(
|
|
241
245
|
self.ddl_collection.query(
|
|
242
246
|
query_texts=[question],
|
|
247
|
+
n_results=self.n_results_ddl,
|
|
243
248
|
)
|
|
244
249
|
)
|
|
245
250
|
|
|
@@ -247,5 +252,6 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
247
252
|
return ChromaDB_VectorStore._extract_documents(
|
|
248
253
|
self.documentation_collection.query(
|
|
249
254
|
query_texts=[question],
|
|
255
|
+
n_results=self.n_results_documentation,
|
|
250
256
|
)
|
|
251
257
|
)
|
vanna/flask/__init__.py
CHANGED
|
@@ -8,27 +8,47 @@ import requests
|
|
|
8
8
|
from flask import Flask, Response, jsonify, request
|
|
9
9
|
|
|
10
10
|
from .assets import css_content, html_content, js_content
|
|
11
|
+
from .auth import AuthInterface, NoAuth
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Cache(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Define the interface for a cache that can be used to store data in a Flask app.
|
|
17
|
+
"""
|
|
18
|
+
|
|
14
19
|
@abstractmethod
|
|
15
20
|
def generate_id(self, *args, **kwargs):
|
|
21
|
+
"""
|
|
22
|
+
Generate a unique ID for the cache.
|
|
23
|
+
"""
|
|
16
24
|
pass
|
|
17
25
|
|
|
18
26
|
@abstractmethod
|
|
19
27
|
def get(self, id, field):
|
|
28
|
+
"""
|
|
29
|
+
Get a value from the cache.
|
|
30
|
+
"""
|
|
20
31
|
pass
|
|
21
32
|
|
|
22
33
|
@abstractmethod
|
|
23
34
|
def get_all(self, field_list) -> list:
|
|
35
|
+
"""
|
|
36
|
+
Get all values from the cache.
|
|
37
|
+
"""
|
|
24
38
|
pass
|
|
25
39
|
|
|
26
40
|
@abstractmethod
|
|
27
41
|
def set(self, id, field, value):
|
|
42
|
+
"""
|
|
43
|
+
Set a value in the cache.
|
|
44
|
+
"""
|
|
28
45
|
pass
|
|
29
46
|
|
|
30
47
|
@abstractmethod
|
|
31
48
|
def delete(self, id):
|
|
49
|
+
"""
|
|
50
|
+
Delete a value from the cache.
|
|
51
|
+
"""
|
|
32
52
|
pass
|
|
33
53
|
|
|
34
54
|
|
|
@@ -64,11 +84,10 @@ class MemoryCache(Cache):
|
|
|
64
84
|
if id in self.cache:
|
|
65
85
|
del self.cache[id]
|
|
66
86
|
|
|
67
|
-
|
|
68
87
|
class VannaFlaskApp:
|
|
69
88
|
flask_app = None
|
|
70
89
|
|
|
71
|
-
def requires_cache(self,
|
|
90
|
+
def requires_cache(self, required_fields, optional_fields=[]):
|
|
72
91
|
def decorator(f):
|
|
73
92
|
@wraps(f)
|
|
74
93
|
def decorated(*args, **kwargs):
|
|
@@ -79,14 +98,17 @@ class VannaFlaskApp:
|
|
|
79
98
|
if id is None:
|
|
80
99
|
return jsonify({"type": "error", "error": "No id provided"})
|
|
81
100
|
|
|
82
|
-
for field in
|
|
101
|
+
for field in required_fields:
|
|
83
102
|
if self.cache.get(id=id, field=field) is None:
|
|
84
103
|
return jsonify({"type": "error", "error": f"No {field} found"})
|
|
85
104
|
|
|
86
105
|
field_values = {
|
|
87
|
-
field: self.cache.get(id=id, field=field) for field in
|
|
106
|
+
field: self.cache.get(id=id, field=field) for field in required_fields
|
|
88
107
|
}
|
|
89
108
|
|
|
109
|
+
for field in optional_fields:
|
|
110
|
+
field_values[field] = self.cache.get(id=id, field=field)
|
|
111
|
+
|
|
90
112
|
# Add the id to the field_values
|
|
91
113
|
field_values["id"] = id
|
|
92
114
|
|
|
@@ -96,7 +118,21 @@ class VannaFlaskApp:
|
|
|
96
118
|
|
|
97
119
|
return decorator
|
|
98
120
|
|
|
121
|
+
def requires_auth(self, f):
|
|
122
|
+
@wraps(f)
|
|
123
|
+
def decorated(*args, **kwargs):
|
|
124
|
+
user = self.auth.get_user(flask.request)
|
|
125
|
+
|
|
126
|
+
if not self.auth.is_logged_in(user):
|
|
127
|
+
return jsonify({"type": "not_logged_in", "html": self.auth.login_form()})
|
|
128
|
+
|
|
129
|
+
# Pass the user to the function
|
|
130
|
+
return f(*args, user=user, **kwargs)
|
|
131
|
+
|
|
132
|
+
return decorated
|
|
133
|
+
|
|
99
134
|
def __init__(self, vn, cache: Cache = MemoryCache(),
|
|
135
|
+
auth: AuthInterface = NoAuth(),
|
|
100
136
|
allow_llm_to_see_data=False,
|
|
101
137
|
logo="https://img.vanna.ai/vanna-flask.svg",
|
|
102
138
|
title="Welcome to Vanna.AI",
|
|
@@ -119,6 +155,7 @@ class VannaFlaskApp:
|
|
|
119
155
|
Args:
|
|
120
156
|
vn: The Vanna instance to interact with.
|
|
121
157
|
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
|
|
158
|
+
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
|
|
122
159
|
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
|
|
123
160
|
logo: The logo to display in the UI. Defaults to the Vanna logo.
|
|
124
161
|
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
|
|
@@ -140,6 +177,7 @@ class VannaFlaskApp:
|
|
|
140
177
|
"""
|
|
141
178
|
self.flask_app = Flask(__name__)
|
|
142
179
|
self.vn = vn
|
|
180
|
+
self.auth = auth
|
|
143
181
|
self.cache = cache
|
|
144
182
|
self.allow_llm_to_see_data = allow_llm_to_see_data
|
|
145
183
|
self.logo = logo
|
|
@@ -160,32 +198,50 @@ class VannaFlaskApp:
|
|
|
160
198
|
log = logging.getLogger("werkzeug")
|
|
161
199
|
log.setLevel(logging.ERROR)
|
|
162
200
|
|
|
201
|
+
@self.flask_app.route("/auth/login", methods=["POST"])
|
|
202
|
+
def login():
|
|
203
|
+
return self.auth.login_handler(flask.request)
|
|
204
|
+
|
|
205
|
+
@self.flask_app.route("/auth/callback", methods=["GET"])
|
|
206
|
+
def callback():
|
|
207
|
+
return self.auth.callback_handler(flask.request)
|
|
208
|
+
|
|
209
|
+
@self.flask_app.route("/auth/logout", methods=["GET"])
|
|
210
|
+
def logout():
|
|
211
|
+
return self.auth.logout_handler(flask.request)
|
|
212
|
+
|
|
163
213
|
@self.flask_app.route("/api/v0/get_config", methods=["GET"])
|
|
164
|
-
|
|
214
|
+
@self.requires_auth
|
|
215
|
+
def get_config(user: any):
|
|
216
|
+
config = {
|
|
217
|
+
"logo": self.logo,
|
|
218
|
+
"title": self.title,
|
|
219
|
+
"subtitle": self.subtitle,
|
|
220
|
+
"show_training_data": self.show_training_data,
|
|
221
|
+
"suggested_questions": self.suggested_questions,
|
|
222
|
+
"sql": self.sql,
|
|
223
|
+
"table": self.table,
|
|
224
|
+
"csv_download": self.csv_download,
|
|
225
|
+
"chart": self.chart,
|
|
226
|
+
"redraw_chart": self.redraw_chart,
|
|
227
|
+
"auto_fix_sql": self.auto_fix_sql,
|
|
228
|
+
"ask_results_correct": self.ask_results_correct,
|
|
229
|
+
"followup_questions": self.followup_questions,
|
|
230
|
+
"summarization": self.summarization,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
config = self.auth.override_config_for_user(user, config)
|
|
234
|
+
|
|
165
235
|
return jsonify(
|
|
166
236
|
{
|
|
167
237
|
"type": "config",
|
|
168
|
-
"config":
|
|
169
|
-
"logo": self.logo,
|
|
170
|
-
"title": self.title,
|
|
171
|
-
"subtitle": self.subtitle,
|
|
172
|
-
"show_training_data": self.show_training_data,
|
|
173
|
-
"suggested_questions": self.suggested_questions,
|
|
174
|
-
"sql": self.sql,
|
|
175
|
-
"table": self.table,
|
|
176
|
-
"csv_download": self.csv_download,
|
|
177
|
-
"chart": self.chart,
|
|
178
|
-
"redraw_chart": self.redraw_chart,
|
|
179
|
-
"auto_fix_sql": self.auto_fix_sql,
|
|
180
|
-
"ask_results_correct": self.ask_results_correct,
|
|
181
|
-
"followup_questions": self.followup_questions,
|
|
182
|
-
"summarization": self.summarization,
|
|
183
|
-
},
|
|
238
|
+
"config": config
|
|
184
239
|
}
|
|
185
240
|
)
|
|
186
241
|
|
|
187
242
|
@self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
|
|
188
|
-
|
|
243
|
+
@self.requires_auth
|
|
244
|
+
def generate_questions(user: any):
|
|
189
245
|
# If self has an _model attribute and model=='chinook'
|
|
190
246
|
if hasattr(self.vn, "_model") and self.vn._model == "chinook":
|
|
191
247
|
return jsonify(
|
|
@@ -240,7 +296,8 @@ class VannaFlaskApp:
|
|
|
240
296
|
)
|
|
241
297
|
|
|
242
298
|
@self.flask_app.route("/api/v0/generate_sql", methods=["GET"])
|
|
243
|
-
|
|
299
|
+
@self.requires_auth
|
|
300
|
+
def generate_sql(user: any):
|
|
244
301
|
question = flask.request.args.get("question")
|
|
245
302
|
|
|
246
303
|
if question is None:
|
|
@@ -261,8 +318,9 @@ class VannaFlaskApp:
|
|
|
261
318
|
)
|
|
262
319
|
|
|
263
320
|
@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
|
|
321
|
+
@self.requires_auth
|
|
264
322
|
@self.requires_cache(["sql"])
|
|
265
|
-
def run_sql(id: str, sql: str):
|
|
323
|
+
def run_sql(user: any, id: str, sql: str):
|
|
266
324
|
try:
|
|
267
325
|
if not vn.run_sql_is_set:
|
|
268
326
|
return jsonify(
|
|
@@ -274,7 +332,7 @@ class VannaFlaskApp:
|
|
|
274
332
|
|
|
275
333
|
df = vn.run_sql(sql=sql)
|
|
276
334
|
|
|
277
|
-
cache.set(id=id, field="df", value=df)
|
|
335
|
+
self.cache.set(id=id, field="df", value=df)
|
|
278
336
|
|
|
279
337
|
return jsonify(
|
|
280
338
|
{
|
|
@@ -288,8 +346,9 @@ class VannaFlaskApp:
|
|
|
288
346
|
return jsonify({"type": "sql_error", "error": str(e)})
|
|
289
347
|
|
|
290
348
|
@self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
|
|
349
|
+
@self.requires_auth
|
|
291
350
|
@self.requires_cache(["question", "sql"])
|
|
292
|
-
def fix_sql(id: str, question:str, sql: str):
|
|
351
|
+
def fix_sql(user: any, id: str, question:str, sql: str):
|
|
293
352
|
error = flask.request.json.get("error")
|
|
294
353
|
|
|
295
354
|
if error is None:
|
|
@@ -311,14 +370,15 @@ class VannaFlaskApp:
|
|
|
311
370
|
|
|
312
371
|
|
|
313
372
|
@self.flask_app.route('/api/v0/update_sql', methods=['POST'])
|
|
373
|
+
@self.requires_auth
|
|
314
374
|
@self.requires_cache([])
|
|
315
|
-
def update_sql(id: str):
|
|
375
|
+
def update_sql(user: any, id: str):
|
|
316
376
|
sql = flask.request.json.get('sql')
|
|
317
377
|
|
|
318
378
|
if sql is None:
|
|
319
379
|
return jsonify({"type": "error", "error": "No sql provided"})
|
|
320
380
|
|
|
321
|
-
cache.set(id=id, field='sql', value=sql)
|
|
381
|
+
self.cache.set(id=id, field='sql', value=sql)
|
|
322
382
|
|
|
323
383
|
return jsonify(
|
|
324
384
|
{
|
|
@@ -328,8 +388,9 @@ class VannaFlaskApp:
|
|
|
328
388
|
})
|
|
329
389
|
|
|
330
390
|
@self.flask_app.route("/api/v0/download_csv", methods=["GET"])
|
|
391
|
+
@self.requires_auth
|
|
331
392
|
@self.requires_cache(["df"])
|
|
332
|
-
def download_csv(id: str, df):
|
|
393
|
+
def download_csv(user: any, id: str, df):
|
|
333
394
|
csv = df.to_csv()
|
|
334
395
|
|
|
335
396
|
return Response(
|
|
@@ -339,8 +400,9 @@ class VannaFlaskApp:
|
|
|
339
400
|
)
|
|
340
401
|
|
|
341
402
|
@self.flask_app.route("/api/v0/generate_plotly_figure", methods=["GET"])
|
|
403
|
+
@self.requires_auth
|
|
342
404
|
@self.requires_cache(["df", "question", "sql"])
|
|
343
|
-
def generate_plotly_figure(id: str, df, question, sql):
|
|
405
|
+
def generate_plotly_figure(user: any, id: str, df, question, sql):
|
|
344
406
|
chart_instructions = flask.request.args.get('chart_instructions')
|
|
345
407
|
|
|
346
408
|
if chart_instructions is not None:
|
|
@@ -355,7 +417,7 @@ class VannaFlaskApp:
|
|
|
355
417
|
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
356
418
|
fig_json = fig.to_json()
|
|
357
419
|
|
|
358
|
-
cache.set(id=id, field="fig_json", value=fig_json)
|
|
420
|
+
self.cache.set(id=id, field="fig_json", value=fig_json)
|
|
359
421
|
|
|
360
422
|
return jsonify(
|
|
361
423
|
{
|
|
@@ -373,7 +435,8 @@ class VannaFlaskApp:
|
|
|
373
435
|
return jsonify({"type": "error", "error": str(e)})
|
|
374
436
|
|
|
375
437
|
@self.flask_app.route("/api/v0/get_training_data", methods=["GET"])
|
|
376
|
-
|
|
438
|
+
@self.requires_auth
|
|
439
|
+
def get_training_data(user: any):
|
|
377
440
|
df = vn.get_training_data()
|
|
378
441
|
|
|
379
442
|
if df is None or len(df) == 0:
|
|
@@ -393,7 +456,8 @@ class VannaFlaskApp:
|
|
|
393
456
|
)
|
|
394
457
|
|
|
395
458
|
@self.flask_app.route("/api/v0/remove_training_data", methods=["POST"])
|
|
396
|
-
|
|
459
|
+
@self.requires_auth
|
|
460
|
+
def remove_training_data(user: any):
|
|
397
461
|
# Get id from the JSON body
|
|
398
462
|
id = flask.request.json.get("id")
|
|
399
463
|
|
|
@@ -408,7 +472,8 @@ class VannaFlaskApp:
|
|
|
408
472
|
)
|
|
409
473
|
|
|
410
474
|
@self.flask_app.route("/api/v0/train", methods=["POST"])
|
|
411
|
-
|
|
475
|
+
@self.requires_auth
|
|
476
|
+
def add_training_data(user: any):
|
|
412
477
|
question = flask.request.json.get("question")
|
|
413
478
|
sql = flask.request.json.get("sql")
|
|
414
479
|
ddl = flask.request.json.get("ddl")
|
|
@@ -425,8 +490,9 @@ class VannaFlaskApp:
|
|
|
425
490
|
return jsonify({"type": "error", "error": str(e)})
|
|
426
491
|
|
|
427
492
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
493
|
+
@self.requires_auth
|
|
428
494
|
@self.requires_cache(["df", "question", "sql"])
|
|
429
|
-
def generate_followup_questions(id: str, df, question, sql):
|
|
495
|
+
def generate_followup_questions(user: any, id: str, df, question, sql):
|
|
430
496
|
if self.allow_llm_to_see_data:
|
|
431
497
|
followup_questions = vn.generate_followup_questions(
|
|
432
498
|
question=question, sql=sql, df=df
|
|
@@ -434,7 +500,7 @@ class VannaFlaskApp:
|
|
|
434
500
|
if followup_questions is not None and len(followup_questions) > 5:
|
|
435
501
|
followup_questions = followup_questions[:5]
|
|
436
502
|
|
|
437
|
-
cache.set(id=id, field="followup_questions", value=followup_questions)
|
|
503
|
+
self.cache.set(id=id, field="followup_questions", value=followup_questions)
|
|
438
504
|
|
|
439
505
|
return jsonify(
|
|
440
506
|
{
|
|
@@ -445,7 +511,7 @@ class VannaFlaskApp:
|
|
|
445
511
|
}
|
|
446
512
|
)
|
|
447
513
|
else:
|
|
448
|
-
cache.set(id=id, field="followup_questions", value=[])
|
|
514
|
+
self.cache.set(id=id, field="followup_questions", value=[])
|
|
449
515
|
return jsonify(
|
|
450
516
|
{
|
|
451
517
|
"type": "question_list",
|
|
@@ -456,10 +522,14 @@ class VannaFlaskApp:
|
|
|
456
522
|
)
|
|
457
523
|
|
|
458
524
|
@self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
|
|
525
|
+
@self.requires_auth
|
|
459
526
|
@self.requires_cache(["df", "question"])
|
|
460
|
-
def generate_summary(id: str, df, question):
|
|
527
|
+
def generate_summary(user: any, id: str, df, question):
|
|
461
528
|
if self.allow_llm_to_see_data:
|
|
462
529
|
summary = vn.generate_summary(question=question, df=df)
|
|
530
|
+
|
|
531
|
+
self.cache.set(id=id, field="summary", value=summary)
|
|
532
|
+
|
|
463
533
|
return jsonify(
|
|
464
534
|
{
|
|
465
535
|
"type": "text",
|
|
@@ -477,10 +547,12 @@ class VannaFlaskApp:
|
|
|
477
547
|
)
|
|
478
548
|
|
|
479
549
|
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
|
|
550
|
+
@self.requires_auth
|
|
480
551
|
@self.requires_cache(
|
|
481
|
-
["question", "sql", "df", "fig_json"]
|
|
552
|
+
["question", "sql", "df", "fig_json"],
|
|
553
|
+
optional_fields=["summary"]
|
|
482
554
|
)
|
|
483
|
-
def load_question(id: str, question, sql, df, fig_json):
|
|
555
|
+
def load_question(user: any, id: str, question, sql, df, fig_json, summary):
|
|
484
556
|
try:
|
|
485
557
|
return jsonify(
|
|
486
558
|
{
|
|
@@ -488,8 +560,9 @@ class VannaFlaskApp:
|
|
|
488
560
|
"id": id,
|
|
489
561
|
"question": question,
|
|
490
562
|
"sql": sql,
|
|
491
|
-
"df": df.head(10).to_json(orient="records"),
|
|
563
|
+
"df": df.head(10).to_json(orient="records", date_format="iso"),
|
|
492
564
|
"fig": fig_json,
|
|
565
|
+
"summary": summary,
|
|
493
566
|
}
|
|
494
567
|
)
|
|
495
568
|
|
|
@@ -497,7 +570,8 @@ class VannaFlaskApp:
|
|
|
497
570
|
return jsonify({"type": "error", "error": str(e)})
|
|
498
571
|
|
|
499
572
|
@self.flask_app.route("/api/v0/get_question_history", methods=["GET"])
|
|
500
|
-
|
|
573
|
+
@self.requires_auth
|
|
574
|
+
def get_question_history(user: any):
|
|
501
575
|
return jsonify(
|
|
502
576
|
{
|
|
503
577
|
"type": "question_history",
|
|
@@ -525,7 +599,7 @@ class VannaFlaskApp:
|
|
|
525
599
|
# Proxy the /vanna.svg file to the remote server
|
|
526
600
|
@self.flask_app.route("/vanna.svg")
|
|
527
601
|
def proxy_vanna_svg():
|
|
528
|
-
remote_url =
|
|
602
|
+
remote_url = "https://vanna.ai/img/vanna.svg"
|
|
529
603
|
response = requests.get(remote_url, stream=True)
|
|
530
604
|
|
|
531
605
|
# Check if the request to the remote URL was successful
|