vanna 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/flask/__init__.py CHANGED
@@ -8,11 +8,13 @@ from functools import wraps
8
8
 
9
9
  import flask
10
10
  import requests
11
+ from flasgger import Swagger
11
12
  from flask import Flask, Response, jsonify, request, send_from_directory
12
13
  from flask_sock import Sock
13
14
 
14
15
  from .assets import css_content, html_content, js_content
15
16
  from .auth import AuthInterface, NoAuth
17
+ from ..base import VannaBase
16
18
 
17
19
 
18
20
  class Cache(ABC):
@@ -88,7 +90,8 @@ class MemoryCache(Cache):
88
90
  if id in self.cache:
89
91
  del self.cache[id]
90
92
 
91
- class VannaFlaskApp:
93
+
94
+ class VannaFlaskAPI:
92
95
  flask_app = None
93
96
 
94
97
  def requires_cache(self, required_fields, optional_fields=[]):
@@ -135,30 +138,17 @@ class VannaFlaskApp:
135
138
 
136
139
  return decorated
137
140
 
138
- def __init__(self, vn, cache: Cache = MemoryCache(),
139
- auth: AuthInterface = NoAuth(),
140
- debug=True,
141
- allow_llm_to_see_data=False,
142
- logo="https://img.vanna.ai/vanna-flask.svg",
143
- title="Welcome to Vanna.AI",
144
- subtitle="Your AI-powered copilot for SQL queries.",
145
- show_training_data=True,
146
- suggested_questions=True,
147
- sql=True,
148
- table=True,
149
- csv_download=True,
150
- chart=True,
151
- redraw_chart=True,
152
- auto_fix_sql=True,
153
- ask_results_correct=True,
154
- followup_questions=True,
155
- summarization=True,
156
- function_generation=True,
157
- index_html_path=None,
158
- assets_folder=None,
159
- ):
141
+ def __init__(
142
+ self,
143
+ vn: VannaBase,
144
+ cache: Cache = MemoryCache(),
145
+ auth: AuthInterface = NoAuth(),
146
+ debug=True,
147
+ allow_llm_to_see_data=False,
148
+ chart=True,
149
+ ):
160
150
  """
161
- Expose a Flask app that can be used to interact with a Vanna instance.
151
+ Expose a Flask API that can be used to interact with a Vanna instance.
162
152
 
163
153
  Args:
164
154
  vn: The Vanna instance to interact with.
@@ -166,52 +156,30 @@ class VannaFlaskApp:
166
156
  auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
167
157
  debug: Show the debug console. Defaults to True.
168
158
  allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
169
- logo: The logo to display in the UI. Defaults to the Vanna logo.
170
- title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
171
- subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
172
- show_training_data: Whether to show the training data in the UI. Defaults to True.
173
- suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
174
- sql: Whether to show the SQL input in the UI. Defaults to True.
175
- table: Whether to show the table output in the UI. Defaults to True.
176
- csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
177
159
  chart: Whether to show the chart output in the UI. Defaults to True.
178
- redraw_chart: Whether to allow redrawing the chart. Defaults to True.
179
- auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
180
- ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
181
- followup_questions: Whether to show followup questions. Defaults to True.
182
- summarization: Whether to show summarization. Defaults to True.
183
- index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
184
- assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
185
160
 
186
161
  Returns:
187
162
  None
188
163
  """
164
+
189
165
  self.flask_app = Flask(__name__)
166
+
167
+ self.swagger = Swagger(
168
+ self.flask_app, template={"info": {"title": "Vanna API"}}
169
+ )
190
170
  self.sock = Sock(self.flask_app)
191
171
  self.ws_clients = []
192
172
  self.vn = vn
193
- self.debug = debug
194
173
  self.auth = auth
195
174
  self.cache = cache
175
+ self.debug = debug
196
176
  self.allow_llm_to_see_data = allow_llm_to_see_data
197
- self.logo = logo
198
- self.title = title
199
- self.subtitle = subtitle
200
- self.show_training_data = show_training_data
201
- self.suggested_questions = suggested_questions
202
- self.sql = sql
203
- self.table = table
204
- self.csv_download = csv_download
205
177
  self.chart = chart
206
- self.redraw_chart = redraw_chart
207
- self.auto_fix_sql = auto_fix_sql
208
- self.ask_results_correct = ask_results_correct
209
- self.followup_questions = followup_questions
210
- self.summarization = summarization
211
- self.function_generation = function_generation and hasattr(vn, "get_function")
212
- self.index_html_path = index_html_path
213
- self.assets_folder = assets_folder
214
-
178
+ self.config = {
179
+ "debug": debug,
180
+ "allow_llm_to_see_data": allow_llm_to_see_data,
181
+ "chart": chart,
182
+ }
215
183
  log = logging.getLogger("werkzeug")
216
184
  log.setLevel(logging.ERROR)
217
185
 
@@ -225,42 +193,27 @@ class VannaFlaskApp:
225
193
 
226
194
  self.vn.log = log
227
195
 
228
- @self.flask_app.route("/auth/login", methods=["POST"])
229
- def login():
230
- return self.auth.login_handler(flask.request)
231
-
232
- @self.flask_app.route("/auth/callback", methods=["GET"])
233
- def callback():
234
- return self.auth.callback_handler(flask.request)
235
-
236
- @self.flask_app.route("/auth/logout", methods=["GET"])
237
- def logout():
238
- return self.auth.logout_handler(flask.request)
239
-
240
196
  @self.flask_app.route("/api/v0/get_config", methods=["GET"])
241
197
  @self.requires_auth
242
198
  def get_config(user: any):
243
- config = {
244
- "debug": self.debug,
245
- "logo": self.logo,
246
- "title": self.title,
247
- "subtitle": self.subtitle,
248
- "show_training_data": self.show_training_data,
249
- "suggested_questions": self.suggested_questions,
250
- "sql": self.sql,
251
- "table": self.table,
252
- "csv_download": self.csv_download,
253
- "chart": self.chart,
254
- "redraw_chart": self.redraw_chart,
255
- "auto_fix_sql": self.auto_fix_sql,
256
- "ask_results_correct": self.ask_results_correct,
257
- "followup_questions": self.followup_questions,
258
- "summarization": self.summarization,
259
- "function_generation": self.function_generation,
260
- }
261
-
262
- config = self.auth.override_config_for_user(user, config)
263
-
199
+ """
200
+ Get the configuration for a user
201
+ ---
202
+ parameters:
203
+ - name: user
204
+ in: query
205
+ responses:
206
+ 200:
207
+ schema:
208
+ type: object
209
+ properties:
210
+ type:
211
+ type: string
212
+ default: config
213
+ config:
214
+ type: object
215
+ """
216
+ config = self.auth.override_config_for_user(user, self.config)
264
217
  return jsonify(
265
218
  {
266
219
  "type": "config",
@@ -271,6 +224,28 @@ class VannaFlaskApp:
271
224
  @self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
272
225
  @self.requires_auth
273
226
  def generate_questions(user: any):
227
+ """
228
+ Generate questions
229
+ ---
230
+ parameters:
231
+ - name: user
232
+ in: query
233
+ responses:
234
+ 200:
235
+ schema:
236
+ type: object
237
+ properties:
238
+ type:
239
+ type: string
240
+ default: question_list
241
+ questions:
242
+ type: array
243
+ items:
244
+ type: string
245
+ header:
246
+ type: string
247
+ default: Here are some questions you can ask
248
+ """
274
249
  # If self has an _model attribute and model=='chinook'
275
250
  if hasattr(self.vn, "_model") and self.vn._model == "chinook":
276
251
  return jsonify(
@@ -327,6 +302,29 @@ class VannaFlaskApp:
327
302
  @self.flask_app.route("/api/v0/generate_sql", methods=["GET"])
328
303
  @self.requires_auth
329
304
  def generate_sql(user: any):
305
+ """
306
+ Generate SQL from a question
307
+ ---
308
+ parameters:
309
+ - name: user
310
+ in: query
311
+ - name: question
312
+ in: query
313
+ type: string
314
+ required: true
315
+ responses:
316
+ 200:
317
+ schema:
318
+ type: object
319
+ properties:
320
+ type:
321
+ type: string
322
+ default: sql
323
+ id:
324
+ type: string
325
+ text:
326
+ type: string
327
+ """
330
328
  question = flask.request.args.get("question")
331
329
 
332
330
  if question is None:
@@ -358,6 +356,29 @@ class VannaFlaskApp:
358
356
  @self.flask_app.route("/api/v0/get_function", methods=["GET"])
359
357
  @self.requires_auth
360
358
  def get_function(user: any):
359
+ """
360
+ Get a function from a question
361
+ ---
362
+ parameters:
363
+ - name: user
364
+ in: query
365
+ - name: question
366
+ in: query
367
+ type: string
368
+ required: true
369
+ responses:
370
+ 200:
371
+ schema:
372
+ type: object
373
+ properties:
374
+ type:
375
+ type: string
376
+ default: function
377
+ id:
378
+ type: object
379
+ function:
380
+ type: string
381
+ """
361
382
  question = flask.request.args.get("question")
362
383
 
363
384
  if question is None:
@@ -393,6 +414,23 @@ class VannaFlaskApp:
393
414
  @self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
394
415
  @self.requires_auth
395
416
  def get_all_functions(user: any):
417
+ """
418
+ Get all the functions
419
+ ---
420
+ parameters:
421
+ - name: user
422
+ in: query
423
+ responses:
424
+ 200:
425
+ schema:
426
+ type: object
427
+ properties:
428
+ type:
429
+ type: string
430
+ default: functions
431
+ functions:
432
+ type: array
433
+ """
396
434
  if not hasattr(vn, "get_all_functions"):
397
435
  return jsonify({"type": "error", "error": "This setup does not support function generation."})
398
436
 
@@ -409,6 +447,31 @@ class VannaFlaskApp:
409
447
  @self.requires_auth
410
448
  @self.requires_cache(["sql"])
411
449
  def run_sql(user: any, id: str, sql: str):
450
+ """
451
+ Run SQL
452
+ ---
453
+ parameters:
454
+ - name: user
455
+ in: query
456
+ - name: id
457
+ in: query|body
458
+ type: string
459
+ required: true
460
+ responses:
461
+ 200:
462
+ schema:
463
+ type: object
464
+ properties:
465
+ type:
466
+ type: string
467
+ default: df
468
+ id:
469
+ type: string
470
+ df:
471
+ type: object
472
+ should_generate_chart:
473
+ type: boolean
474
+ """
412
475
  try:
413
476
  if not vn.run_sql_is_set:
414
477
  return jsonify(
@@ -437,7 +500,34 @@ class VannaFlaskApp:
437
500
  @self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
438
501
  @self.requires_auth
439
502
  @self.requires_cache(["question", "sql"])
440
- def fix_sql(user: any, id: str, question:str, sql: str):
503
+ def fix_sql(user: any, id: str, question: str, sql: str):
504
+ """
505
+ Fix SQL
506
+ ---
507
+ parameters:
508
+ - name: user
509
+ in: query
510
+ - name: id
511
+ in: query|body
512
+ type: string
513
+ required: true
514
+ - name: error
515
+ in: body
516
+ type: string
517
+ required: true
518
+ responses:
519
+ 200:
520
+ schema:
521
+ type: object
522
+ properties:
523
+ type:
524
+ type: string
525
+ default: sql
526
+ id:
527
+ type: string
528
+ text:
529
+ type: string
530
+ """
441
531
  error = flask.request.json.get("error")
442
532
 
443
533
  if error is None:
@@ -462,6 +552,33 @@ class VannaFlaskApp:
462
552
  @self.requires_auth
463
553
  @self.requires_cache([])
464
554
  def update_sql(user: any, id: str):
555
+ """
556
+ Update SQL
557
+ ---
558
+ parameters:
559
+ - name: user
560
+ in: query
561
+ - name: id
562
+ in: query|body
563
+ type: string
564
+ required: true
565
+ - name: sql
566
+ in: body
567
+ type: string
568
+ required: true
569
+ responses:
570
+ 200:
571
+ schema:
572
+ type: object
573
+ properties:
574
+ type:
575
+ type: string
576
+ default: sql
577
+ id:
578
+ type: string
579
+ text:
580
+ type: string
581
+ """
465
582
  sql = flask.request.json.get('sql')
466
583
 
467
584
  if sql is None:
@@ -480,6 +597,20 @@ class VannaFlaskApp:
480
597
  @self.requires_auth
481
598
  @self.requires_cache(["df"])
482
599
  def download_csv(user: any, id: str, df):
600
+ """
601
+ Download CSV
602
+ ---
603
+ parameters:
604
+ - name: user
605
+ in: query
606
+ - name: id
607
+ in: query|body
608
+ type: string
609
+ required: true
610
+ responses:
611
+ 200:
612
+ description: download CSV
613
+ """
483
614
  csv = df.to_csv()
484
615
 
485
616
  return Response(
@@ -492,6 +623,32 @@ class VannaFlaskApp:
492
623
  @self.requires_auth
493
624
  @self.requires_cache(["df", "question", "sql"])
494
625
  def generate_plotly_figure(user: any, id: str, df, question, sql):
626
+ """
627
+ Generate plotly figure
628
+ ---
629
+ parameters:
630
+ - name: user
631
+ in: query
632
+ - name: id
633
+ in: query|body
634
+ type: string
635
+ required: true
636
+ - name: chart_instructions
637
+ in: body
638
+ type: string
639
+ responses:
640
+ 200:
641
+ schema:
642
+ type: object
643
+ properties:
644
+ type:
645
+ type: string
646
+ default: plotly_figure
647
+ id:
648
+ type: string
649
+ fig:
650
+ type: object
651
+ """
495
652
  chart_instructions = flask.request.args.get('chart_instructions')
496
653
 
497
654
  try:
@@ -530,6 +687,26 @@ class VannaFlaskApp:
530
687
  @self.flask_app.route("/api/v0/get_training_data", methods=["GET"])
531
688
  @self.requires_auth
532
689
  def get_training_data(user: any):
690
+ """
691
+ Get all training data
692
+ ---
693
+ parameters:
694
+ - name: user
695
+ in: query
696
+ responses:
697
+ 200:
698
+ schema:
699
+ type: object
700
+ properties:
701
+ type:
702
+ type: string
703
+ default: df
704
+ id:
705
+ type: string
706
+ default: training_data
707
+ df:
708
+ type: object
709
+ """
533
710
  df = vn.get_training_data()
534
711
 
535
712
  if df is None or len(df) == 0:
@@ -551,6 +728,24 @@ class VannaFlaskApp:
551
728
  @self.flask_app.route("/api/v0/remove_training_data", methods=["POST"])
552
729
  @self.requires_auth
553
730
  def remove_training_data(user: any):
731
+ """
732
+ Remove training data
733
+ ---
734
+ parameters:
735
+ - name: user
736
+ in: query
737
+ - name: id
738
+ in: body
739
+ type: string
740
+ required: true
741
+ responses:
742
+ 200:
743
+ schema:
744
+ type: object
745
+ properties:
746
+ success:
747
+ type: boolean
748
+ """
554
749
  # Get id from the JSON body
555
750
  id = flask.request.json.get("id")
556
751
 
@@ -567,6 +762,32 @@ class VannaFlaskApp:
567
762
  @self.flask_app.route("/api/v0/train", methods=["POST"])
568
763
  @self.requires_auth
569
764
  def add_training_data(user: any):
765
+ """
766
+ Add training data
767
+ ---
768
+ parameters:
769
+ - name: user
770
+ in: query
771
+ - name: question
772
+ in: body
773
+ type: string
774
+ - name: sql
775
+ in: body
776
+ type: string
777
+ - name: ddl
778
+ in: body
779
+ type: string
780
+ - name: documentation
781
+ in: body
782
+ type: string
783
+ responses:
784
+ 200:
785
+ schema:
786
+ type: object
787
+ properties:
788
+ id:
789
+ type: string
790
+ """
570
791
  question = flask.request.json.get("question")
571
792
  sql = flask.request.json.get("sql")
572
793
  ddl = flask.request.json.get("ddl")
@@ -586,6 +807,29 @@ class VannaFlaskApp:
586
807
  @self.requires_auth
587
808
  @self.requires_cache(["question", "sql"])
588
809
  def create_function(user: any, id: str, question: str, sql: str):
810
+ """
811
+ Create function
812
+ ---
813
+ parameters:
814
+ - name: user
815
+ in: query
816
+ - name: id
817
+ in: query|body
818
+ type: string
819
+ required: true
820
+ responses:
821
+ 200:
822
+ schema:
823
+ type: object
824
+ properties:
825
+ type:
826
+ type: string
827
+ default: function_template
828
+ id:
829
+ type: string
830
+ function_template:
831
+ type: object
832
+ """
589
833
  plotly_code = self.cache.get(id=id, field="plotly_code")
590
834
 
591
835
  if plotly_code is None:
@@ -604,6 +848,28 @@ class VannaFlaskApp:
604
848
  @self.flask_app.route("/api/v0/update_function", methods=["POST"])
605
849
  @self.requires_auth
606
850
  def update_function(user: any):
851
+ """
852
+ Update function
853
+ ---
854
+ parameters:
855
+ - name: user
856
+ in: query
857
+ - name: old_function_name
858
+ in: body
859
+ type: string
860
+ required: true
861
+ - name: updated_function
862
+ in: body
863
+ type: object
864
+ required: true
865
+ responses:
866
+ 200:
867
+ schema:
868
+ type: object
869
+ properties:
870
+ success:
871
+ type: boolean
872
+ """
607
873
  old_function_name = flask.request.json.get("old_function_name")
608
874
  updated_function = flask.request.json.get("updated_function")
609
875
 
@@ -617,15 +883,57 @@ class VannaFlaskApp:
617
883
  @self.flask_app.route("/api/v0/delete_function", methods=["POST"])
618
884
  @self.requires_auth
619
885
  def delete_function(user: any):
886
+ """
887
+ Delete function
888
+ ---
889
+ parameters:
890
+ - name: user
891
+ in: query
892
+ - name: function_name
893
+ in: body
894
+ type: string
895
+ required: true
896
+ responses:
897
+ 200:
898
+ schema:
899
+ type: object
900
+ properties:
901
+ success:
902
+ type: boolean
903
+ """
620
904
  function_name = flask.request.json.get("function_name")
621
905
 
622
906
  return jsonify({"success": vn.delete_function(function_name=function_name)})
623
907
 
624
-
625
908
  @self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
626
909
  @self.requires_auth
627
910
  @self.requires_cache(["df", "question", "sql"])
628
911
  def generate_followup_questions(user: any, id: str, df, question, sql):
912
+ """
913
+ Generate followup questions
914
+ ---
915
+ parameters:
916
+ - name: user
917
+ in: query
918
+ - name: id
919
+ in: query|body
920
+ type: string
921
+ required: true
922
+ responses:
923
+ 200:
924
+ schema:
925
+ type: object
926
+ properties:
927
+ type:
928
+ type: string
929
+ default: question_list
930
+ questions:
931
+ type: array
932
+ items:
933
+ type: string
934
+ header:
935
+ type: string
936
+ """
629
937
  if self.allow_llm_to_see_data:
630
938
  followup_questions = vn.generate_followup_questions(
631
939
  question=question, sql=sql, df=df
@@ -658,6 +966,29 @@ class VannaFlaskApp:
658
966
  @self.requires_auth
659
967
  @self.requires_cache(["df", "question"])
660
968
  def generate_summary(user: any, id: str, df, question):
969
+ """
970
+ Generate summary
971
+ ---
972
+ parameters:
973
+ - name: user
974
+ in: query
975
+ - name: id
976
+ in: query|body
977
+ type: string
978
+ required: true
979
+ responses:
980
+ 200:
981
+ schema:
982
+ type: object
983
+ properties:
984
+ type:
985
+ type: string
986
+ default: text
987
+ id:
988
+ type: string
989
+ text:
990
+ type: string
991
+ """
661
992
  if self.allow_llm_to_see_data:
662
993
  summary = vn.generate_summary(question=question, df=df)
663
994
 
@@ -686,6 +1017,37 @@ class VannaFlaskApp:
686
1017
  optional_fields=["summary", "fig_json"]
687
1018
  )
688
1019
  def load_question(user: any, id: str, question, sql, df, fig_json, summary):
1020
+ """
1021
+ Load question
1022
+ ---
1023
+ parameters:
1024
+ - name: user
1025
+ in: query
1026
+ - name: id
1027
+ in: query|body
1028
+ type: string
1029
+ required: true
1030
+ responses:
1031
+ 200:
1032
+ schema:
1033
+ type: object
1034
+ properties:
1035
+ type:
1036
+ type: string
1037
+ default: question_cache
1038
+ id:
1039
+ type: string
1040
+ question:
1041
+ type: string
1042
+ sql:
1043
+ type: string
1044
+ df:
1045
+ type: object
1046
+ fig:
1047
+ type: object
1048
+ summary:
1049
+ type: string
1050
+ """
689
1051
  try:
690
1052
  return jsonify(
691
1053
  {
@@ -705,6 +1067,25 @@ class VannaFlaskApp:
705
1067
  @self.flask_app.route("/api/v0/get_question_history", methods=["GET"])
706
1068
  @self.requires_auth
707
1069
  def get_question_history(user: any):
1070
+ """
1071
+ Get question history
1072
+ ---
1073
+ parameters:
1074
+ - name: user
1075
+ in: query
1076
+ responses:
1077
+ 200:
1078
+ schema:
1079
+ type: object
1080
+ properties:
1081
+ type:
1082
+ type: string
1083
+ default: question_history
1084
+ questions:
1085
+ type: array
1086
+ items:
1087
+ type: string
1088
+ """
708
1089
  return jsonify(
709
1090
  {
710
1091
  "type": "question_history",
@@ -718,6 +1099,136 @@ class VannaFlaskApp:
718
1099
  {"type": "error", "error": "The rest of the API is not ported yet."}
719
1100
  )
720
1101
 
1102
+ if self.debug:
1103
+ @self.sock.route("/api/v0/log")
1104
+ def sock_log(ws):
1105
+ self.ws_clients.append(ws)
1106
+
1107
+ try:
1108
+ while True:
1109
+ message = ws.receive() # This example just reads and ignores to keep the socket open
1110
+ finally:
1111
+ self.ws_clients.remove(ws)
1112
+
1113
+ def run(self, *args, **kwargs):
1114
+ """
1115
+ Run the Flask app.
1116
+
1117
+ Args:
1118
+ *args: Arguments to pass to Flask's run method.
1119
+ **kwargs: Keyword arguments to pass to Flask's run method.
1120
+
1121
+ Returns:
1122
+ None
1123
+ """
1124
+ if args or kwargs:
1125
+ self.flask_app.run(*args, **kwargs)
1126
+
1127
+ else:
1128
+ try:
1129
+ from google.colab import output
1130
+
1131
+ output.serve_kernel_port_as_window(8084)
1132
+ from google.colab.output import eval_js
1133
+
1134
+ print("Your app is running at:")
1135
+ print(eval_js("google.colab.kernel.proxyPort(8084)"))
1136
+ except:
1137
+ print("Your app is running at:")
1138
+ print("http://localhost:8084")
1139
+
1140
+ self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
1141
+
1142
+
1143
+ class VannaFlaskApp(VannaFlaskAPI):
1144
+ def __init__(
1145
+ self,
1146
+ vn: VannaBase,
1147
+ cache: Cache = MemoryCache(),
1148
+ auth: AuthInterface = NoAuth(),
1149
+ debug=True,
1150
+ allow_llm_to_see_data=False,
1151
+ logo="https://img.vanna.ai/vanna-flask.svg",
1152
+ title="Welcome to Vanna.AI",
1153
+ subtitle="Your AI-powered copilot for SQL queries.",
1154
+ show_training_data=True,
1155
+ suggested_questions=True,
1156
+ sql=True,
1157
+ table=True,
1158
+ csv_download=True,
1159
+ chart=True,
1160
+ redraw_chart=True,
1161
+ auto_fix_sql=True,
1162
+ ask_results_correct=True,
1163
+ followup_questions=True,
1164
+ summarization=True,
1165
+ function_generation=True,
1166
+ index_html_path=None,
1167
+ assets_folder=None,
1168
+ ):
1169
+ """
1170
+ Expose a Flask app that can be used to interact with a Vanna instance.
1171
+
1172
+ Args:
1173
+ vn: The Vanna instance to interact with.
1174
+ cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
1175
+ auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
1176
+ debug: Show the debug console. Defaults to True.
1177
+ allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
1178
+ logo: The logo to display in the UI. Defaults to the Vanna logo.
1179
+ title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
1180
+ subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
1181
+ show_training_data: Whether to show the training data in the UI. Defaults to True.
1182
+ suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
1183
+ sql: Whether to show the SQL input in the UI. Defaults to True.
1184
+ table: Whether to show the table output in the UI. Defaults to True.
1185
+ csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
1186
+ chart: Whether to show the chart output in the UI. Defaults to True.
1187
+ redraw_chart: Whether to allow redrawing the chart. Defaults to True.
1188
+ auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
1189
+ ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
1190
+ followup_questions: Whether to show followup questions. Defaults to True.
1191
+ summarization: Whether to show summarization. Defaults to True.
1192
+ index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
1193
+ assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
1194
+
1195
+ Returns:
1196
+ None
1197
+ """
1198
+ super().__init__(vn, cache, auth, debug, allow_llm_to_see_data, chart)
1199
+
1200
+ self.config["logo"] = logo
1201
+ self.config["title"] = title
1202
+ self.config["subtitle"] = subtitle
1203
+ self.config["show_training_data"] = show_training_data
1204
+ self.config["suggested_questions"] = suggested_questions
1205
+ self.config["sql"] = sql
1206
+ self.config["table"] = table
1207
+ self.config["csv_download"] = csv_download
1208
+ self.config["chart"] = chart
1209
+ self.config["redraw_chart"] = redraw_chart
1210
+ self.config["auto_fix_sql"] = auto_fix_sql
1211
+ self.config["ask_results_correct"] = ask_results_correct
1212
+ self.config["followup_questions"] = followup_questions
1213
+ self.config["summarization"] = summarization
1214
+ self.config["function_generation"] = function_generation
1215
+
1216
+ self.index_html_path = index_html_path
1217
+ self.assets_folder = assets_folder
1218
+
1219
+ @self.flask_app.route("/auth/login", methods=["POST"])
1220
+ def login():
1221
+ return self.auth.login_handler(flask.request)
1222
+
1223
+ @self.flask_app.route("/auth/callback", methods=["GET"])
1224
+ def callback():
1225
+ return self.auth.callback_handler(flask.request)
1226
+
1227
+ @self.flask_app.route("/auth/logout", methods=["GET"])
1228
+ def logout():
1229
+ return self.auth.logout_handler(flask.request)
1230
+
1231
+
721
1232
  @self.flask_app.route("/assets/<path:filename>")
722
1233
  def proxy_assets(filename):
723
1234
  if self.assets_folder:
@@ -755,18 +1266,6 @@ class VannaFlaskApp:
755
1266
  else:
756
1267
  return "Error fetching file from remote server", response.status_code
757
1268
 
758
- if self.debug:
759
- @self.sock.route("/api/v0/log")
760
- def sock_log(ws):
761
- self.ws_clients.append(ws)
762
-
763
- try:
764
- while True:
765
- message = ws.receive() # This example just reads and ignores to keep the socket open
766
- finally:
767
- self.ws_clients.remove(ws)
768
-
769
-
770
1269
  @self.flask_app.route("/", defaults={"path": ""})
771
1270
  @self.flask_app.route("/<path:path>")
772
1271
  def hello(path: str):
@@ -775,32 +1274,3 @@ class VannaFlaskApp:
775
1274
  filename = os.path.basename(self.index_html_path)
776
1275
  return send_from_directory(directory=directory, path=filename)
777
1276
  return html_content
778
-
779
- def run(self, *args, **kwargs):
780
- """
781
- Run the Flask app.
782
-
783
- Args:
784
- *args: Arguments to pass to Flask's run method.
785
- **kwargs: Keyword arguments to pass to Flask's run method.
786
-
787
- Returns:
788
- None
789
- """
790
- if args or kwargs:
791
- self.flask_app.run(*args, **kwargs)
792
-
793
- else:
794
- try:
795
- from google.colab import output
796
-
797
- output.serve_kernel_port_as_window(8084)
798
- from google.colab.output import eval_js
799
-
800
- print("Your app is running at:")
801
- print(eval_js("google.colab.kernel.proxyPort(8084)"))
802
- except:
803
- print("Your app is running at:")
804
- print("http://localhost:8084")
805
-
806
- self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)