vanna 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
40
40
  initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
41
41
  ) -> str:
42
42
  if len(ddl_list) > 0:
43
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
43
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
44
44
 
45
45
  for ddl in ddl_list:
46
46
  if (
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
57
57
  initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
58
58
  ) -> str:
59
59
  if len(documentation_List) > 0:
60
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
60
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
61
61
 
62
62
  for documentation in documentation_List:
63
63
  if (
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
74
74
  initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
75
75
  ) -> str:
76
76
  if len(sql_List) > 0:
77
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
77
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
78
78
 
79
79
  for question in sql_List:
80
80
  if (
vanna/base/base.py CHANGED
@@ -124,6 +124,18 @@ class VannaBase(ABC):
124
124
  return self.extract_sql(llm_response)
125
125
 
126
126
  def extract_sql(self, llm_response: str) -> str:
127
+ # If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
128
+ sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
129
+ if sql:
130
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
131
+ )
132
+ return sql.group(0)
133
+
134
+ # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
135
+ sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
136
+ if sql:
137
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
138
+ return sql.group(0)
127
139
  # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
128
140
  sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
129
141
  if sql:
@@ -363,7 +375,7 @@ class VannaBase(ABC):
363
375
  self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
364
376
  ) -> str:
365
377
  if len(ddl_list) > 0:
366
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
378
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
367
379
 
368
380
  for ddl in ddl_list:
369
381
  if (
@@ -382,7 +394,7 @@ class VannaBase(ABC):
382
394
  max_tokens: int = 14000,
383
395
  ) -> str:
384
396
  if len(documentation_list) > 0:
385
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
397
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
386
398
 
387
399
  for documentation in documentation_list:
388
400
  if (
@@ -398,7 +410,7 @@ class VannaBase(ABC):
398
410
  self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
399
411
  ) -> str:
400
412
  if len(sql_list) > 0:
401
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
413
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
402
414
 
403
415
  for question in sql_list:
404
416
  if (
@@ -642,6 +654,7 @@ class VannaBase(ABC):
642
654
  password=password,
643
655
  account=account,
644
656
  database=database,
657
+ client_session_keep_alive=True
645
658
  )
646
659
 
647
660
  def run_sql_snowflake(sql: str) -> pd.DataFrame:
@@ -890,6 +903,94 @@ class VannaBase(ABC):
890
903
  self.run_sql_is_set = True
891
904
  self.run_sql = run_sql_mysql
892
905
 
906
+ def connect_to_oracle(
907
+ self,
908
+ user: str = None,
909
+ password: str = None,
910
+ dsn: str = None,
911
+ ):
912
+
913
+ """
914
+ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
915
+ **Example:**
916
+ ```python
917
+ vn.connect_to_oracle(
918
+ user="username",
919
+ password="password",
920
+ dns="host:port/sid",
921
+ )
922
+ ```
923
+ Args:
924
+ USER (str): Oracle db user name.
925
+ PASSWORD (str): Oracle db user password.
926
+ DSN (str): Oracle db host ip - host:port/sid.
927
+ """
928
+
929
+ try:
930
+ import oracledb
931
+ except ImportError:
932
+
933
+ raise DependencyError(
934
+ "You need to install required dependencies to execute this method,"
935
+ " run command: \npip install oracledb"
936
+ )
937
+
938
+ if not dsn:
939
+ dsn = os.getenv("DSN")
940
+
941
+ if not dsn:
942
+ raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
943
+
944
+ if not user:
945
+ user = os.getenv("USER")
946
+
947
+ if not user:
948
+ raise ImproperlyConfigured("Please set your Oracle db user")
949
+
950
+ if not password:
951
+ password = os.getenv("PASSWORD")
952
+
953
+ if not password:
954
+ raise ImproperlyConfigured("Please set your Oracle db password")
955
+
956
+ conn = None
957
+
958
+ try:
959
+ conn = oracledb.connect(
960
+ user=user,
961
+ password=password,
962
+ dsn=dsn,
963
+ )
964
+ except oracledb.Error as e:
965
+ raise ValidationError(e)
966
+
967
+ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
968
+ if conn:
969
+ try:
970
+ sql = sql.rstrip()
971
+ if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
972
+ sql = sql[:-1]
973
+
974
+ cs = conn.cursor()
975
+ cs.execute(sql)
976
+ results = cs.fetchall()
977
+
978
+ # Create a pandas dataframe from the results
979
+ df = pd.DataFrame(
980
+ results, columns=[desc[0] for desc in cs.description]
981
+ )
982
+ return df
983
+
984
+ except oracledb.Error as e:
985
+ conn.rollback()
986
+ raise ValidationError(e)
987
+
988
+ except Exception as e:
989
+ conn.rollback()
990
+ raise e
991
+
992
+ self.run_sql_is_set = True
993
+ self.run_sql = run_sql_oracle
893
994
 
894
995
  def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
895
996
  """
@@ -1238,7 +1339,7 @@ class VannaBase(ABC):
1238
1339
  """
1239
1340
 
1240
1341
  if question and not sql:
1241
- raise ValidationError(f"Please also provide a SQL query")
1342
+ raise ValidationError("Please also provide a SQL query")
1242
1343
 
1243
1344
  if documentation:
1244
1345
  print("Adding documentation....")
@@ -1,5 +1,4 @@
1
1
  import json
2
- import uuid
3
2
  from typing import List
4
3
 
5
4
  import chromadb
@@ -16,17 +15,16 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
16
15
  class ChromaDB_VectorStore(VannaBase):
17
16
  def __init__(self, config=None):
18
17
  VannaBase.__init__(self, config=config)
18
+ if config is None:
19
+ config = {}
19
20
 
20
- if config is not None:
21
- path = config.get("path", ".")
22
- self.embedding_function = config.get("embedding_function", default_ef)
23
- curr_client = config.get("client", "persistent")
24
- self.n_results = config.get("n_results", 10)
25
- else:
26
- path = "."
27
- self.embedding_function = default_ef
28
- curr_client = "persistent" # defaults to persistent storage
29
- self.n_results = 10 # defaults to 10 documents
21
+ path = config.get("path", ".")
22
+ self.embedding_function = config.get("embedding_function", default_ef)
23
+ curr_client = config.get("client", "persistent")
24
+ collection_metadata = config.get("collection_metadata", None)
25
+ self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
26
+ self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
27
+ self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
30
28
 
31
29
  if curr_client == "persistent":
32
30
  self.chroma_client = chromadb.PersistentClient(
@@ -43,13 +41,19 @@ class ChromaDB_VectorStore(VannaBase):
43
41
  raise ValueError(f"Unsupported client was set in config: {curr_client}")
44
42
 
45
43
  self.documentation_collection = self.chroma_client.get_or_create_collection(
46
- name="documentation", embedding_function=self.embedding_function
44
+ name="documentation",
45
+ embedding_function=self.embedding_function,
46
+ metadata=collection_metadata,
47
47
  )
48
48
  self.ddl_collection = self.chroma_client.get_or_create_collection(
49
- name="ddl", embedding_function=self.embedding_function
49
+ name="ddl",
50
+ embedding_function=self.embedding_function,
51
+ metadata=collection_metadata,
50
52
  )
51
53
  self.sql_collection = self.chroma_client.get_or_create_collection(
52
- name="sql", embedding_function=self.embedding_function
54
+ name="sql",
55
+ embedding_function=self.embedding_function,
56
+ metadata=collection_metadata,
53
57
  )
54
58
 
55
59
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
@@ -232,7 +236,7 @@ class ChromaDB_VectorStore(VannaBase):
232
236
  return ChromaDB_VectorStore._extract_documents(
233
237
  self.sql_collection.query(
234
238
  query_texts=[question],
235
- n_results=self.n_results,
239
+ n_results=self.n_results_sql,
236
240
  )
237
241
  )
238
242
 
@@ -240,6 +244,7 @@ class ChromaDB_VectorStore(VannaBase):
240
244
  return ChromaDB_VectorStore._extract_documents(
241
245
  self.ddl_collection.query(
242
246
  query_texts=[question],
247
+ n_results=self.n_results_ddl,
243
248
  )
244
249
  )
245
250
 
@@ -247,5 +252,6 @@ class ChromaDB_VectorStore(VannaBase):
247
252
  return ChromaDB_VectorStore._extract_documents(
248
253
  self.documentation_collection.query(
249
254
  query_texts=[question],
255
+ n_results=self.n_results_documentation,
250
256
  )
251
257
  )
vanna/flask/__init__.py CHANGED
@@ -8,27 +8,47 @@ import requests
8
8
  from flask import Flask, Response, jsonify, request
9
9
 
10
10
  from .assets import css_content, html_content, js_content
11
+ from .auth import AuthInterface, NoAuth
11
12
 
12
13
 
13
14
  class Cache(ABC):
15
+ """
16
+ Define the interface for a cache that can be used to store data in a Flask app.
17
+ """
18
+
14
19
  @abstractmethod
15
20
  def generate_id(self, *args, **kwargs):
21
+ """
22
+ Generate a unique ID for the cache.
23
+ """
16
24
  pass
17
25
 
18
26
  @abstractmethod
19
27
  def get(self, id, field):
28
+ """
29
+ Get a value from the cache.
30
+ """
20
31
  pass
21
32
 
22
33
  @abstractmethod
23
34
  def get_all(self, field_list) -> list:
35
+ """
36
+ Get all values from the cache.
37
+ """
24
38
  pass
25
39
 
26
40
  @abstractmethod
27
41
  def set(self, id, field, value):
42
+ """
43
+ Set a value in the cache.
44
+ """
28
45
  pass
29
46
 
30
47
  @abstractmethod
31
48
  def delete(self, id):
49
+ """
50
+ Delete a value from the cache.
51
+ """
32
52
  pass
33
53
 
34
54
 
@@ -64,11 +84,10 @@ class MemoryCache(Cache):
64
84
  if id in self.cache:
65
85
  del self.cache[id]
66
86
 
67
-
68
87
  class VannaFlaskApp:
69
88
  flask_app = None
70
89
 
71
- def requires_cache(self, fields):
90
+ def requires_cache(self, required_fields, optional_fields=[]):
72
91
  def decorator(f):
73
92
  @wraps(f)
74
93
  def decorated(*args, **kwargs):
@@ -79,14 +98,17 @@ class VannaFlaskApp:
79
98
  if id is None:
80
99
  return jsonify({"type": "error", "error": "No id provided"})
81
100
 
82
- for field in fields:
101
+ for field in required_fields:
83
102
  if self.cache.get(id=id, field=field) is None:
84
103
  return jsonify({"type": "error", "error": f"No {field} found"})
85
104
 
86
105
  field_values = {
87
- field: self.cache.get(id=id, field=field) for field in fields
106
+ field: self.cache.get(id=id, field=field) for field in required_fields
88
107
  }
89
108
 
109
+ for field in optional_fields:
110
+ field_values[field] = self.cache.get(id=id, field=field)
111
+
90
112
  # Add the id to the field_values
91
113
  field_values["id"] = id
92
114
 
@@ -96,7 +118,21 @@ class VannaFlaskApp:
96
118
 
97
119
  return decorator
98
120
 
121
+ def requires_auth(self, f):
122
+ @wraps(f)
123
+ def decorated(*args, **kwargs):
124
+ user = self.auth.get_user(flask.request)
125
+
126
+ if not self.auth.is_logged_in(user):
127
+ return jsonify({"type": "not_logged_in", "html": self.auth.login_form()})
128
+
129
+ # Pass the user to the function
130
+ return f(*args, user=user, **kwargs)
131
+
132
+ return decorated
133
+
99
134
  def __init__(self, vn, cache: Cache = MemoryCache(),
135
+ auth: AuthInterface = NoAuth(),
100
136
  allow_llm_to_see_data=False,
101
137
  logo="https://img.vanna.ai/vanna-flask.svg",
102
138
  title="Welcome to Vanna.AI",
@@ -119,6 +155,7 @@ class VannaFlaskApp:
119
155
  Args:
120
156
  vn: The Vanna instance to interact with.
121
157
  cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
158
+ auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
122
159
  allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
123
160
  logo: The logo to display in the UI. Defaults to the Vanna logo.
124
161
  title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
@@ -140,6 +177,7 @@ class VannaFlaskApp:
140
177
  """
141
178
  self.flask_app = Flask(__name__)
142
179
  self.vn = vn
180
+ self.auth = auth
143
181
  self.cache = cache
144
182
  self.allow_llm_to_see_data = allow_llm_to_see_data
145
183
  self.logo = logo
@@ -160,32 +198,50 @@ class VannaFlaskApp:
160
198
  log = logging.getLogger("werkzeug")
161
199
  log.setLevel(logging.ERROR)
162
200
 
201
+ @self.flask_app.route("/auth/login", methods=["POST"])
202
+ def login():
203
+ return self.auth.login_handler(flask.request)
204
+
205
+ @self.flask_app.route("/auth/callback", methods=["GET"])
206
+ def callback():
207
+ return self.auth.callback_handler(flask.request)
208
+
209
+ @self.flask_app.route("/auth/logout", methods=["GET"])
210
+ def logout():
211
+ return self.auth.logout_handler(flask.request)
212
+
163
213
  @self.flask_app.route("/api/v0/get_config", methods=["GET"])
164
- def get_config():
214
+ @self.requires_auth
215
+ def get_config(user: any):
216
+ config = {
217
+ "logo": self.logo,
218
+ "title": self.title,
219
+ "subtitle": self.subtitle,
220
+ "show_training_data": self.show_training_data,
221
+ "suggested_questions": self.suggested_questions,
222
+ "sql": self.sql,
223
+ "table": self.table,
224
+ "csv_download": self.csv_download,
225
+ "chart": self.chart,
226
+ "redraw_chart": self.redraw_chart,
227
+ "auto_fix_sql": self.auto_fix_sql,
228
+ "ask_results_correct": self.ask_results_correct,
229
+ "followup_questions": self.followup_questions,
230
+ "summarization": self.summarization,
231
+ }
232
+
233
+ config = self.auth.override_config_for_user(user, config)
234
+
165
235
  return jsonify(
166
236
  {
167
237
  "type": "config",
168
- "config": {
169
- "logo": self.logo,
170
- "title": self.title,
171
- "subtitle": self.subtitle,
172
- "show_training_data": self.show_training_data,
173
- "suggested_questions": self.suggested_questions,
174
- "sql": self.sql,
175
- "table": self.table,
176
- "csv_download": self.csv_download,
177
- "chart": self.chart,
178
- "redraw_chart": self.redraw_chart,
179
- "auto_fix_sql": self.auto_fix_sql,
180
- "ask_results_correct": self.ask_results_correct,
181
- "followup_questions": self.followup_questions,
182
- "summarization": self.summarization,
183
- },
238
+ "config": config
184
239
  }
185
240
  )
186
241
 
187
242
  @self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
188
- def generate_questions():
243
+ @self.requires_auth
244
+ def generate_questions(user: any):
189
245
  # If self has an _model attribute and model=='chinook'
190
246
  if hasattr(self.vn, "_model") and self.vn._model == "chinook":
191
247
  return jsonify(
@@ -240,7 +296,8 @@ class VannaFlaskApp:
240
296
  )
241
297
 
242
298
  @self.flask_app.route("/api/v0/generate_sql", methods=["GET"])
243
- def generate_sql():
299
+ @self.requires_auth
300
+ def generate_sql(user: any):
244
301
  question = flask.request.args.get("question")
245
302
 
246
303
  if question is None:
@@ -261,8 +318,9 @@ class VannaFlaskApp:
261
318
  )
262
319
 
263
320
  @self.flask_app.route("/api/v0/run_sql", methods=["GET"])
321
+ @self.requires_auth
264
322
  @self.requires_cache(["sql"])
265
- def run_sql(id: str, sql: str):
323
+ def run_sql(user: any, id: str, sql: str):
266
324
  try:
267
325
  if not vn.run_sql_is_set:
268
326
  return jsonify(
@@ -274,7 +332,7 @@ class VannaFlaskApp:
274
332
 
275
333
  df = vn.run_sql(sql=sql)
276
334
 
277
- cache.set(id=id, field="df", value=df)
335
+ self.cache.set(id=id, field="df", value=df)
278
336
 
279
337
  return jsonify(
280
338
  {
@@ -288,8 +346,9 @@ class VannaFlaskApp:
288
346
  return jsonify({"type": "sql_error", "error": str(e)})
289
347
 
290
348
  @self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
349
+ @self.requires_auth
291
350
  @self.requires_cache(["question", "sql"])
292
- def fix_sql(id: str, question:str, sql: str):
351
+ def fix_sql(user: any, id: str, question:str, sql: str):
293
352
  error = flask.request.json.get("error")
294
353
 
295
354
  if error is None:
@@ -311,14 +370,15 @@ class VannaFlaskApp:
311
370
 
312
371
 
313
372
  @self.flask_app.route('/api/v0/update_sql', methods=['POST'])
373
+ @self.requires_auth
314
374
  @self.requires_cache([])
315
- def update_sql(id: str):
375
+ def update_sql(user: any, id: str):
316
376
  sql = flask.request.json.get('sql')
317
377
 
318
378
  if sql is None:
319
379
  return jsonify({"type": "error", "error": "No sql provided"})
320
380
 
321
- cache.set(id=id, field='sql', value=sql)
381
+ self.cache.set(id=id, field='sql', value=sql)
322
382
 
323
383
  return jsonify(
324
384
  {
@@ -328,8 +388,9 @@ class VannaFlaskApp:
328
388
  })
329
389
 
330
390
  @self.flask_app.route("/api/v0/download_csv", methods=["GET"])
391
+ @self.requires_auth
331
392
  @self.requires_cache(["df"])
332
- def download_csv(id: str, df):
393
+ def download_csv(user: any, id: str, df):
333
394
  csv = df.to_csv()
334
395
 
335
396
  return Response(
@@ -339,8 +400,9 @@ class VannaFlaskApp:
339
400
  )
340
401
 
341
402
  @self.flask_app.route("/api/v0/generate_plotly_figure", methods=["GET"])
403
+ @self.requires_auth
342
404
  @self.requires_cache(["df", "question", "sql"])
343
- def generate_plotly_figure(id: str, df, question, sql):
405
+ def generate_plotly_figure(user: any, id: str, df, question, sql):
344
406
  chart_instructions = flask.request.args.get('chart_instructions')
345
407
 
346
408
  if chart_instructions is not None:
@@ -355,7 +417,7 @@ class VannaFlaskApp:
355
417
  fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
356
418
  fig_json = fig.to_json()
357
419
 
358
- cache.set(id=id, field="fig_json", value=fig_json)
420
+ self.cache.set(id=id, field="fig_json", value=fig_json)
359
421
 
360
422
  return jsonify(
361
423
  {
@@ -373,7 +435,8 @@ class VannaFlaskApp:
373
435
  return jsonify({"type": "error", "error": str(e)})
374
436
 
375
437
  @self.flask_app.route("/api/v0/get_training_data", methods=["GET"])
376
- def get_training_data():
438
+ @self.requires_auth
439
+ def get_training_data(user: any):
377
440
  df = vn.get_training_data()
378
441
 
379
442
  if df is None or len(df) == 0:
@@ -393,7 +456,8 @@ class VannaFlaskApp:
393
456
  )
394
457
 
395
458
  @self.flask_app.route("/api/v0/remove_training_data", methods=["POST"])
396
- def remove_training_data():
459
+ @self.requires_auth
460
+ def remove_training_data(user: any):
397
461
  # Get id from the JSON body
398
462
  id = flask.request.json.get("id")
399
463
 
@@ -408,7 +472,8 @@ class VannaFlaskApp:
408
472
  )
409
473
 
410
474
  @self.flask_app.route("/api/v0/train", methods=["POST"])
411
- def add_training_data():
475
+ @self.requires_auth
476
+ def add_training_data(user: any):
412
477
  question = flask.request.json.get("question")
413
478
  sql = flask.request.json.get("sql")
414
479
  ddl = flask.request.json.get("ddl")
@@ -425,8 +490,9 @@ class VannaFlaskApp:
425
490
  return jsonify({"type": "error", "error": str(e)})
426
491
 
427
492
  @self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
493
+ @self.requires_auth
428
494
  @self.requires_cache(["df", "question", "sql"])
429
- def generate_followup_questions(id: str, df, question, sql):
495
+ def generate_followup_questions(user: any, id: str, df, question, sql):
430
496
  if self.allow_llm_to_see_data:
431
497
  followup_questions = vn.generate_followup_questions(
432
498
  question=question, sql=sql, df=df
@@ -434,7 +500,7 @@ class VannaFlaskApp:
434
500
  if followup_questions is not None and len(followup_questions) > 5:
435
501
  followup_questions = followup_questions[:5]
436
502
 
437
- cache.set(id=id, field="followup_questions", value=followup_questions)
503
+ self.cache.set(id=id, field="followup_questions", value=followup_questions)
438
504
 
439
505
  return jsonify(
440
506
  {
@@ -445,7 +511,7 @@ class VannaFlaskApp:
445
511
  }
446
512
  )
447
513
  else:
448
- cache.set(id=id, field="followup_questions", value=[])
514
+ self.cache.set(id=id, field="followup_questions", value=[])
449
515
  return jsonify(
450
516
  {
451
517
  "type": "question_list",
@@ -456,10 +522,14 @@ class VannaFlaskApp:
456
522
  )
457
523
 
458
524
  @self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
525
+ @self.requires_auth
459
526
  @self.requires_cache(["df", "question"])
460
- def generate_summary(id: str, df, question):
527
+ def generate_summary(user: any, id: str, df, question):
461
528
  if self.allow_llm_to_see_data:
462
529
  summary = vn.generate_summary(question=question, df=df)
530
+
531
+ self.cache.set(id=id, field="summary", value=summary)
532
+
463
533
  return jsonify(
464
534
  {
465
535
  "type": "text",
@@ -477,10 +547,12 @@ class VannaFlaskApp:
477
547
  )
478
548
 
479
549
  @self.flask_app.route("/api/v0/load_question", methods=["GET"])
550
+ @self.requires_auth
480
551
  @self.requires_cache(
481
- ["question", "sql", "df", "fig_json"]
552
+ ["question", "sql", "df", "fig_json"],
553
+ optional_fields=["summary"]
482
554
  )
483
- def load_question(id: str, question, sql, df, fig_json):
555
+ def load_question(user: any, id: str, question, sql, df, fig_json, summary):
484
556
  try:
485
557
  return jsonify(
486
558
  {
@@ -488,8 +560,9 @@ class VannaFlaskApp:
488
560
  "id": id,
489
561
  "question": question,
490
562
  "sql": sql,
491
- "df": df.head(10).to_json(orient="records"),
563
+ "df": df.head(10).to_json(orient="records", date_format="iso"),
492
564
  "fig": fig_json,
565
+ "summary": summary,
493
566
  }
494
567
  )
495
568
 
@@ -497,7 +570,8 @@ class VannaFlaskApp:
497
570
  return jsonify({"type": "error", "error": str(e)})
498
571
 
499
572
  @self.flask_app.route("/api/v0/get_question_history", methods=["GET"])
500
- def get_question_history():
573
+ @self.requires_auth
574
+ def get_question_history(user: any):
501
575
  return jsonify(
502
576
  {
503
577
  "type": "question_history",
@@ -525,7 +599,7 @@ class VannaFlaskApp:
525
599
  # Proxy the /vanna.svg file to the remote server
526
600
  @self.flask_app.route("/vanna.svg")
527
601
  def proxy_vanna_svg():
528
- remote_url = f"https://vanna.ai/img/vanna.svg"
602
+ remote_url = "https://vanna.ai/img/vanna.svg"
529
603
  response = requests.get(remote_url, stream=True)
530
604
 
531
605
  # Check if the request to the remote URL was successful