vanna 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,26 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class VannaAdvanced(ABC):
5
+ def __init__(self, config=None):
6
+ self.config = config
7
+
8
+ @abstractmethod
9
+ def get_function(self, question: str, additional_data: dict = {}) -> dict:
10
+ pass
11
+
12
+ @abstractmethod
13
+ def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def update_function(self, old_function_name: str, updated_function: dict) -> bool:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def delete_function(self, function_name: str) -> bool:
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_all_functions(self) -> list:
26
+ pass
vanna/base/base.py CHANGED
@@ -79,9 +79,10 @@ class VannaBase(ABC):
79
79
  self.static_documentation = ""
80
80
  self.dialect = self.config.get("dialect", "SQL")
81
81
  self.language = self.config.get("language", None)
82
+ self.max_tokens = self.config.get("max_tokens", 14000)
82
83
 
83
84
  def log(self, message: str, title: str = "Info"):
84
- print(message)
85
+ print(f"{title}: {message}")
85
86
 
86
87
  def _response_language(self) -> str:
87
88
  if self.language is None:
@@ -555,18 +556,18 @@ class VannaBase(ABC):
555
556
  """
556
557
 
557
558
  if initial_prompt is None:
558
- initial_prompt = f"You are a {self.dialect} expert. "
559
+ initial_prompt = f"You are a {self.dialect} expert. " + \
559
560
  "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
560
561
 
561
562
  initial_prompt = self.add_ddl_to_prompt(
562
- initial_prompt, ddl_list, max_tokens=14000
563
+ initial_prompt, ddl_list, max_tokens=self.max_tokens
563
564
  )
564
565
 
565
566
  if self.static_documentation != "":
566
567
  doc_list.append(self.static_documentation)
567
568
 
568
569
  initial_prompt = self.add_documentation_to_prompt(
569
- initial_prompt, doc_list, max_tokens=14000
570
+ initial_prompt, doc_list, max_tokens=self.max_tokens
570
571
  )
571
572
 
572
573
  initial_prompt += (
@@ -603,15 +604,15 @@ class VannaBase(ABC):
603
604
  initial_prompt = f"The user initially asked the question: '{question}': \n\n"
604
605
 
605
606
  initial_prompt = self.add_ddl_to_prompt(
606
- initial_prompt, ddl_list, max_tokens=14000
607
+ initial_prompt, ddl_list, max_tokens=self.max_tokens
607
608
  )
608
609
 
609
610
  initial_prompt = self.add_documentation_to_prompt(
610
- initial_prompt, doc_list, max_tokens=14000
611
+ initial_prompt, doc_list, max_tokens=self.max_tokens
611
612
  )
612
613
 
613
614
  initial_prompt = self.add_sql_to_prompt(
614
- initial_prompt, question_sql_list, max_tokens=14000
615
+ initial_prompt, question_sql_list, max_tokens=self.max_tokens
615
616
  )
616
617
 
617
618
  message_log = [self.system_message(initial_prompt)]
@@ -1022,11 +1023,11 @@ class VannaBase(ABC):
1022
1023
  ):
1023
1024
 
1024
1025
  try:
1025
- from clickhouse_driver import connect
1026
+ import clickhouse_connect
1026
1027
  except ImportError:
1027
1028
  raise DependencyError(
1028
1029
  "You need to install required dependencies to execute this method,"
1029
- " run command: \npip install clickhouse-driver"
1030
+ " run command: \npip install clickhouse_connect"
1030
1031
  )
1031
1032
 
1032
1033
  if not host:
@@ -1062,12 +1063,13 @@ class VannaBase(ABC):
1062
1063
  conn = None
1063
1064
 
1064
1065
  try:
1065
- conn = connect(host=host,
1066
- user=user,
1067
- password=password,
1068
- database=dbname,
1069
- port=port,
1070
- )
1066
+ conn = clickhouse_connect.get_client(
1067
+ host=host,
1068
+ port=port,
1069
+ username=user,
1070
+ password=password,
1071
+ database=dbname,
1072
+ )
1071
1073
  print(conn)
1072
1074
  except Exception as e:
1073
1075
  raise ValidationError(e)
@@ -1075,19 +1077,16 @@ class VannaBase(ABC):
1075
1077
  def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
1076
1078
  if conn:
1077
1079
  try:
1078
- cs = conn.cursor()
1079
- cs.execute(sql)
1080
- results = cs.fetchall()
1080
+ result = conn.query(sql)
1081
+ results = result.result_rows
1081
1082
 
1082
- # Create a pandas dataframe from the results
1083
- df = pd.DataFrame(
1084
- results, columns=[desc[0] for desc in cs.description]
1085
- )
1086
- return df
1083
+ # Create a pandas dataframe from the results
1084
+ df = pd.DataFrame(results, columns=result.column_names)
1085
+ return df
1087
1086
 
1088
1087
  except Exception as e:
1089
1088
  raise e
1090
-
1089
+
1091
1090
  self.run_sql_is_set = True
1092
1091
  self.run_sql = run_sql_clickhouse
1093
1092
 
@@ -1597,6 +1596,7 @@ class VannaBase(ABC):
1597
1596
  print_results: bool = True,
1598
1597
  auto_train: bool = True,
1599
1598
  visualize: bool = True, # if False, will not generate plotly code
1599
+ allow_llm_to_see_data: bool = False,
1600
1600
  ) -> Union[
1601
1601
  Tuple[
1602
1602
  Union[str, None],
@@ -1627,7 +1627,7 @@ class VannaBase(ABC):
1627
1627
  question = input("Enter a question: ")
1628
1628
 
1629
1629
  try:
1630
- sql = self.generate_sql(question=question)
1630
+ sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
1631
1631
  except Exception as e:
1632
1632
  print(e)
1633
1633
  return None, None, None
@@ -1701,7 +1701,7 @@ class VannaBase(ABC):
1701
1701
  return None
1702
1702
  else:
1703
1703
  return sql, None, None
1704
- return sql, df, None
1704
+ return sql, df, fig
1705
1705
 
1706
1706
  def train(
1707
1707
  self,
vanna/flask/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import logging
3
+ import os
3
4
  import sys
4
5
  import uuid
5
6
  from abc import ABC, abstractmethod
@@ -7,7 +8,7 @@ from functools import wraps
7
8
 
8
9
  import flask
9
10
  import requests
10
- from flask import Flask, Response, jsonify, request
11
+ from flask import Flask, Response, jsonify, request, send_from_directory
11
12
  from flask_sock import Sock
12
13
 
13
14
  from .assets import css_content, html_content, js_content
@@ -151,7 +152,10 @@ class VannaFlaskApp:
151
152
  auto_fix_sql=True,
152
153
  ask_results_correct=True,
153
154
  followup_questions=True,
154
- summarization=True
155
+ summarization=True,
156
+ function_generation=True,
157
+ index_html_path=None,
158
+ assets_folder=None,
155
159
  ):
156
160
  """
157
161
  Expose a Flask app that can be used to interact with a Vanna instance.
@@ -176,6 +180,8 @@ class VannaFlaskApp:
176
180
  ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
177
181
  followup_questions: Whether to show followup questions. Defaults to True.
178
182
  summarization: Whether to show summarization. Defaults to True.
183
+ index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
184
+ assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
179
185
 
180
186
  Returns:
181
187
  None
@@ -202,6 +208,9 @@ class VannaFlaskApp:
202
208
  self.ask_results_correct = ask_results_correct
203
209
  self.followup_questions = followup_questions
204
210
  self.summarization = summarization
211
+ self.function_generation = function_generation and hasattr(vn, "get_function")
212
+ self.index_html_path = index_html_path
213
+ self.assets_folder = assets_folder
205
214
 
206
215
  log = logging.getLogger("werkzeug")
207
216
  log.setLevel(logging.ERROR)
@@ -247,6 +256,7 @@ class VannaFlaskApp:
247
256
  "ask_results_correct": self.ask_results_correct,
248
257
  "followup_questions": self.followup_questions,
249
258
  "summarization": self.summarization,
259
+ "function_generation": self.function_generation,
250
260
  }
251
261
 
252
262
  config = self.auth.override_config_for_user(user, config)
@@ -345,6 +355,56 @@ class VannaFlaskApp:
345
355
  }
346
356
  )
347
357
 
358
+ @self.flask_app.route("/api/v0/get_function", methods=["GET"])
359
+ @self.requires_auth
360
+ def get_function(user: any):
361
+ question = flask.request.args.get("question")
362
+
363
+ if question is None:
364
+ return jsonify({"type": "error", "error": "No question provided"})
365
+
366
+ if not hasattr(vn, "get_function"):
367
+ return jsonify({"type": "error", "error": "This setup does not support function generation."})
368
+
369
+ id = self.cache.generate_id(question=question)
370
+ function = vn.get_function(question=question)
371
+
372
+ if function is None:
373
+ return jsonify({"type": "error", "error": "No function found"})
374
+
375
+ if 'instantiated_sql' not in function:
376
+ self.vn.log(f"No instantiated SQL found for {question} in {function}")
377
+ return jsonify({"type": "error", "error": "No instantiated SQL found"})
378
+
379
+ self.cache.set(id=id, field="question", value=question)
380
+ self.cache.set(id=id, field="sql", value=function['instantiated_sql'])
381
+
382
+ if 'instantiated_post_processing_code' in function and function['instantiated_post_processing_code'] is not None and len(function['instantiated_post_processing_code']) > 0:
383
+ self.cache.set(id=id, field="plotly_code", value=function['instantiated_post_processing_code'])
384
+
385
+ return jsonify(
386
+ {
387
+ "type": "function",
388
+ "id": id,
389
+ "function": function,
390
+ }
391
+ )
392
+
393
+ @self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
394
+ @self.requires_auth
395
+ def get_all_functions(user: any):
396
+ if not hasattr(vn, "get_all_functions"):
397
+ return jsonify({"type": "error", "error": "This setup does not support function generation."})
398
+
399
+ functions = vn.get_all_functions()
400
+
401
+ return jsonify(
402
+ {
403
+ "type": "functions",
404
+ "functions": functions,
405
+ }
406
+ )
407
+
348
408
  @self.flask_app.route("/api/v0/run_sql", methods=["GET"])
349
409
  @self.requires_auth
350
410
  @self.requires_cache(["sql"])
@@ -438,11 +498,18 @@ class VannaFlaskApp:
438
498
  question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
439
499
 
440
500
  try:
441
- code = vn.generate_plotly_code(
442
- question=question,
443
- sql=sql,
444
- df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
445
- )
501
+ # If chart_instructions is not set then attempt to retrieve the code from the cache
502
+ if chart_instructions is None or len(chart_instructions) == 0:
503
+ code = self.cache.get(id=id, field="plotly_code")
504
+
505
+ if code is None:
506
+ code = vn.generate_plotly_code(
507
+ question=question,
508
+ sql=sql,
509
+ df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
510
+ )
511
+ self.cache.set(id=id, field="plotly_code", value=code)
512
+
446
513
  fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
447
514
  fig_json = fig.to_json()
448
515
 
@@ -518,6 +585,46 @@ class VannaFlaskApp:
518
585
  print("TRAINING ERROR", e)
519
586
  return jsonify({"type": "error", "error": str(e)})
520
587
 
588
+ @self.flask_app.route("/api/v0/create_function", methods=["GET"])
589
+ @self.requires_auth
590
+ @self.requires_cache(["question", "sql"])
591
+ def create_function(user: any, id: str, question: str, sql: str):
592
+ plotly_code = self.cache.get(id=id, field="plotly_code")
593
+
594
+ if plotly_code is None:
595
+ plotly_code = ""
596
+
597
+ function_data = self.vn.create_function(question=question, sql=sql, plotly_code=plotly_code)
598
+
599
+ return jsonify(
600
+ {
601
+ "type": "function_template",
602
+ "id": id,
603
+ "function_template": function_data,
604
+ }
605
+ )
606
+
607
+ @self.flask_app.route("/api/v0/update_function", methods=["POST"])
608
+ @self.requires_auth
609
+ def update_function(user: any):
610
+ old_function_name = flask.request.json.get("old_function_name")
611
+ updated_function = flask.request.json.get("updated_function")
612
+
613
+ print("old_function_name", old_function_name)
614
+ print("updated_function", updated_function)
615
+
616
+ updated = vn.update_function(old_function_name=old_function_name, updated_function=updated_function)
617
+
618
+ return jsonify({"success": updated})
619
+
620
+ @self.flask_app.route("/api/v0/delete_function", methods=["POST"])
621
+ @self.requires_auth
622
+ def delete_function(user: any):
623
+ function_name = flask.request.json.get("function_name")
624
+
625
+ return jsonify({"success": vn.delete_function(function_name=function_name)})
626
+
627
+
521
628
  @self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
522
629
  @self.requires_auth
523
630
  @self.requires_cache(["df", "question", "sql"])
@@ -578,8 +685,8 @@ class VannaFlaskApp:
578
685
  @self.flask_app.route("/api/v0/load_question", methods=["GET"])
579
686
  @self.requires_auth
580
687
  @self.requires_cache(
581
- ["question", "sql", "df", "fig_json"],
582
- optional_fields=["summary"]
688
+ ["question", "sql", "df"],
689
+ optional_fields=["summary", "fig_json"]
583
690
  )
584
691
  def load_question(user: any, id: str, question, sql, df, fig_json, summary):
585
692
  try:
@@ -616,6 +723,9 @@ class VannaFlaskApp:
616
723
 
617
724
  @self.flask_app.route("/assets/<path:filename>")
618
725
  def proxy_assets(filename):
726
+ if self.assets_folder:
727
+ return send_from_directory(self.assets_folder, filename)
728
+
619
729
  if ".css" in filename:
620
730
  return Response(css_content, mimetype="text/css")
621
731
 
@@ -663,6 +773,10 @@ class VannaFlaskApp:
663
773
  @self.flask_app.route("/", defaults={"path": ""})
664
774
  @self.flask_app.route("/<path:path>")
665
775
  def hello(path: str):
776
+ if self.index_html_path:
777
+ directory = os.path.dirname(self.index_html_path)
778
+ filename = os.path.basename(self.index_html_path)
779
+ return send_from_directory(directory=directory, path=filename)
666
780
  return html_content
667
781
 
668
782
  def run(self, *args, **kwargs):
@@ -692,4 +806,4 @@ class VannaFlaskApp:
692
806
  print("Your app is running at:")
693
807
  print("http://localhost:8084")
694
808
 
695
- self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
809
+ self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)