vanna 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/anthropic/anthropic_chat.py +9 -9
- vanna/base/base.py +82 -53
- vanna/flask/__init__.py +603 -133
- vanna/openai/openai_chat.py +0 -9
- {vanna-0.6.3.dist-info → vanna-0.6.4.dist-info}/METADATA +2 -1
- {vanna-0.6.3.dist-info → vanna-0.6.4.dist-info}/RECORD +7 -7
- {vanna-0.6.3.dist-info → vanna-0.6.4.dist-info}/WHEEL +0 -0
|
@@ -8,15 +8,7 @@ from ..base import VannaBase
|
|
|
8
8
|
class Anthropic_Chat(VannaBase):
|
|
9
9
|
def __init__(self, client=None, config=None):
|
|
10
10
|
VannaBase.__init__(self, config=config)
|
|
11
|
-
|
|
12
|
-
if client is not None:
|
|
13
|
-
self.client = client
|
|
14
|
-
return
|
|
15
|
-
|
|
16
|
-
if config is None and client is None:
|
|
17
|
-
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
|
18
|
-
return
|
|
19
|
-
|
|
11
|
+
|
|
20
12
|
# default parameters - can be overrided using config
|
|
21
13
|
self.temperature = 0.7
|
|
22
14
|
self.max_tokens = 500
|
|
@@ -27,6 +19,14 @@ class Anthropic_Chat(VannaBase):
|
|
|
27
19
|
if "max_tokens" in config:
|
|
28
20
|
self.max_tokens = config["max_tokens"]
|
|
29
21
|
|
|
22
|
+
if client is not None:
|
|
23
|
+
self.client = client
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
if config is None and client is None:
|
|
27
|
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
|
28
|
+
return
|
|
29
|
+
|
|
30
30
|
if "api_key" in config:
|
|
31
31
|
self.client = anthropic.Anthropic(api_key=config["api_key"])
|
|
32
32
|
|
vanna/base/base.py
CHANGED
|
@@ -577,6 +577,7 @@ class VannaBase(ABC):
|
|
|
577
577
|
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
|
578
578
|
"4. Please use the most relevant table(s). \n"
|
|
579
579
|
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
|
580
|
+
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
|
580
581
|
)
|
|
581
582
|
|
|
582
583
|
message_log = [self.system_message(initial_prompt)]
|
|
@@ -718,6 +719,7 @@ class VannaBase(ABC):
|
|
|
718
719
|
database: str,
|
|
719
720
|
role: Union[str, None] = None,
|
|
720
721
|
warehouse: Union[str, None] = None,
|
|
722
|
+
**kwargs
|
|
721
723
|
):
|
|
722
724
|
try:
|
|
723
725
|
snowflake = __import__("snowflake.connector")
|
|
@@ -735,7 +737,7 @@ class VannaBase(ABC):
|
|
|
735
737
|
else:
|
|
736
738
|
raise ImproperlyConfigured("Please set your Snowflake username.")
|
|
737
739
|
|
|
738
|
-
if password == "
|
|
740
|
+
if password == "mypassword":
|
|
739
741
|
password_env = os.getenv("SNOWFLAKE_PASSWORD")
|
|
740
742
|
|
|
741
743
|
if password_env is not None:
|
|
@@ -764,7 +766,8 @@ class VannaBase(ABC):
|
|
|
764
766
|
password=password,
|
|
765
767
|
account=account,
|
|
766
768
|
database=database,
|
|
767
|
-
client_session_keep_alive=True
|
|
769
|
+
client_session_keep_alive=True,
|
|
770
|
+
**kwargs
|
|
768
771
|
)
|
|
769
772
|
|
|
770
773
|
def run_sql_snowflake(sql: str) -> pd.DataFrame:
|
|
@@ -790,13 +793,13 @@ class VannaBase(ABC):
|
|
|
790
793
|
self.run_sql = run_sql_snowflake
|
|
791
794
|
self.run_sql_is_set = True
|
|
792
795
|
|
|
793
|
-
def connect_to_sqlite(self, url: str):
|
|
796
|
+
def connect_to_sqlite(self, url: str, check_same_thread: bool = False, **kwargs):
|
|
794
797
|
"""
|
|
795
798
|
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
796
799
|
|
|
797
800
|
Args:
|
|
798
801
|
url (str): The URL of the database to connect to.
|
|
799
|
-
|
|
802
|
+
check_same_thread (str): Allow the connection may be accessed in multiple threads.
|
|
800
803
|
Returns:
|
|
801
804
|
None
|
|
802
805
|
"""
|
|
@@ -815,7 +818,11 @@ class VannaBase(ABC):
|
|
|
815
818
|
url = path
|
|
816
819
|
|
|
817
820
|
# Connect to the database
|
|
818
|
-
conn = sqlite3.connect(
|
|
821
|
+
conn = sqlite3.connect(
|
|
822
|
+
url,
|
|
823
|
+
check_same_thread=check_same_thread,
|
|
824
|
+
**kwargs
|
|
825
|
+
)
|
|
819
826
|
|
|
820
827
|
def run_sql_sqlite(sql: str):
|
|
821
828
|
return pd.read_sql_query(sql, conn)
|
|
@@ -831,6 +838,7 @@ class VannaBase(ABC):
|
|
|
831
838
|
user: str = None,
|
|
832
839
|
password: str = None,
|
|
833
840
|
port: int = None,
|
|
841
|
+
**kwargs
|
|
834
842
|
):
|
|
835
843
|
"""
|
|
836
844
|
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
@@ -900,6 +908,7 @@ class VannaBase(ABC):
|
|
|
900
908
|
user=user,
|
|
901
909
|
password=password,
|
|
902
910
|
port=port,
|
|
911
|
+
**kwargs
|
|
903
912
|
)
|
|
904
913
|
except psycopg2.Error as e:
|
|
905
914
|
raise ValidationError(e)
|
|
@@ -931,12 +940,13 @@ class VannaBase(ABC):
|
|
|
931
940
|
|
|
932
941
|
|
|
933
942
|
def connect_to_mysql(
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
943
|
+
self,
|
|
944
|
+
host: str = None,
|
|
945
|
+
dbname: str = None,
|
|
946
|
+
user: str = None,
|
|
947
|
+
password: str = None,
|
|
948
|
+
port: int = None,
|
|
949
|
+
**kwargs
|
|
940
950
|
):
|
|
941
951
|
|
|
942
952
|
try:
|
|
@@ -980,12 +990,15 @@ class VannaBase(ABC):
|
|
|
980
990
|
conn = None
|
|
981
991
|
|
|
982
992
|
try:
|
|
983
|
-
conn = pymysql.connect(
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
993
|
+
conn = pymysql.connect(
|
|
994
|
+
host=host,
|
|
995
|
+
user=user,
|
|
996
|
+
password=password,
|
|
997
|
+
database=dbname,
|
|
998
|
+
port=port,
|
|
999
|
+
cursorclass=pymysql.cursors.DictCursor,
|
|
1000
|
+
**kwargs
|
|
1001
|
+
)
|
|
989
1002
|
except pymysql.Error as e:
|
|
990
1003
|
raise ValidationError(e)
|
|
991
1004
|
|
|
@@ -1015,12 +1028,13 @@ class VannaBase(ABC):
|
|
|
1015
1028
|
self.run_sql = run_sql_mysql
|
|
1016
1029
|
|
|
1017
1030
|
def connect_to_clickhouse(
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1031
|
+
self,
|
|
1032
|
+
host: str = None,
|
|
1033
|
+
dbname: str = None,
|
|
1034
|
+
user: str = None,
|
|
1035
|
+
password: str = None,
|
|
1036
|
+
port: int = None,
|
|
1037
|
+
**kwargs
|
|
1024
1038
|
):
|
|
1025
1039
|
|
|
1026
1040
|
try:
|
|
@@ -1070,6 +1084,7 @@ class VannaBase(ABC):
|
|
|
1070
1084
|
username=user,
|
|
1071
1085
|
password=password,
|
|
1072
1086
|
database=dbname,
|
|
1087
|
+
**kwargs
|
|
1073
1088
|
)
|
|
1074
1089
|
print(conn)
|
|
1075
1090
|
except Exception as e:
|
|
@@ -1087,15 +1102,16 @@ class VannaBase(ABC):
|
|
|
1087
1102
|
|
|
1088
1103
|
except Exception as e:
|
|
1089
1104
|
raise e
|
|
1090
|
-
|
|
1105
|
+
|
|
1091
1106
|
self.run_sql_is_set = True
|
|
1092
1107
|
self.run_sql = run_sql_clickhouse
|
|
1093
1108
|
|
|
1094
1109
|
def connect_to_oracle(
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1110
|
+
self,
|
|
1111
|
+
user: str = None,
|
|
1112
|
+
password: str = None,
|
|
1113
|
+
dsn: str = None,
|
|
1114
|
+
**kwargs
|
|
1099
1115
|
):
|
|
1100
1116
|
|
|
1101
1117
|
"""
|
|
@@ -1148,7 +1164,8 @@ class VannaBase(ABC):
|
|
|
1148
1164
|
user=user,
|
|
1149
1165
|
password=password,
|
|
1150
1166
|
dsn=dsn,
|
|
1151
|
-
|
|
1167
|
+
**kwargs
|
|
1168
|
+
)
|
|
1152
1169
|
except oracledb.Error as e:
|
|
1153
1170
|
raise ValidationError(e)
|
|
1154
1171
|
|
|
@@ -1180,7 +1197,12 @@ class VannaBase(ABC):
|
|
|
1180
1197
|
self.run_sql_is_set = True
|
|
1181
1198
|
self.run_sql = run_sql_oracle
|
|
1182
1199
|
|
|
1183
|
-
def connect_to_bigquery(
|
|
1200
|
+
def connect_to_bigquery(
|
|
1201
|
+
self,
|
|
1202
|
+
cred_file_path: str = None,
|
|
1203
|
+
project_id: str = None,
|
|
1204
|
+
**kwargs
|
|
1205
|
+
):
|
|
1184
1206
|
"""
|
|
1185
1207
|
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1186
1208
|
**Example:**
|
|
@@ -1242,7 +1264,11 @@ class VannaBase(ABC):
|
|
|
1242
1264
|
)
|
|
1243
1265
|
|
|
1244
1266
|
try:
|
|
1245
|
-
conn = bigquery.Client(
|
|
1267
|
+
conn = bigquery.Client(
|
|
1268
|
+
project=project_id,
|
|
1269
|
+
credentials=credentials,
|
|
1270
|
+
**kwargs
|
|
1271
|
+
)
|
|
1246
1272
|
except:
|
|
1247
1273
|
raise ImproperlyConfigured(
|
|
1248
1274
|
"Could not connect to bigquery please correct credentials"
|
|
@@ -1265,7 +1291,7 @@ class VannaBase(ABC):
|
|
|
1265
1291
|
self.run_sql_is_set = True
|
|
1266
1292
|
self.run_sql = run_sql_bigquery
|
|
1267
1293
|
|
|
1268
|
-
def connect_to_duckdb(self, url: str, init_sql: str = None):
|
|
1294
|
+
def connect_to_duckdb(self, url: str, init_sql: str = None, **kwargs):
|
|
1269
1295
|
"""
|
|
1270
1296
|
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1271
1297
|
|
|
@@ -1303,7 +1329,7 @@ class VannaBase(ABC):
|
|
|
1303
1329
|
f.write(response.content)
|
|
1304
1330
|
|
|
1305
1331
|
# Connect to the database
|
|
1306
|
-
conn = duckdb.connect(path)
|
|
1332
|
+
conn = duckdb.connect(path, **kwargs)
|
|
1307
1333
|
if init_sql:
|
|
1308
1334
|
conn.query(init_sql)
|
|
1309
1335
|
|
|
@@ -1314,7 +1340,7 @@ class VannaBase(ABC):
|
|
|
1314
1340
|
self.run_sql = run_sql_duckdb
|
|
1315
1341
|
self.run_sql_is_set = True
|
|
1316
1342
|
|
|
1317
|
-
def connect_to_mssql(self, odbc_conn_str: str):
|
|
1343
|
+
def connect_to_mssql(self, odbc_conn_str: str, **kwargs):
|
|
1318
1344
|
"""
|
|
1319
1345
|
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
1320
1346
|
|
|
@@ -1347,7 +1373,7 @@ class VannaBase(ABC):
|
|
|
1347
1373
|
|
|
1348
1374
|
from sqlalchemy import create_engine
|
|
1349
1375
|
|
|
1350
|
-
engine = create_engine(connection_url)
|
|
1376
|
+
engine = create_engine(connection_url, **kwargs)
|
|
1351
1377
|
|
|
1352
1378
|
def run_sql_mssql(sql: str):
|
|
1353
1379
|
# Execute the SQL statement and return the result as a pandas DataFrame
|
|
@@ -1362,16 +1388,17 @@ class VannaBase(ABC):
|
|
|
1362
1388
|
self.run_sql = run_sql_mssql
|
|
1363
1389
|
self.run_sql_is_set = True
|
|
1364
1390
|
def connect_to_presto(
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1391
|
+
self,
|
|
1392
|
+
host: str,
|
|
1393
|
+
catalog: str = 'hive',
|
|
1394
|
+
schema: str = 'default',
|
|
1395
|
+
user: str = None,
|
|
1396
|
+
password: str = None,
|
|
1397
|
+
port: int = None,
|
|
1398
|
+
combined_pem_path: str = None,
|
|
1399
|
+
protocol: str = 'https',
|
|
1400
|
+
requests_kwargs: dict = None,
|
|
1401
|
+
**kwargs
|
|
1375
1402
|
):
|
|
1376
1403
|
"""
|
|
1377
1404
|
Connect to a Presto database using the specified parameters.
|
|
@@ -1444,7 +1471,8 @@ class VannaBase(ABC):
|
|
|
1444
1471
|
schema=schema,
|
|
1445
1472
|
port=port,
|
|
1446
1473
|
protocol=protocol,
|
|
1447
|
-
requests_kwargs=requests_kwargs
|
|
1474
|
+
requests_kwargs=requests_kwargs,
|
|
1475
|
+
**kwargs)
|
|
1448
1476
|
except presto.Error as e:
|
|
1449
1477
|
raise ValidationError(e)
|
|
1450
1478
|
|
|
@@ -1477,13 +1505,14 @@ class VannaBase(ABC):
|
|
|
1477
1505
|
self.run_sql = run_sql_presto
|
|
1478
1506
|
|
|
1479
1507
|
def connect_to_hive(
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1508
|
+
self,
|
|
1509
|
+
host: str = None,
|
|
1510
|
+
dbname: str = 'default',
|
|
1511
|
+
user: str = None,
|
|
1512
|
+
password: str = None,
|
|
1513
|
+
port: int = None,
|
|
1514
|
+
auth: str = 'CUSTOM',
|
|
1515
|
+
**kwargs
|
|
1487
1516
|
):
|
|
1488
1517
|
"""
|
|
1489
1518
|
Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
vanna/flask/__init__.py
CHANGED
|
@@ -8,11 +8,13 @@ from functools import wraps
|
|
|
8
8
|
|
|
9
9
|
import flask
|
|
10
10
|
import requests
|
|
11
|
+
from flasgger import Swagger
|
|
11
12
|
from flask import Flask, Response, jsonify, request, send_from_directory
|
|
12
13
|
from flask_sock import Sock
|
|
13
14
|
|
|
14
15
|
from .assets import css_content, html_content, js_content
|
|
15
16
|
from .auth import AuthInterface, NoAuth
|
|
17
|
+
from ..base import VannaBase
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Cache(ABC):
|
|
@@ -88,7 +90,8 @@ class MemoryCache(Cache):
|
|
|
88
90
|
if id in self.cache:
|
|
89
91
|
del self.cache[id]
|
|
90
92
|
|
|
91
|
-
|
|
93
|
+
|
|
94
|
+
class VannaFlaskAPI:
|
|
92
95
|
flask_app = None
|
|
93
96
|
|
|
94
97
|
def requires_cache(self, required_fields, optional_fields=[]):
|
|
@@ -135,30 +138,17 @@ class VannaFlaskApp:
|
|
|
135
138
|
|
|
136
139
|
return decorated
|
|
137
140
|
|
|
138
|
-
def __init__(
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
sql=True,
|
|
148
|
-
table=True,
|
|
149
|
-
csv_download=True,
|
|
150
|
-
chart=True,
|
|
151
|
-
redraw_chart=True,
|
|
152
|
-
auto_fix_sql=True,
|
|
153
|
-
ask_results_correct=True,
|
|
154
|
-
followup_questions=True,
|
|
155
|
-
summarization=True,
|
|
156
|
-
function_generation=True,
|
|
157
|
-
index_html_path=None,
|
|
158
|
-
assets_folder=None,
|
|
159
|
-
):
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
vn: VannaBase,
|
|
144
|
+
cache: Cache = MemoryCache(),
|
|
145
|
+
auth: AuthInterface = NoAuth(),
|
|
146
|
+
debug=True,
|
|
147
|
+
allow_llm_to_see_data=False,
|
|
148
|
+
chart=True,
|
|
149
|
+
):
|
|
160
150
|
"""
|
|
161
|
-
Expose a Flask
|
|
151
|
+
Expose a Flask API that can be used to interact with a Vanna instance.
|
|
162
152
|
|
|
163
153
|
Args:
|
|
164
154
|
vn: The Vanna instance to interact with.
|
|
@@ -166,52 +156,30 @@ class VannaFlaskApp:
|
|
|
166
156
|
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
|
|
167
157
|
debug: Show the debug console. Defaults to True.
|
|
168
158
|
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
|
|
169
|
-
logo: The logo to display in the UI. Defaults to the Vanna logo.
|
|
170
|
-
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
|
|
171
|
-
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
|
|
172
|
-
show_training_data: Whether to show the training data in the UI. Defaults to True.
|
|
173
|
-
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
|
|
174
|
-
sql: Whether to show the SQL input in the UI. Defaults to True.
|
|
175
|
-
table: Whether to show the table output in the UI. Defaults to True.
|
|
176
|
-
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
|
|
177
159
|
chart: Whether to show the chart output in the UI. Defaults to True.
|
|
178
|
-
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
|
|
179
|
-
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
|
|
180
|
-
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
|
|
181
|
-
followup_questions: Whether to show followup questions. Defaults to True.
|
|
182
|
-
summarization: Whether to show summarization. Defaults to True.
|
|
183
|
-
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
|
|
184
|
-
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
|
|
185
160
|
|
|
186
161
|
Returns:
|
|
187
162
|
None
|
|
188
163
|
"""
|
|
164
|
+
|
|
189
165
|
self.flask_app = Flask(__name__)
|
|
166
|
+
|
|
167
|
+
self.swagger = Swagger(
|
|
168
|
+
self.flask_app, template={"info": {"title": "Vanna API"}}
|
|
169
|
+
)
|
|
190
170
|
self.sock = Sock(self.flask_app)
|
|
191
171
|
self.ws_clients = []
|
|
192
172
|
self.vn = vn
|
|
193
|
-
self.debug = debug
|
|
194
173
|
self.auth = auth
|
|
195
174
|
self.cache = cache
|
|
175
|
+
self.debug = debug
|
|
196
176
|
self.allow_llm_to_see_data = allow_llm_to_see_data
|
|
197
|
-
self.logo = logo
|
|
198
|
-
self.title = title
|
|
199
|
-
self.subtitle = subtitle
|
|
200
|
-
self.show_training_data = show_training_data
|
|
201
|
-
self.suggested_questions = suggested_questions
|
|
202
|
-
self.sql = sql
|
|
203
|
-
self.table = table
|
|
204
|
-
self.csv_download = csv_download
|
|
205
177
|
self.chart = chart
|
|
206
|
-
self.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
self.function_generation = function_generation and hasattr(vn, "get_function")
|
|
212
|
-
self.index_html_path = index_html_path
|
|
213
|
-
self.assets_folder = assets_folder
|
|
214
|
-
|
|
178
|
+
self.config = {
|
|
179
|
+
"debug": debug,
|
|
180
|
+
"allow_llm_to_see_data": allow_llm_to_see_data,
|
|
181
|
+
"chart": chart,
|
|
182
|
+
}
|
|
215
183
|
log = logging.getLogger("werkzeug")
|
|
216
184
|
log.setLevel(logging.ERROR)
|
|
217
185
|
|
|
@@ -225,42 +193,27 @@ class VannaFlaskApp:
|
|
|
225
193
|
|
|
226
194
|
self.vn.log = log
|
|
227
195
|
|
|
228
|
-
@self.flask_app.route("/auth/login", methods=["POST"])
|
|
229
|
-
def login():
|
|
230
|
-
return self.auth.login_handler(flask.request)
|
|
231
|
-
|
|
232
|
-
@self.flask_app.route("/auth/callback", methods=["GET"])
|
|
233
|
-
def callback():
|
|
234
|
-
return self.auth.callback_handler(flask.request)
|
|
235
|
-
|
|
236
|
-
@self.flask_app.route("/auth/logout", methods=["GET"])
|
|
237
|
-
def logout():
|
|
238
|
-
return self.auth.logout_handler(flask.request)
|
|
239
|
-
|
|
240
196
|
@self.flask_app.route("/api/v0/get_config", methods=["GET"])
|
|
241
197
|
@self.requires_auth
|
|
242
198
|
def get_config(user: any):
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
config = self.auth.override_config_for_user(user, config)
|
|
263
|
-
|
|
199
|
+
"""
|
|
200
|
+
Get the configuration for a user
|
|
201
|
+
---
|
|
202
|
+
parameters:
|
|
203
|
+
- name: user
|
|
204
|
+
in: query
|
|
205
|
+
responses:
|
|
206
|
+
200:
|
|
207
|
+
schema:
|
|
208
|
+
type: object
|
|
209
|
+
properties:
|
|
210
|
+
type:
|
|
211
|
+
type: string
|
|
212
|
+
default: config
|
|
213
|
+
config:
|
|
214
|
+
type: object
|
|
215
|
+
"""
|
|
216
|
+
config = self.auth.override_config_for_user(user, self.config)
|
|
264
217
|
return jsonify(
|
|
265
218
|
{
|
|
266
219
|
"type": "config",
|
|
@@ -271,6 +224,28 @@ class VannaFlaskApp:
|
|
|
271
224
|
@self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
|
|
272
225
|
@self.requires_auth
|
|
273
226
|
def generate_questions(user: any):
|
|
227
|
+
"""
|
|
228
|
+
Generate questions
|
|
229
|
+
---
|
|
230
|
+
parameters:
|
|
231
|
+
- name: user
|
|
232
|
+
in: query
|
|
233
|
+
responses:
|
|
234
|
+
200:
|
|
235
|
+
schema:
|
|
236
|
+
type: object
|
|
237
|
+
properties:
|
|
238
|
+
type:
|
|
239
|
+
type: string
|
|
240
|
+
default: question_list
|
|
241
|
+
questions:
|
|
242
|
+
type: array
|
|
243
|
+
items:
|
|
244
|
+
type: string
|
|
245
|
+
header:
|
|
246
|
+
type: string
|
|
247
|
+
default: Here are some questions you can ask
|
|
248
|
+
"""
|
|
274
249
|
# If self has an _model attribute and model=='chinook'
|
|
275
250
|
if hasattr(self.vn, "_model") and self.vn._model == "chinook":
|
|
276
251
|
return jsonify(
|
|
@@ -327,6 +302,29 @@ class VannaFlaskApp:
|
|
|
327
302
|
@self.flask_app.route("/api/v0/generate_sql", methods=["GET"])
|
|
328
303
|
@self.requires_auth
|
|
329
304
|
def generate_sql(user: any):
|
|
305
|
+
"""
|
|
306
|
+
Generate SQL from a question
|
|
307
|
+
---
|
|
308
|
+
parameters:
|
|
309
|
+
- name: user
|
|
310
|
+
in: query
|
|
311
|
+
- name: question
|
|
312
|
+
in: query
|
|
313
|
+
type: string
|
|
314
|
+
required: true
|
|
315
|
+
responses:
|
|
316
|
+
200:
|
|
317
|
+
schema:
|
|
318
|
+
type: object
|
|
319
|
+
properties:
|
|
320
|
+
type:
|
|
321
|
+
type: string
|
|
322
|
+
default: sql
|
|
323
|
+
id:
|
|
324
|
+
type: string
|
|
325
|
+
text:
|
|
326
|
+
type: string
|
|
327
|
+
"""
|
|
330
328
|
question = flask.request.args.get("question")
|
|
331
329
|
|
|
332
330
|
if question is None:
|
|
@@ -358,6 +356,29 @@ class VannaFlaskApp:
|
|
|
358
356
|
@self.flask_app.route("/api/v0/get_function", methods=["GET"])
|
|
359
357
|
@self.requires_auth
|
|
360
358
|
def get_function(user: any):
|
|
359
|
+
"""
|
|
360
|
+
Get a function from a question
|
|
361
|
+
---
|
|
362
|
+
parameters:
|
|
363
|
+
- name: user
|
|
364
|
+
in: query
|
|
365
|
+
- name: question
|
|
366
|
+
in: query
|
|
367
|
+
type: string
|
|
368
|
+
required: true
|
|
369
|
+
responses:
|
|
370
|
+
200:
|
|
371
|
+
schema:
|
|
372
|
+
type: object
|
|
373
|
+
properties:
|
|
374
|
+
type:
|
|
375
|
+
type: string
|
|
376
|
+
default: function
|
|
377
|
+
id:
|
|
378
|
+
type: object
|
|
379
|
+
function:
|
|
380
|
+
type: string
|
|
381
|
+
"""
|
|
361
382
|
question = flask.request.args.get("question")
|
|
362
383
|
|
|
363
384
|
if question is None:
|
|
@@ -393,6 +414,23 @@ class VannaFlaskApp:
|
|
|
393
414
|
@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
|
|
394
415
|
@self.requires_auth
|
|
395
416
|
def get_all_functions(user: any):
|
|
417
|
+
"""
|
|
418
|
+
Get all the functions
|
|
419
|
+
---
|
|
420
|
+
parameters:
|
|
421
|
+
- name: user
|
|
422
|
+
in: query
|
|
423
|
+
responses:
|
|
424
|
+
200:
|
|
425
|
+
schema:
|
|
426
|
+
type: object
|
|
427
|
+
properties:
|
|
428
|
+
type:
|
|
429
|
+
type: string
|
|
430
|
+
default: functions
|
|
431
|
+
functions:
|
|
432
|
+
type: array
|
|
433
|
+
"""
|
|
396
434
|
if not hasattr(vn, "get_all_functions"):
|
|
397
435
|
return jsonify({"type": "error", "error": "This setup does not support function generation."})
|
|
398
436
|
|
|
@@ -409,6 +447,31 @@ class VannaFlaskApp:
|
|
|
409
447
|
@self.requires_auth
|
|
410
448
|
@self.requires_cache(["sql"])
|
|
411
449
|
def run_sql(user: any, id: str, sql: str):
|
|
450
|
+
"""
|
|
451
|
+
Run SQL
|
|
452
|
+
---
|
|
453
|
+
parameters:
|
|
454
|
+
- name: user
|
|
455
|
+
in: query
|
|
456
|
+
- name: id
|
|
457
|
+
in: query|body
|
|
458
|
+
type: string
|
|
459
|
+
required: true
|
|
460
|
+
responses:
|
|
461
|
+
200:
|
|
462
|
+
schema:
|
|
463
|
+
type: object
|
|
464
|
+
properties:
|
|
465
|
+
type:
|
|
466
|
+
type: string
|
|
467
|
+
default: df
|
|
468
|
+
id:
|
|
469
|
+
type: string
|
|
470
|
+
df:
|
|
471
|
+
type: object
|
|
472
|
+
should_generate_chart:
|
|
473
|
+
type: boolean
|
|
474
|
+
"""
|
|
412
475
|
try:
|
|
413
476
|
if not vn.run_sql_is_set:
|
|
414
477
|
return jsonify(
|
|
@@ -437,7 +500,34 @@ class VannaFlaskApp:
|
|
|
437
500
|
@self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
|
|
438
501
|
@self.requires_auth
|
|
439
502
|
@self.requires_cache(["question", "sql"])
|
|
440
|
-
def fix_sql(user: any, id: str, question:str, sql: str):
|
|
503
|
+
def fix_sql(user: any, id: str, question: str, sql: str):
|
|
504
|
+
"""
|
|
505
|
+
Fix SQL
|
|
506
|
+
---
|
|
507
|
+
parameters:
|
|
508
|
+
- name: user
|
|
509
|
+
in: query
|
|
510
|
+
- name: id
|
|
511
|
+
in: query|body
|
|
512
|
+
type: string
|
|
513
|
+
required: true
|
|
514
|
+
- name: error
|
|
515
|
+
in: body
|
|
516
|
+
type: string
|
|
517
|
+
required: true
|
|
518
|
+
responses:
|
|
519
|
+
200:
|
|
520
|
+
schema:
|
|
521
|
+
type: object
|
|
522
|
+
properties:
|
|
523
|
+
type:
|
|
524
|
+
type: string
|
|
525
|
+
default: sql
|
|
526
|
+
id:
|
|
527
|
+
type: string
|
|
528
|
+
text:
|
|
529
|
+
type: string
|
|
530
|
+
"""
|
|
441
531
|
error = flask.request.json.get("error")
|
|
442
532
|
|
|
443
533
|
if error is None:
|
|
@@ -462,6 +552,33 @@ class VannaFlaskApp:
|
|
|
462
552
|
@self.requires_auth
|
|
463
553
|
@self.requires_cache([])
|
|
464
554
|
def update_sql(user: any, id: str):
|
|
555
|
+
"""
|
|
556
|
+
Update SQL
|
|
557
|
+
---
|
|
558
|
+
parameters:
|
|
559
|
+
- name: user
|
|
560
|
+
in: query
|
|
561
|
+
- name: id
|
|
562
|
+
in: query|body
|
|
563
|
+
type: string
|
|
564
|
+
required: true
|
|
565
|
+
- name: sql
|
|
566
|
+
in: body
|
|
567
|
+
type: string
|
|
568
|
+
required: true
|
|
569
|
+
responses:
|
|
570
|
+
200:
|
|
571
|
+
schema:
|
|
572
|
+
type: object
|
|
573
|
+
properties:
|
|
574
|
+
type:
|
|
575
|
+
type: string
|
|
576
|
+
default: sql
|
|
577
|
+
id:
|
|
578
|
+
type: string
|
|
579
|
+
text:
|
|
580
|
+
type: string
|
|
581
|
+
"""
|
|
465
582
|
sql = flask.request.json.get('sql')
|
|
466
583
|
|
|
467
584
|
if sql is None:
|
|
@@ -480,6 +597,20 @@ class VannaFlaskApp:
|
|
|
480
597
|
@self.requires_auth
|
|
481
598
|
@self.requires_cache(["df"])
|
|
482
599
|
def download_csv(user: any, id: str, df):
|
|
600
|
+
"""
|
|
601
|
+
Download CSV
|
|
602
|
+
---
|
|
603
|
+
parameters:
|
|
604
|
+
- name: user
|
|
605
|
+
in: query
|
|
606
|
+
- name: id
|
|
607
|
+
in: query|body
|
|
608
|
+
type: string
|
|
609
|
+
required: true
|
|
610
|
+
responses:
|
|
611
|
+
200:
|
|
612
|
+
description: download CSV
|
|
613
|
+
"""
|
|
483
614
|
csv = df.to_csv()
|
|
484
615
|
|
|
485
616
|
return Response(
|
|
@@ -492,6 +623,32 @@ class VannaFlaskApp:
|
|
|
492
623
|
@self.requires_auth
|
|
493
624
|
@self.requires_cache(["df", "question", "sql"])
|
|
494
625
|
def generate_plotly_figure(user: any, id: str, df, question, sql):
|
|
626
|
+
"""
|
|
627
|
+
Generate plotly figure
|
|
628
|
+
---
|
|
629
|
+
parameters:
|
|
630
|
+
- name: user
|
|
631
|
+
in: query
|
|
632
|
+
- name: id
|
|
633
|
+
in: query|body
|
|
634
|
+
type: string
|
|
635
|
+
required: true
|
|
636
|
+
- name: chart_instructions
|
|
637
|
+
in: body
|
|
638
|
+
type: string
|
|
639
|
+
responses:
|
|
640
|
+
200:
|
|
641
|
+
schema:
|
|
642
|
+
type: object
|
|
643
|
+
properties:
|
|
644
|
+
type:
|
|
645
|
+
type: string
|
|
646
|
+
default: plotly_figure
|
|
647
|
+
id:
|
|
648
|
+
type: string
|
|
649
|
+
fig:
|
|
650
|
+
type: object
|
|
651
|
+
"""
|
|
495
652
|
chart_instructions = flask.request.args.get('chart_instructions')
|
|
496
653
|
|
|
497
654
|
try:
|
|
@@ -530,6 +687,26 @@ class VannaFlaskApp:
|
|
|
530
687
|
@self.flask_app.route("/api/v0/get_training_data", methods=["GET"])
|
|
531
688
|
@self.requires_auth
|
|
532
689
|
def get_training_data(user: any):
|
|
690
|
+
"""
|
|
691
|
+
Get all training data
|
|
692
|
+
---
|
|
693
|
+
parameters:
|
|
694
|
+
- name: user
|
|
695
|
+
in: query
|
|
696
|
+
responses:
|
|
697
|
+
200:
|
|
698
|
+
schema:
|
|
699
|
+
type: object
|
|
700
|
+
properties:
|
|
701
|
+
type:
|
|
702
|
+
type: string
|
|
703
|
+
default: df
|
|
704
|
+
id:
|
|
705
|
+
type: string
|
|
706
|
+
default: training_data
|
|
707
|
+
df:
|
|
708
|
+
type: object
|
|
709
|
+
"""
|
|
533
710
|
df = vn.get_training_data()
|
|
534
711
|
|
|
535
712
|
if df is None or len(df) == 0:
|
|
@@ -551,6 +728,24 @@ class VannaFlaskApp:
|
|
|
551
728
|
@self.flask_app.route("/api/v0/remove_training_data", methods=["POST"])
|
|
552
729
|
@self.requires_auth
|
|
553
730
|
def remove_training_data(user: any):
|
|
731
|
+
"""
|
|
732
|
+
Remove training data
|
|
733
|
+
---
|
|
734
|
+
parameters:
|
|
735
|
+
- name: user
|
|
736
|
+
in: query
|
|
737
|
+
- name: id
|
|
738
|
+
in: body
|
|
739
|
+
type: string
|
|
740
|
+
required: true
|
|
741
|
+
responses:
|
|
742
|
+
200:
|
|
743
|
+
schema:
|
|
744
|
+
type: object
|
|
745
|
+
properties:
|
|
746
|
+
success:
|
|
747
|
+
type: boolean
|
|
748
|
+
"""
|
|
554
749
|
# Get id from the JSON body
|
|
555
750
|
id = flask.request.json.get("id")
|
|
556
751
|
|
|
@@ -567,6 +762,32 @@ class VannaFlaskApp:
|
|
|
567
762
|
@self.flask_app.route("/api/v0/train", methods=["POST"])
|
|
568
763
|
@self.requires_auth
|
|
569
764
|
def add_training_data(user: any):
|
|
765
|
+
"""
|
|
766
|
+
Add training data
|
|
767
|
+
---
|
|
768
|
+
parameters:
|
|
769
|
+
- name: user
|
|
770
|
+
in: query
|
|
771
|
+
- name: question
|
|
772
|
+
in: body
|
|
773
|
+
type: string
|
|
774
|
+
- name: sql
|
|
775
|
+
in: body
|
|
776
|
+
type: string
|
|
777
|
+
- name: ddl
|
|
778
|
+
in: body
|
|
779
|
+
type: string
|
|
780
|
+
- name: documentation
|
|
781
|
+
in: body
|
|
782
|
+
type: string
|
|
783
|
+
responses:
|
|
784
|
+
200:
|
|
785
|
+
schema:
|
|
786
|
+
type: object
|
|
787
|
+
properties:
|
|
788
|
+
id:
|
|
789
|
+
type: string
|
|
790
|
+
"""
|
|
570
791
|
question = flask.request.json.get("question")
|
|
571
792
|
sql = flask.request.json.get("sql")
|
|
572
793
|
ddl = flask.request.json.get("ddl")
|
|
@@ -586,6 +807,29 @@ class VannaFlaskApp:
|
|
|
586
807
|
@self.requires_auth
|
|
587
808
|
@self.requires_cache(["question", "sql"])
|
|
588
809
|
def create_function(user: any, id: str, question: str, sql: str):
|
|
810
|
+
"""
|
|
811
|
+
Create function
|
|
812
|
+
---
|
|
813
|
+
parameters:
|
|
814
|
+
- name: user
|
|
815
|
+
in: query
|
|
816
|
+
- name: id
|
|
817
|
+
in: query|body
|
|
818
|
+
type: string
|
|
819
|
+
required: true
|
|
820
|
+
responses:
|
|
821
|
+
200:
|
|
822
|
+
schema:
|
|
823
|
+
type: object
|
|
824
|
+
properties:
|
|
825
|
+
type:
|
|
826
|
+
type: string
|
|
827
|
+
default: function_template
|
|
828
|
+
id:
|
|
829
|
+
type: string
|
|
830
|
+
function_template:
|
|
831
|
+
type: object
|
|
832
|
+
"""
|
|
589
833
|
plotly_code = self.cache.get(id=id, field="plotly_code")
|
|
590
834
|
|
|
591
835
|
if plotly_code is None:
|
|
@@ -604,6 +848,28 @@ class VannaFlaskApp:
|
|
|
604
848
|
@self.flask_app.route("/api/v0/update_function", methods=["POST"])
|
|
605
849
|
@self.requires_auth
|
|
606
850
|
def update_function(user: any):
|
|
851
|
+
"""
|
|
852
|
+
Update function
|
|
853
|
+
---
|
|
854
|
+
parameters:
|
|
855
|
+
- name: user
|
|
856
|
+
in: query
|
|
857
|
+
- name: old_function_name
|
|
858
|
+
in: body
|
|
859
|
+
type: string
|
|
860
|
+
required: true
|
|
861
|
+
- name: updated_function
|
|
862
|
+
in: body
|
|
863
|
+
type: object
|
|
864
|
+
required: true
|
|
865
|
+
responses:
|
|
866
|
+
200:
|
|
867
|
+
schema:
|
|
868
|
+
type: object
|
|
869
|
+
properties:
|
|
870
|
+
success:
|
|
871
|
+
type: boolean
|
|
872
|
+
"""
|
|
607
873
|
old_function_name = flask.request.json.get("old_function_name")
|
|
608
874
|
updated_function = flask.request.json.get("updated_function")
|
|
609
875
|
|
|
@@ -617,15 +883,57 @@ class VannaFlaskApp:
|
|
|
617
883
|
@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
|
|
618
884
|
@self.requires_auth
|
|
619
885
|
def delete_function(user: any):
|
|
886
|
+
"""
|
|
887
|
+
Delete function
|
|
888
|
+
---
|
|
889
|
+
parameters:
|
|
890
|
+
- name: user
|
|
891
|
+
in: query
|
|
892
|
+
- name: function_name
|
|
893
|
+
in: body
|
|
894
|
+
type: string
|
|
895
|
+
required: true
|
|
896
|
+
responses:
|
|
897
|
+
200:
|
|
898
|
+
schema:
|
|
899
|
+
type: object
|
|
900
|
+
properties:
|
|
901
|
+
success:
|
|
902
|
+
type: boolean
|
|
903
|
+
"""
|
|
620
904
|
function_name = flask.request.json.get("function_name")
|
|
621
905
|
|
|
622
906
|
return jsonify({"success": vn.delete_function(function_name=function_name)})
|
|
623
907
|
|
|
624
|
-
|
|
625
908
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
626
909
|
@self.requires_auth
|
|
627
910
|
@self.requires_cache(["df", "question", "sql"])
|
|
628
911
|
def generate_followup_questions(user: any, id: str, df, question, sql):
|
|
912
|
+
"""
|
|
913
|
+
Generate followup questions
|
|
914
|
+
---
|
|
915
|
+
parameters:
|
|
916
|
+
- name: user
|
|
917
|
+
in: query
|
|
918
|
+
- name: id
|
|
919
|
+
in: query|body
|
|
920
|
+
type: string
|
|
921
|
+
required: true
|
|
922
|
+
responses:
|
|
923
|
+
200:
|
|
924
|
+
schema:
|
|
925
|
+
type: object
|
|
926
|
+
properties:
|
|
927
|
+
type:
|
|
928
|
+
type: string
|
|
929
|
+
default: question_list
|
|
930
|
+
questions:
|
|
931
|
+
type: array
|
|
932
|
+
items:
|
|
933
|
+
type: string
|
|
934
|
+
header:
|
|
935
|
+
type: string
|
|
936
|
+
"""
|
|
629
937
|
if self.allow_llm_to_see_data:
|
|
630
938
|
followup_questions = vn.generate_followup_questions(
|
|
631
939
|
question=question, sql=sql, df=df
|
|
@@ -658,6 +966,29 @@ class VannaFlaskApp:
|
|
|
658
966
|
@self.requires_auth
|
|
659
967
|
@self.requires_cache(["df", "question"])
|
|
660
968
|
def generate_summary(user: any, id: str, df, question):
|
|
969
|
+
"""
|
|
970
|
+
Generate summary
|
|
971
|
+
---
|
|
972
|
+
parameters:
|
|
973
|
+
- name: user
|
|
974
|
+
in: query
|
|
975
|
+
- name: id
|
|
976
|
+
in: query|body
|
|
977
|
+
type: string
|
|
978
|
+
required: true
|
|
979
|
+
responses:
|
|
980
|
+
200:
|
|
981
|
+
schema:
|
|
982
|
+
type: object
|
|
983
|
+
properties:
|
|
984
|
+
type:
|
|
985
|
+
type: string
|
|
986
|
+
default: text
|
|
987
|
+
id:
|
|
988
|
+
type: string
|
|
989
|
+
text:
|
|
990
|
+
type: string
|
|
991
|
+
"""
|
|
661
992
|
if self.allow_llm_to_see_data:
|
|
662
993
|
summary = vn.generate_summary(question=question, df=df)
|
|
663
994
|
|
|
@@ -686,6 +1017,37 @@ class VannaFlaskApp:
|
|
|
686
1017
|
optional_fields=["summary", "fig_json"]
|
|
687
1018
|
)
|
|
688
1019
|
def load_question(user: any, id: str, question, sql, df, fig_json, summary):
|
|
1020
|
+
"""
|
|
1021
|
+
Load question
|
|
1022
|
+
---
|
|
1023
|
+
parameters:
|
|
1024
|
+
- name: user
|
|
1025
|
+
in: query
|
|
1026
|
+
- name: id
|
|
1027
|
+
in: query|body
|
|
1028
|
+
type: string
|
|
1029
|
+
required: true
|
|
1030
|
+
responses:
|
|
1031
|
+
200:
|
|
1032
|
+
schema:
|
|
1033
|
+
type: object
|
|
1034
|
+
properties:
|
|
1035
|
+
type:
|
|
1036
|
+
type: string
|
|
1037
|
+
default: question_cache
|
|
1038
|
+
id:
|
|
1039
|
+
type: string
|
|
1040
|
+
question:
|
|
1041
|
+
type: string
|
|
1042
|
+
sql:
|
|
1043
|
+
type: string
|
|
1044
|
+
df:
|
|
1045
|
+
type: object
|
|
1046
|
+
fig:
|
|
1047
|
+
type: object
|
|
1048
|
+
summary:
|
|
1049
|
+
type: string
|
|
1050
|
+
"""
|
|
689
1051
|
try:
|
|
690
1052
|
return jsonify(
|
|
691
1053
|
{
|
|
@@ -705,6 +1067,25 @@ class VannaFlaskApp:
|
|
|
705
1067
|
@self.flask_app.route("/api/v0/get_question_history", methods=["GET"])
|
|
706
1068
|
@self.requires_auth
|
|
707
1069
|
def get_question_history(user: any):
|
|
1070
|
+
"""
|
|
1071
|
+
Get question history
|
|
1072
|
+
---
|
|
1073
|
+
parameters:
|
|
1074
|
+
- name: user
|
|
1075
|
+
in: query
|
|
1076
|
+
responses:
|
|
1077
|
+
200:
|
|
1078
|
+
schema:
|
|
1079
|
+
type: object
|
|
1080
|
+
properties:
|
|
1081
|
+
type:
|
|
1082
|
+
type: string
|
|
1083
|
+
default: question_history
|
|
1084
|
+
questions:
|
|
1085
|
+
type: array
|
|
1086
|
+
items:
|
|
1087
|
+
type: string
|
|
1088
|
+
"""
|
|
708
1089
|
return jsonify(
|
|
709
1090
|
{
|
|
710
1091
|
"type": "question_history",
|
|
@@ -718,6 +1099,136 @@ class VannaFlaskApp:
|
|
|
718
1099
|
{"type": "error", "error": "The rest of the API is not ported yet."}
|
|
719
1100
|
)
|
|
720
1101
|
|
|
1102
|
+
if self.debug:
|
|
1103
|
+
@self.sock.route("/api/v0/log")
|
|
1104
|
+
def sock_log(ws):
|
|
1105
|
+
self.ws_clients.append(ws)
|
|
1106
|
+
|
|
1107
|
+
try:
|
|
1108
|
+
while True:
|
|
1109
|
+
message = ws.receive() # This example just reads and ignores to keep the socket open
|
|
1110
|
+
finally:
|
|
1111
|
+
self.ws_clients.remove(ws)
|
|
1112
|
+
|
|
1113
|
+
def run(self, *args, **kwargs):
|
|
1114
|
+
"""
|
|
1115
|
+
Run the Flask app.
|
|
1116
|
+
|
|
1117
|
+
Args:
|
|
1118
|
+
*args: Arguments to pass to Flask's run method.
|
|
1119
|
+
**kwargs: Keyword arguments to pass to Flask's run method.
|
|
1120
|
+
|
|
1121
|
+
Returns:
|
|
1122
|
+
None
|
|
1123
|
+
"""
|
|
1124
|
+
if args or kwargs:
|
|
1125
|
+
self.flask_app.run(*args, **kwargs)
|
|
1126
|
+
|
|
1127
|
+
else:
|
|
1128
|
+
try:
|
|
1129
|
+
from google.colab import output
|
|
1130
|
+
|
|
1131
|
+
output.serve_kernel_port_as_window(8084)
|
|
1132
|
+
from google.colab.output import eval_js
|
|
1133
|
+
|
|
1134
|
+
print("Your app is running at:")
|
|
1135
|
+
print(eval_js("google.colab.kernel.proxyPort(8084)"))
|
|
1136
|
+
except:
|
|
1137
|
+
print("Your app is running at:")
|
|
1138
|
+
print("http://localhost:8084")
|
|
1139
|
+
|
|
1140
|
+
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
class VannaFlaskApp(VannaFlaskAPI):
|
|
1144
|
+
def __init__(
|
|
1145
|
+
self,
|
|
1146
|
+
vn: VannaBase,
|
|
1147
|
+
cache: Cache = MemoryCache(),
|
|
1148
|
+
auth: AuthInterface = NoAuth(),
|
|
1149
|
+
debug=True,
|
|
1150
|
+
allow_llm_to_see_data=False,
|
|
1151
|
+
logo="https://img.vanna.ai/vanna-flask.svg",
|
|
1152
|
+
title="Welcome to Vanna.AI",
|
|
1153
|
+
subtitle="Your AI-powered copilot for SQL queries.",
|
|
1154
|
+
show_training_data=True,
|
|
1155
|
+
suggested_questions=True,
|
|
1156
|
+
sql=True,
|
|
1157
|
+
table=True,
|
|
1158
|
+
csv_download=True,
|
|
1159
|
+
chart=True,
|
|
1160
|
+
redraw_chart=True,
|
|
1161
|
+
auto_fix_sql=True,
|
|
1162
|
+
ask_results_correct=True,
|
|
1163
|
+
followup_questions=True,
|
|
1164
|
+
summarization=True,
|
|
1165
|
+
function_generation=True,
|
|
1166
|
+
index_html_path=None,
|
|
1167
|
+
assets_folder=None,
|
|
1168
|
+
):
|
|
1169
|
+
"""
|
|
1170
|
+
Expose a Flask app that can be used to interact with a Vanna instance.
|
|
1171
|
+
|
|
1172
|
+
Args:
|
|
1173
|
+
vn: The Vanna instance to interact with.
|
|
1174
|
+
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
|
|
1175
|
+
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
|
|
1176
|
+
debug: Show the debug console. Defaults to True.
|
|
1177
|
+
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
|
|
1178
|
+
logo: The logo to display in the UI. Defaults to the Vanna logo.
|
|
1179
|
+
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
|
|
1180
|
+
subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
|
|
1181
|
+
show_training_data: Whether to show the training data in the UI. Defaults to True.
|
|
1182
|
+
suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
|
|
1183
|
+
sql: Whether to show the SQL input in the UI. Defaults to True.
|
|
1184
|
+
table: Whether to show the table output in the UI. Defaults to True.
|
|
1185
|
+
csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
|
|
1186
|
+
chart: Whether to show the chart output in the UI. Defaults to True.
|
|
1187
|
+
redraw_chart: Whether to allow redrawing the chart. Defaults to True.
|
|
1188
|
+
auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
|
|
1189
|
+
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
|
|
1190
|
+
followup_questions: Whether to show followup questions. Defaults to True.
|
|
1191
|
+
summarization: Whether to show summarization. Defaults to True.
|
|
1192
|
+
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
|
|
1193
|
+
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
|
|
1194
|
+
|
|
1195
|
+
Returns:
|
|
1196
|
+
None
|
|
1197
|
+
"""
|
|
1198
|
+
super().__init__(vn, cache, auth, debug, allow_llm_to_see_data, chart)
|
|
1199
|
+
|
|
1200
|
+
self.config["logo"] = logo
|
|
1201
|
+
self.config["title"] = title
|
|
1202
|
+
self.config["subtitle"] = subtitle
|
|
1203
|
+
self.config["show_training_data"] = show_training_data
|
|
1204
|
+
self.config["suggested_questions"] = suggested_questions
|
|
1205
|
+
self.config["sql"] = sql
|
|
1206
|
+
self.config["table"] = table
|
|
1207
|
+
self.config["csv_download"] = csv_download
|
|
1208
|
+
self.config["chart"] = chart
|
|
1209
|
+
self.config["redraw_chart"] = redraw_chart
|
|
1210
|
+
self.config["auto_fix_sql"] = auto_fix_sql
|
|
1211
|
+
self.config["ask_results_correct"] = ask_results_correct
|
|
1212
|
+
self.config["followup_questions"] = followup_questions
|
|
1213
|
+
self.config["summarization"] = summarization
|
|
1214
|
+
self.config["function_generation"] = function_generation
|
|
1215
|
+
|
|
1216
|
+
self.index_html_path = index_html_path
|
|
1217
|
+
self.assets_folder = assets_folder
|
|
1218
|
+
|
|
1219
|
+
@self.flask_app.route("/auth/login", methods=["POST"])
|
|
1220
|
+
def login():
|
|
1221
|
+
return self.auth.login_handler(flask.request)
|
|
1222
|
+
|
|
1223
|
+
@self.flask_app.route("/auth/callback", methods=["GET"])
|
|
1224
|
+
def callback():
|
|
1225
|
+
return self.auth.callback_handler(flask.request)
|
|
1226
|
+
|
|
1227
|
+
@self.flask_app.route("/auth/logout", methods=["GET"])
|
|
1228
|
+
def logout():
|
|
1229
|
+
return self.auth.logout_handler(flask.request)
|
|
1230
|
+
|
|
1231
|
+
|
|
721
1232
|
@self.flask_app.route("/assets/<path:filename>")
|
|
722
1233
|
def proxy_assets(filename):
|
|
723
1234
|
if self.assets_folder:
|
|
@@ -755,18 +1266,6 @@ class VannaFlaskApp:
|
|
|
755
1266
|
else:
|
|
756
1267
|
return "Error fetching file from remote server", response.status_code
|
|
757
1268
|
|
|
758
|
-
if self.debug:
|
|
759
|
-
@self.sock.route("/api/v0/log")
|
|
760
|
-
def sock_log(ws):
|
|
761
|
-
self.ws_clients.append(ws)
|
|
762
|
-
|
|
763
|
-
try:
|
|
764
|
-
while True:
|
|
765
|
-
message = ws.receive() # This example just reads and ignores to keep the socket open
|
|
766
|
-
finally:
|
|
767
|
-
self.ws_clients.remove(ws)
|
|
768
|
-
|
|
769
|
-
|
|
770
1269
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
771
1270
|
@self.flask_app.route("/<path:path>")
|
|
772
1271
|
def hello(path: str):
|
|
@@ -775,32 +1274,3 @@ class VannaFlaskApp:
|
|
|
775
1274
|
filename = os.path.basename(self.index_html_path)
|
|
776
1275
|
return send_from_directory(directory=directory, path=filename)
|
|
777
1276
|
return html_content
|
|
778
|
-
|
|
779
|
-
def run(self, *args, **kwargs):
|
|
780
|
-
"""
|
|
781
|
-
Run the Flask app.
|
|
782
|
-
|
|
783
|
-
Args:
|
|
784
|
-
*args: Arguments to pass to Flask's run method.
|
|
785
|
-
**kwargs: Keyword arguments to pass to Flask's run method.
|
|
786
|
-
|
|
787
|
-
Returns:
|
|
788
|
-
None
|
|
789
|
-
"""
|
|
790
|
-
if args or kwargs:
|
|
791
|
-
self.flask_app.run(*args, **kwargs)
|
|
792
|
-
|
|
793
|
-
else:
|
|
794
|
-
try:
|
|
795
|
-
from google.colab import output
|
|
796
|
-
|
|
797
|
-
output.serve_kernel_port_as_window(8084)
|
|
798
|
-
from google.colab.output import eval_js
|
|
799
|
-
|
|
800
|
-
print("Your app is running at:")
|
|
801
|
-
print(eval_js("google.colab.kernel.proxyPort(8084)"))
|
|
802
|
-
except:
|
|
803
|
-
print("Your app is running at:")
|
|
804
|
-
print("http://localhost:8084")
|
|
805
|
-
|
|
806
|
-
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
|
vanna/openai/openai_chat.py
CHANGED
|
@@ -11,14 +11,10 @@ class OpenAI_Chat(VannaBase):
|
|
|
11
11
|
|
|
12
12
|
# default parameters - can be overrided using config
|
|
13
13
|
self.temperature = 0.7
|
|
14
|
-
self.max_tokens = 500
|
|
15
14
|
|
|
16
15
|
if "temperature" in config:
|
|
17
16
|
self.temperature = config["temperature"]
|
|
18
17
|
|
|
19
|
-
if "max_tokens" in config:
|
|
20
|
-
self.max_tokens = config["max_tokens"]
|
|
21
|
-
|
|
22
18
|
if "api_type" in config:
|
|
23
19
|
raise Exception(
|
|
24
20
|
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
|
@@ -75,7 +71,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
75
71
|
response = self.client.chat.completions.create(
|
|
76
72
|
model=model,
|
|
77
73
|
messages=prompt,
|
|
78
|
-
max_tokens=self.max_tokens,
|
|
79
74
|
stop=None,
|
|
80
75
|
temperature=self.temperature,
|
|
81
76
|
)
|
|
@@ -87,7 +82,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
87
82
|
response = self.client.chat.completions.create(
|
|
88
83
|
engine=engine,
|
|
89
84
|
messages=prompt,
|
|
90
|
-
max_tokens=self.max_tokens,
|
|
91
85
|
stop=None,
|
|
92
86
|
temperature=self.temperature,
|
|
93
87
|
)
|
|
@@ -98,7 +92,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
98
92
|
response = self.client.chat.completions.create(
|
|
99
93
|
engine=self.config["engine"],
|
|
100
94
|
messages=prompt,
|
|
101
|
-
max_tokens=self.max_tokens,
|
|
102
95
|
stop=None,
|
|
103
96
|
temperature=self.temperature,
|
|
104
97
|
)
|
|
@@ -109,7 +102,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
109
102
|
response = self.client.chat.completions.create(
|
|
110
103
|
model=self.config["model"],
|
|
111
104
|
messages=prompt,
|
|
112
|
-
max_tokens=self.max_tokens,
|
|
113
105
|
stop=None,
|
|
114
106
|
temperature=self.temperature,
|
|
115
107
|
)
|
|
@@ -123,7 +115,6 @@ class OpenAI_Chat(VannaBase):
|
|
|
123
115
|
response = self.client.chat.completions.create(
|
|
124
116
|
model=model,
|
|
125
117
|
messages=prompt,
|
|
126
|
-
max_tokens=self.max_tokens,
|
|
127
118
|
stop=None,
|
|
128
119
|
temperature=self.temperature,
|
|
129
120
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.4
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -16,6 +16,7 @@ Requires-Dist: sqlparse
|
|
|
16
16
|
Requires-Dist: kaleido
|
|
17
17
|
Requires-Dist: flask
|
|
18
18
|
Requires-Dist: flask-sock
|
|
19
|
+
Requires-Dist: flasgger
|
|
19
20
|
Requires-Dist: sqlalchemy
|
|
20
21
|
Requires-Dist: psycopg2-binary ; extra == "all"
|
|
21
22
|
Requires-Dist: db-dtypes ; extra == "all"
|
|
@@ -7,15 +7,15 @@ vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oE
|
|
|
7
7
|
vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
|
|
8
8
|
vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,658
|
|
9
9
|
vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
|
|
10
|
-
vanna/anthropic/anthropic_chat.py,sha256=
|
|
10
|
+
vanna/anthropic/anthropic_chat.py,sha256=7X3x8SYwDY28aGyBnt0YNRMG8YY1p_t-foMfKGj8_Oo,2627
|
|
11
11
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
12
|
-
vanna/base/base.py,sha256=
|
|
12
|
+
vanna/base/base.py,sha256=3Du70NrXQMn_LOif2YFPRRVKo4wH5-f6eZcLlXEX0X8,71705
|
|
13
13
|
vanna/bedrock/__init__.py,sha256=hRT2bgJbHEqViLdL-t9hfjSfFdIOkPU2ADBt-B1En-8,46
|
|
14
14
|
vanna/bedrock/bedrock_converse.py,sha256=Nx5kYm-diAfYmsWAnTP5xnv7V84Og69-AP9b3seIe0E,2869
|
|
15
15
|
vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
|
|
16
16
|
vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
|
|
17
17
|
vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
|
|
18
|
-
vanna/flask/__init__.py,sha256=
|
|
18
|
+
vanna/flask/__init__.py,sha256=r1ucQupb6wuTcjVVKpkdrg6R38FZe6KQoKw9AtcghDQ,42889
|
|
19
19
|
vanna/flask/assets.py,sha256=_UoUr57sS0QL2BuTxAOe9k4yy8T7-fp2NpbRSVtW3IM,451769
|
|
20
20
|
vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
|
|
21
21
|
vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
|
|
@@ -35,7 +35,7 @@ vanna/mock/vectordb.py,sha256=h45znfYMUnttE2BBC8v6TKeMaA58pFJL-5B3OGeRNFI,2681
|
|
|
35
35
|
vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
|
|
36
36
|
vanna/ollama/ollama.py,sha256=rXa7cfvdlO1E5SLysXIl3IZpIaA2r0RBvV5jX2-upiE,3794
|
|
37
37
|
vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
|
|
38
|
-
vanna/openai/openai_chat.py,sha256=
|
|
38
|
+
vanna/openai/openai_chat.py,sha256=KU6ynOQ5v7vwrQQ13phXoUXeQUrH6_vmhfiPvWddTrQ,4427
|
|
39
39
|
vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
|
|
40
40
|
vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
|
|
41
41
|
vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8QyVGI0eI,12226
|
|
@@ -50,6 +50,6 @@ vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
|
|
|
50
50
|
vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
|
|
51
51
|
vanna/weaviate/__init__.py,sha256=HL6PAl7ePBAkeG8uln-BmM7IUtWohyTPvDfcPzSGSCg,46
|
|
52
52
|
vanna/weaviate/weaviate_vector.py,sha256=GEiu4Vd9w-7j10aB-zTxJ8gefqe_F-LUUGvttFs1vlg,7539
|
|
53
|
-
vanna-0.6.
|
|
54
|
-
vanna-0.6.
|
|
55
|
-
vanna-0.6.
|
|
53
|
+
vanna-0.6.4.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
54
|
+
vanna-0.6.4.dist-info/METADATA,sha256=LqIi4Hg1y_aTEH79PX48nnY1TM-u6ese9K8Os9Cqkg0,11889
|
|
55
|
+
vanna-0.6.4.dist-info/RECORD,,
|
|
File without changes
|