vanna 0.4.3__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/base/base.py CHANGED
@@ -62,6 +62,7 @@ import plotly
62
62
  import plotly.express as px
63
63
  import plotly.graph_objects as go
64
64
  import requests
65
+ import sqlparse
65
66
 
66
67
  from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
67
68
  from ..types import TrainingPlan, TrainingPlanItem
@@ -70,14 +71,25 @@ from ..utils import validate_config_path
70
71
 
71
72
  class VannaBase(ABC):
72
73
  def __init__(self, config=None):
74
+ if config is None:
75
+ config = {}
76
+
73
77
  self.config = config
74
78
  self.run_sql_is_set = False
75
79
  self.static_documentation = ""
80
+ self.dialect = self.config.get("dialect", "SQL")
81
+ self.language = self.config.get("language", None)
76
82
 
77
- def log(self, message: str):
83
+ def log(self, message: str, title: str = "Info"):
78
84
  print(message)
79
85
 
80
- def generate_sql(self, question: str, **kwargs) -> str:
86
+ def _response_language(self) -> str:
87
+ if self.language is None:
88
+ return ""
89
+
90
+ return f"Respond in the {self.language} language."
91
+
92
+ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
81
93
  """
82
94
  Example:
83
95
  ```python
@@ -99,6 +111,7 @@ class VannaBase(ABC):
99
111
 
100
112
  Args:
101
113
  question (str): The question to generate a SQL query for.
114
+ allow_llm_to_see_data (bool): Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
102
115
 
103
116
  Returns:
104
117
  str: The SQL query that answers the question.
@@ -118,45 +131,129 @@ class VannaBase(ABC):
118
131
  doc_list=doc_list,
119
132
  **kwargs,
120
133
  )
121
- self.log(prompt)
134
+ self.log(title="SQL Prompt", message=prompt)
122
135
  llm_response = self.submit_prompt(prompt, **kwargs)
123
- self.log(llm_response)
136
+ self.log(title="LLM Response", message=llm_response)
137
+
138
+ if 'intermediate_sql' in llm_response:
139
+ if not allow_llm_to_see_data:
140
+ return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
141
+
142
+ if allow_llm_to_see_data:
143
+ intermediate_sql = self.extract_sql(llm_response)
144
+
145
+ try:
146
+ self.log(title="Running Intermediate SQL", message=intermediate_sql)
147
+ df = self.run_sql(intermediate_sql)
148
+
149
+ prompt = self.get_sql_prompt(
150
+ initial_prompt=initial_prompt,
151
+ question=question,
152
+ question_sql_list=question_sql_list,
153
+ ddl_list=ddl_list,
154
+ doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
155
+ **kwargs,
156
+ )
157
+ self.log(title="Final SQL Prompt", message=prompt)
158
+ llm_response = self.submit_prompt(prompt, **kwargs)
159
+ self.log(title="LLM Response", message=llm_response)
160
+ except Exception as e:
161
+ return f"Error running intermediate SQL: {e}"
162
+
163
+
124
164
  return self.extract_sql(llm_response)
125
165
 
126
166
  def extract_sql(self, llm_response: str) -> str:
127
- # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
128
- sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
129
- if sql:
130
- self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
131
- return sql.group(0)
132
- # If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
133
- sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
134
- if sql:
135
- self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
136
- )
137
- return sql.group(0)
167
+ """
168
+ Example:
169
+ ```python
170
+ vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
171
+ ```
138
172
 
139
- # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
140
- sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
141
- if sql:
142
- self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
143
- return sql.group(1)
173
+ Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
174
+ Override this function if your LLM responses need custom extraction logic.
144
175
 
145
- sql = re.search(r"```(.*)```", llm_response, re.DOTALL)
146
- if sql:
147
- self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
148
- return sql.group(1)
176
+ Args:
177
+ llm_response (str): The LLM response.
178
+
179
+ Returns:
180
+ str: The extracted SQL query.
181
+ """
182
+
183
+ # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
184
+ sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
185
+ if sqls:
186
+ sql = sqls[-1]
187
+ self.log(title="Extracted SQL", message=f"{sql}")
188
+ return sql
189
+
190
+ # If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
191
+ sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
192
+ if sqls:
193
+ sql = sqls[-1]
194
+ self.log(title="Extracted SQL", message=f"{sql}")
195
+ return sql
196
+
197
+ # If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
198
+ sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
199
+ if sqls:
200
+ sql = sqls[-1]
201
+ self.log(title="Extracted SQL", message=f"{sql}")
202
+ return sql
203
+
204
+ sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
205
+ if sqls:
206
+ sql = sqls[-1]
207
+ self.log(title="Extracted SQL", message=f"{sql}")
208
+ return sql
149
209
 
150
210
  return llm_response
151
211
 
152
212
  def is_sql_valid(self, sql: str) -> bool:
153
- # This is a check to see the SQL is valid and should be run
154
- # This simple function just checks if the SQL contains a SELECT statement
213
+ """
214
+ Example:
215
+ ```python
216
+ vn.is_sql_valid("SELECT * FROM customers")
217
+ ```
218
+ Checks if the SQL query is valid. This is usually used to check if we should run the SQL query or not.
219
+ By default it checks if the SQL query is a SELECT statement. You can override this method to enable running other types of SQL queries.
220
+
221
+ Args:
222
+ sql (str): The SQL query to check.
223
+
224
+ Returns:
225
+ bool: True if the SQL query is valid, False otherwise.
226
+ """
227
+
228
+ parsed = sqlparse.parse(sql)
229
+
230
+ for statement in parsed:
231
+ if statement.get_type() == 'SELECT':
232
+ return True
155
233
 
156
- if "SELECT" in sql.upper():
234
+ return False
235
+
236
+ def should_generate_chart(self, df: pd.DataFrame) -> bool:
237
+ """
238
+ Example:
239
+ ```python
240
+ vn.should_generate_chart(df)
241
+ ```
242
+
243
+ Checks if a chart should be generated for the given DataFrame. By default, it checks if the DataFrame has more than one row and has numerical columns.
244
+ You can override this method to customize the logic for generating charts.
245
+
246
+ Args:
247
+ df (pd.DataFrame): The DataFrame to check.
248
+
249
+ Returns:
250
+ bool: True if a chart should be generated, False otherwise.
251
+ """
252
+
253
+ if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0:
157
254
  return True
158
- else:
159
- return False
255
+
256
+ return False
160
257
 
161
258
  def generate_followup_questions(
162
259
  self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
@@ -184,7 +281,8 @@ class VannaBase(ABC):
184
281
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
185
282
  ),
186
283
  self.user_message(
187
- f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
284
+ f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
285
+ self._response_language()
188
286
  ),
189
287
  ]
190
288
 
@@ -228,7 +326,8 @@ class VannaBase(ABC):
228
326
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
229
327
  ),
230
328
  self.user_message(
231
- "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
329
+ "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." +
330
+ self._response_language()
232
331
  ),
233
332
  ]
234
333
 
@@ -375,7 +474,7 @@ class VannaBase(ABC):
375
474
  self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
376
475
  ) -> str:
377
476
  if len(ddl_list) > 0:
378
- initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
477
+ initial_prompt += "\n===Tables \n"
379
478
 
380
479
  for ddl in ddl_list:
381
480
  if (
@@ -394,7 +493,7 @@ class VannaBase(ABC):
394
493
  max_tokens: int = 14000,
395
494
  ) -> str:
396
495
  if len(documentation_list) > 0:
397
- initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
496
+ initial_prompt += "\n===Additional Context \n\n"
398
497
 
399
498
  for documentation in documentation_list:
400
499
  if (
@@ -410,7 +509,7 @@ class VannaBase(ABC):
410
509
  self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
411
510
  ) -> str:
412
511
  if len(sql_list) > 0:
413
- initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
512
+ initial_prompt += "\n===Question-SQL Pairs\n\n"
414
513
 
415
514
  for question in sql_list:
416
515
  if (
@@ -456,7 +555,8 @@ class VannaBase(ABC):
456
555
  """
457
556
 
458
557
  if initial_prompt is None:
459
- initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
558
+ initial_prompt = f"You are a {self.dialect} expert. "
559
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
460
560
 
461
561
  initial_prompt = self.add_ddl_to_prompt(
462
562
  initial_prompt, ddl_list, max_tokens=14000
@@ -469,6 +569,15 @@ class VannaBase(ABC):
469
569
  initial_prompt, doc_list, max_tokens=14000
470
570
  )
471
571
 
572
+ initial_prompt += (
573
+ "===Response Guidelines \n"
574
+ "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
575
+ "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
576
+ "3. If the provided context is insufficient, please explain why it can't be generated. \n"
577
+ "4. Please use the most relevant table(s). \n"
578
+ "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
579
+ )
580
+
472
581
  message_log = [self.system_message(initial_prompt)]
473
582
 
474
583
  for example in question_sql_list:
@@ -676,7 +785,7 @@ class VannaBase(ABC):
676
785
 
677
786
  return df
678
787
 
679
- self.static_documentation = "This is a Snowflake database"
788
+ self.dialect = "Snowflake SQL"
680
789
  self.run_sql = run_sql_snowflake
681
790
  self.run_sql_is_set = True
682
791
 
@@ -710,7 +819,7 @@ class VannaBase(ABC):
710
819
  def run_sql_sqlite(sql: str):
711
820
  return pd.read_sql_query(sql, conn)
712
821
 
713
- self.static_documentation = "This is a SQLite database"
822
+ self.dialect = "SQLite"
714
823
  self.run_sql = run_sql_sqlite
715
824
  self.run_sql_is_set = True
716
825
 
@@ -815,7 +924,7 @@ class VannaBase(ABC):
815
924
  conn.rollback()
816
925
  raise e
817
926
 
818
- self.static_documentation = "This is a Postgres database"
927
+ self.dialect = "PostgreSQL"
819
928
  self.run_sql_is_set = True
820
929
  self.run_sql = run_sql_postgres
821
930
 
@@ -1078,7 +1187,7 @@ class VannaBase(ABC):
1078
1187
  raise errors
1079
1188
  return None
1080
1189
 
1081
- self.static_documentation = "This is a BigQuery database"
1190
+ self.dialect = "BigQuery SQL"
1082
1191
  self.run_sql_is_set = True
1083
1192
  self.run_sql = run_sql_bigquery
1084
1193
 
@@ -1127,7 +1236,7 @@ class VannaBase(ABC):
1127
1236
  def run_sql_duckdb(sql: str):
1128
1237
  return conn.query(sql).to_df()
1129
1238
 
1130
- self.static_documentation = "This is a DuckDB database"
1239
+ self.dialect = "DuckDB SQL"
1131
1240
  self.run_sql = run_sql_duckdb
1132
1241
  self.run_sql_is_set = True
1133
1242
 
@@ -1174,7 +1283,7 @@ class VannaBase(ABC):
1174
1283
 
1175
1284
  raise Exception("Couldn't run sql")
1176
1285
 
1177
- self.static_documentation = "This is a Microsoft SQL Server database"
1286
+ self.dialect = "T-SQL / Microsoft SQL Server"
1178
1287
  self.run_sql = run_sql_mssql
1179
1288
  self.run_sql_is_set = True
1180
1289
 
vanna/flask/__init__.py CHANGED
@@ -1,4 +1,6 @@
1
+ import json
1
2
  import logging
3
+ import sys
2
4
  import uuid
3
5
  from abc import ABC, abstractmethod
4
6
  from functools import wraps
@@ -6,6 +8,7 @@ from functools import wraps
6
8
  import flask
7
9
  import requests
8
10
  from flask import Flask, Response, jsonify, request
11
+ from flask_sock import Sock
9
12
 
10
13
  from .assets import css_content, html_content, js_content
11
14
  from .auth import AuthInterface, NoAuth
@@ -133,6 +136,7 @@ class VannaFlaskApp:
133
136
 
134
137
  def __init__(self, vn, cache: Cache = MemoryCache(),
135
138
  auth: AuthInterface = NoAuth(),
139
+ debug=True,
136
140
  allow_llm_to_see_data=False,
137
141
  logo="https://img.vanna.ai/vanna-flask.svg",
138
142
  title="Welcome to Vanna.AI",
@@ -156,6 +160,7 @@ class VannaFlaskApp:
156
160
  vn: The Vanna instance to interact with.
157
161
  cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
158
162
  auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
163
+ debug: Show the debug console. Defaults to True.
159
164
  allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
160
165
  logo: The logo to display in the UI. Defaults to the Vanna logo.
161
166
  title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
@@ -176,7 +181,10 @@ class VannaFlaskApp:
176
181
  None
177
182
  """
178
183
  self.flask_app = Flask(__name__)
184
+ self.sock = Sock(self.flask_app)
185
+ self.ws_clients = []
179
186
  self.vn = vn
187
+ self.debug = debug
180
188
  self.auth = auth
181
189
  self.cache = cache
182
190
  self.allow_llm_to_see_data = allow_llm_to_see_data
@@ -198,6 +206,16 @@ class VannaFlaskApp:
198
206
  log = logging.getLogger("werkzeug")
199
207
  log.setLevel(logging.ERROR)
200
208
 
209
+ if "google.colab" in sys.modules:
210
+ self.debug = False
211
+ print("Google Colab doesn't support running websocket servers. Disabling debug mode.")
212
+
213
+ if self.debug:
214
+ def log(message, title="Info"):
215
+ [ws.send(json.dumps({'message': message, 'title': title})) for ws in self.ws_clients]
216
+
217
+ self.vn.log = log
218
+
201
219
  @self.flask_app.route("/auth/login", methods=["POST"])
202
220
  def login():
203
221
  return self.auth.login_handler(flask.request)
@@ -214,6 +232,7 @@ class VannaFlaskApp:
214
232
  @self.requires_auth
215
233
  def get_config(user: any):
216
234
  config = {
235
+ "debug": self.debug,
217
236
  "logo": self.logo,
218
237
  "title": self.title,
219
238
  "subtitle": self.subtitle,
@@ -304,18 +323,27 @@ class VannaFlaskApp:
304
323
  return jsonify({"type": "error", "error": "No question provided"})
305
324
 
306
325
  id = self.cache.generate_id(question=question)
307
- sql = vn.generate_sql(question=question)
326
+ sql = vn.generate_sql(question=question, allow_llm_to_see_data=self.allow_llm_to_see_data)
308
327
 
309
328
  self.cache.set(id=id, field="question", value=question)
310
329
  self.cache.set(id=id, field="sql", value=sql)
311
330
 
312
- return jsonify(
313
- {
314
- "type": "sql",
315
- "id": id,
316
- "text": sql,
317
- }
318
- )
331
+ if vn.is_sql_valid(sql=sql):
332
+ return jsonify(
333
+ {
334
+ "type": "sql",
335
+ "id": id,
336
+ "text": sql,
337
+ }
338
+ )
339
+ else:
340
+ return jsonify(
341
+ {
342
+ "type": "text",
343
+ "id": id,
344
+ "text": sql,
345
+ }
346
+ )
319
347
 
320
348
  @self.flask_app.route("/api/v0/run_sql", methods=["GET"])
321
349
  @self.requires_auth
@@ -339,6 +367,7 @@ class VannaFlaskApp:
339
367
  "type": "df",
340
368
  "id": id,
341
369
  "df": df.head(10).to_json(orient='records', date_format='iso'),
370
+ "should_generate_chart": self.chart and vn.should_generate_chart(df),
342
371
  }
343
372
  )
344
373
 
@@ -619,6 +648,18 @@ class VannaFlaskApp:
619
648
  else:
620
649
  return "Error fetching file from remote server", response.status_code
621
650
 
651
+ if self.debug:
652
+ @self.sock.route("/api/v0/log")
653
+ def sock_log(ws):
654
+ self.ws_clients.append(ws)
655
+
656
+ try:
657
+ while True:
658
+ message = ws.receive() # This example just reads and ignores to keep the socket open
659
+ finally:
660
+ self.ws_clients.remove(ws)
661
+
662
+
622
663
  @self.flask_app.route("/", defaults={"path": ""})
623
664
  @self.flask_app.route("/<path:path>")
624
665
  def hello(path: str):
@@ -651,4 +692,4 @@ class VannaFlaskApp:
651
692
  print("Your app is running at:")
652
693
  print("http://localhost:8084")
653
694
 
654
- self.flask_app.run(host="0.0.0.0", port=8084, debug=False)
695
+ self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)