vanna 0.2.1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
1
  import re
2
2
  from typing import List
3
+
3
4
  import pandas as pd
4
5
  from zhipuai import ZhipuAI
6
+
5
7
  from ..base import VannaBase
6
- import re
7
- from typing import List
8
- import pandas as pd
8
+
9
9
 
10
10
  class ZhipuAI_Chat(VannaBase):
11
11
  def __init__(self, config=None):
vanna/ZhipuAI/__init__.py CHANGED
@@ -0,0 +1,2 @@
1
+ from .ZhipuAI_Chat import ZhipuAI_Chat
2
+ from .ZhipuAI_embeddings import ZhipuAI_Embeddings, ZhipuAIEmbeddingFunction
@@ -0,0 +1 @@
1
+ from .anthropic_chat import Anthropic_Chat
@@ -0,0 +1,78 @@
1
+ import os
2
+
3
+ import anthropic
4
+
5
+ from ..base import VannaBase
6
+
7
+
8
+ class Anthropic_Chat(VannaBase):
9
+ def __init__(self, client=None, config=None):
10
+ VannaBase.__init__(self, config=config)
11
+
12
+ if client is not None:
13
+ self.client = client
14
+ return
15
+
16
+ if config is None and client is None:
17
+ self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
18
+ return
19
+
20
+ # default parameters - can be overrided using config
21
+ self.temperature = 0.7
22
+ self.max_tokens = 500
23
+
24
+ if "temperature" in config:
25
+ self.temperature = config["temperature"]
26
+
27
+ if "max_tokens" in config:
28
+ self.max_tokens = config["max_tokens"]
29
+
30
+ if "api_key" in config:
31
+ self.client = anthropic.Anthropic(api_key=config["api_key"])
32
+
33
+ def system_message(self, message: str) -> any:
34
+ return {"role": "system", "content": message}
35
+
36
+ def user_message(self, message: str) -> any:
37
+ return {"role": "user", "content": message}
38
+
39
+ def assistant_message(self, message: str) -> any:
40
+ return {"role": "assistant", "content": message}
41
+
42
+ def submit_prompt(self, prompt, **kwargs) -> str:
43
+ if prompt is None:
44
+ raise Exception("Prompt is None")
45
+
46
+ if len(prompt) == 0:
47
+ raise Exception("Prompt is empty")
48
+
49
+ # Count the number of tokens in the message log
50
+ # Use 4 as an approximation for the number of characters per token
51
+ num_tokens = 0
52
+ for message in prompt:
53
+ num_tokens += len(message["content"]) / 4
54
+
55
+ if self.config is not None and "model" in self.config:
56
+ print(
57
+ f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
58
+ )
59
+ # claude required system message is a single filed
60
+ # https://docs.anthropic.com/claude/reference/messages_post
61
+ system_message = ''
62
+ no_system_prompt = []
63
+ for prompt_message in prompt:
64
+ role = prompt_message['role']
65
+ if role == 'system':
66
+ system_message = prompt_message['content']
67
+ else:
68
+ no_system_prompt.append({"role": role, "content": prompt_message['content']})
69
+
70
+ response = self.client.messages.create(
71
+ model=self.config["model"],
72
+ messages=no_system_prompt,
73
+ system=system_message,
74
+ max_tokens=self.max_tokens,
75
+ temperature=self.temperature,
76
+ )
77
+
78
+ return response.content[0].text
vanna/base/base.py CHANGED
@@ -103,10 +103,15 @@ class VannaBase(ABC):
103
103
  Returns:
104
104
  str: The SQL query that answers the question.
105
105
  """
106
+ if self.config is not None:
107
+ initial_prompt = self.config.get("initial_prompt", None)
108
+ else:
109
+ initial_prompt = None
106
110
  question_sql_list = self.get_similar_question_sql(question, **kwargs)
107
111
  ddl_list = self.get_related_ddl(question, **kwargs)
108
112
  doc_list = self.get_related_documentation(question, **kwargs)
109
113
  prompt = self.get_sql_prompt(
114
+ initial_prompt=initial_prompt,
110
115
  question=question,
111
116
  question_sql_list=question_sql_list,
112
117
  ddl_list=ddl_list,
@@ -405,6 +410,7 @@ class VannaBase(ABC):
405
410
 
406
411
  def get_sql_prompt(
407
412
  self,
413
+ initial_prompt : str,
408
414
  question: str,
409
415
  question_sql_list: list,
410
416
  ddl_list: list,
@@ -434,7 +440,9 @@ class VannaBase(ABC):
434
440
  Returns:
435
441
  any: The prompt for the LLM to generate SQL.
436
442
  """
437
- initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
443
+
444
+ if initial_prompt is None:
445
+ initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"
438
446
 
439
447
  initial_prompt = self.add_ddl_to_prompt(
440
448
  initial_prompt, ddl_list, max_tokens=14000
@@ -796,6 +804,91 @@ class VannaBase(ABC):
796
804
  self.run_sql_is_set = True
797
805
  self.run_sql = run_sql_postgres
798
806
 
807
+
808
+ def connect_to_mysql(
809
+ self,
810
+ host: str = None,
811
+ dbname: str = None,
812
+ user: str = None,
813
+ password: str = None,
814
+ port: int = None,
815
+ ):
816
+
817
+ try:
818
+ import pymysql.cursors
819
+ except ImportError:
820
+ raise DependencyError(
821
+ "You need to install required dependencies to execute this method,"
822
+ " run command: \npip install PyMySQL"
823
+ )
824
+
825
+ if not host:
826
+ host = os.getenv("HOST")
827
+
828
+ if not host:
829
+ raise ImproperlyConfigured("Please set your MySQL host")
830
+
831
+ if not dbname:
832
+ dbname = os.getenv("DATABASE")
833
+
834
+ if not dbname:
835
+ raise ImproperlyConfigured("Please set your MySQL database")
836
+
837
+ if not user:
838
+ user = os.getenv("USER")
839
+
840
+ if not user:
841
+ raise ImproperlyConfigured("Please set your MySQL user")
842
+
843
+ if not password:
844
+ password = os.getenv("PASSWORD")
845
+
846
+ if not password:
847
+ raise ImproperlyConfigured("Please set your MySQL password")
848
+
849
+ if not port:
850
+ port = os.getenv("PORT")
851
+
852
+ if not port:
853
+ raise ImproperlyConfigured("Please set your MySQL port")
854
+
855
+ conn = None
856
+
857
+ try:
858
+ conn = pymysql.connect(host=host,
859
+ user=user,
860
+ password=password,
861
+ database=dbname,
862
+ port=port,
863
+ cursorclass=pymysql.cursors.DictCursor)
864
+ except pymysql.Error as e:
865
+ raise ValidationError(e)
866
+
867
+ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
868
+ if conn:
869
+ try:
870
+ cs = conn.cursor()
871
+ cs.execute(sql)
872
+ results = cs.fetchall()
873
+
874
+ # Create a pandas dataframe from the results
875
+ df = pd.DataFrame(
876
+ results, columns=[desc[0] for desc in cs.description]
877
+ )
878
+ return df
879
+
880
+ except pymysql.Error as e:
881
+ conn.rollback()
882
+ raise ValidationError(e)
883
+
884
+ except Exception as e:
885
+ conn.rollback()
886
+ raise e
887
+
888
+ self.run_sql_is_set = True
889
+ self.run_sql = run_sql_mysql
890
+
891
+
799
892
  def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
800
893
  """
801
894
  Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
@@ -0,0 +1 @@
1
+ from .chromadb_vector import ChromaDB_VectorStore
vanna/flask/__init__.py CHANGED
@@ -75,7 +75,9 @@ class VannaFlaskApp:
75
75
  id = request.args.get("id")
76
76
 
77
77
  if id is None:
78
- return jsonify({"type": "error", "error": "No id provided"})
78
+ id = request.json.get("id")
79
+ if id is None:
80
+ return jsonify({"type": "error", "error": "No id provided"})
79
81
 
80
82
  for field in fields:
81
83
  if self.cache.get(id=id, field=field) is None:
@@ -94,15 +96,94 @@ class VannaFlaskApp:
94
96
 
95
97
  return decorator
96
98
 
97
- def __init__(self, vn, cache: Cache = MemoryCache(), allow_llm_to_see_data=False):
99
+ def __init__(self, vn, cache: Cache = MemoryCache(),
100
+ allow_llm_to_see_data=False,
101
+ logo="https://img.vanna.ai/vanna-flask.svg",
102
+ title="Welcome to Vanna.AI",
103
+ subtitle="Your AI-powered copilot for SQL queries.",
104
+ show_training_data=True,
105
+ suggested_questions=True,
106
+ sql=True,
107
+ table=True,
108
+ csv_download=True,
109
+ chart=True,
110
+ redraw_chart=True,
111
+ auto_fix_sql=True,
112
+ ask_results_correct=True,
113
+ followup_questions=True,
114
+ summarization=True
115
+ ):
116
+ """
117
+ Expose a Flask app that can be used to interact with a Vanna instance.
118
+
119
+ Args:
120
+ vn: The Vanna instance to interact with.
121
+ cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
122
+ allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
123
+ logo: The logo to display in the UI. Defaults to the Vanna logo.
124
+ title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
125
+ subtitle: The subtitle to display in the UI. Defaults to "Your AI-powered copilot for SQL queries.".
126
+ show_training_data: Whether to show the training data in the UI. Defaults to True.
127
+ suggested_questions: Whether to show suggested questions in the UI. Defaults to True.
128
+ sql: Whether to show the SQL input in the UI. Defaults to True.
129
+ table: Whether to show the table output in the UI. Defaults to True.
130
+ csv_download: Whether to allow downloading the table output as a CSV file. Defaults to True.
131
+ chart: Whether to show the chart output in the UI. Defaults to True.
132
+ redraw_chart: Whether to allow redrawing the chart. Defaults to True.
133
+ auto_fix_sql: Whether to allow auto-fixing SQL errors. Defaults to True.
134
+ ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
135
+ followup_questions: Whether to show followup questions. Defaults to True.
136
+ summarization: Whether to show summarization. Defaults to True.
137
+
138
+ Returns:
139
+ None
140
+ """
98
141
  self.flask_app = Flask(__name__)
99
142
  self.vn = vn
100
143
  self.cache = cache
101
144
  self.allow_llm_to_see_data = allow_llm_to_see_data
145
+ self.logo = logo
146
+ self.title = title
147
+ self.subtitle = subtitle
148
+ self.show_training_data = show_training_data
149
+ self.suggested_questions = suggested_questions
150
+ self.sql = sql
151
+ self.table = table
152
+ self.csv_download = csv_download
153
+ self.chart = chart
154
+ self.redraw_chart = redraw_chart
155
+ self.auto_fix_sql = auto_fix_sql
156
+ self.ask_results_correct = ask_results_correct
157
+ self.followup_questions = followup_questions
158
+ self.summarization = summarization
102
159
 
103
160
  log = logging.getLogger("werkzeug")
104
161
  log.setLevel(logging.ERROR)
105
162
 
163
+ @self.flask_app.route("/api/v0/get_config", methods=["GET"])
164
+ def get_config():
165
+ return jsonify(
166
+ {
167
+ "type": "config",
168
+ "config": {
169
+ "logo": self.logo,
170
+ "title": self.title,
171
+ "subtitle": self.subtitle,
172
+ "show_training_data": self.show_training_data,
173
+ "suggested_questions": self.suggested_questions,
174
+ "sql": self.sql,
175
+ "table": self.table,
176
+ "csv_download": self.csv_download,
177
+ "chart": self.chart,
178
+ "redraw_chart": self.redraw_chart,
179
+ "auto_fix_sql": self.auto_fix_sql,
180
+ "ask_results_correct": self.ask_results_correct,
181
+ "followup_questions": self.followup_questions,
182
+ "summarization": self.summarization,
183
+ },
184
+ }
185
+ )
186
+
106
187
  @self.flask_app.route("/api/v0/generate_questions", methods=["GET"])
107
188
  def generate_questions():
108
189
  # If self has an _model attribute and model=='chinook'
@@ -199,12 +280,52 @@ class VannaFlaskApp:
199
280
  {
200
281
  "type": "df",
201
282
  "id": id,
202
- "df": df.head(10).to_json(orient="records"),
283
+ "df": df.head(10).to_json(orient='records', date_format='iso'),
203
284
  }
204
285
  )
205
286
 
206
287
  except Exception as e:
207
- return jsonify({"type": "error", "error": str(e)})
288
+ return jsonify({"type": "sql_error", "error": str(e)})
289
+
290
+ @self.flask_app.route("/api/v0/fix_sql", methods=["POST"])
291
+ @self.requires_cache(["question", "sql"])
292
+ def fix_sql(id: str, question:str, sql: str):
293
+ error = flask.request.json.get("error")
294
+
295
+ if error is None:
296
+ return jsonify({"type": "error", "error": "No error provided"})
297
+
298
+ question = f"I have an error: {error}\n\nHere is the SQL I tried to run: {sql}\n\nThis is the question I was trying to answer: {question}\n\nCan you rewrite the SQL to fix the error?"
299
+
300
+ fixed_sql = vn.generate_sql(question=question)
301
+
302
+ self.cache.set(id=id, field="sql", value=fixed_sql)
303
+
304
+ return jsonify(
305
+ {
306
+ "type": "sql",
307
+ "id": id,
308
+ "text": fixed_sql,
309
+ }
310
+ )
311
+
312
+
313
+ @self.flask_app.route('/api/v0/update_sql', methods=['POST'])
314
+ @self.requires_cache([])
315
+ def update_sql(id: str):
316
+ sql = flask.request.json.get('sql')
317
+
318
+ if sql is None:
319
+ return jsonify({"type": "error", "error": "No sql provided"})
320
+
321
+ cache.set(id=id, field='sql', value=sql)
322
+
323
+ return jsonify(
324
+ {
325
+ "type": "sql",
326
+ "id": id,
327
+ "text": sql,
328
+ })
208
329
 
209
330
  @self.flask_app.route("/api/v0/download_csv", methods=["GET"])
210
331
  @self.requires_cache(["df"])
@@ -220,6 +341,11 @@ class VannaFlaskApp:
220
341
  @self.flask_app.route("/api/v0/generate_plotly_figure", methods=["GET"])
221
342
  @self.requires_cache(["df", "question", "sql"])
222
343
  def generate_plotly_figure(id: str, df, question, sql):
344
+ chart_instructions = flask.request.args.get('chart_instructions')
345
+
346
+ if chart_instructions is not None:
347
+ question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
348
+
223
349
  try:
224
350
  code = vn.generate_plotly_code(
225
351
  question=question,
@@ -319,6 +445,7 @@ class VannaFlaskApp:
319
445
  }
320
446
  )
321
447
  else:
448
+ cache.set(id=id, field="followup_questions", value=[])
322
449
  return jsonify(
323
450
  {
324
451
  "type": "question_list",
@@ -351,9 +478,9 @@ class VannaFlaskApp:
351
478
 
352
479
  @self.flask_app.route("/api/v0/load_question", methods=["GET"])
353
480
  @self.requires_cache(
354
- ["question", "sql", "df", "fig_json", "followup_questions"]
481
+ ["question", "sql", "df", "fig_json"]
355
482
  )
356
- def load_question(id: str, question, sql, df, fig_json, followup_questions):
483
+ def load_question(id: str, question, sql, df, fig_json):
357
484
  try:
358
485
  return jsonify(
359
486
  {
@@ -363,7 +490,6 @@ class VannaFlaskApp:
363
490
  "sql": sql,
364
491
  "df": df.head(10).to_json(orient="records"),
365
492
  "fig": fig_json,
366
- "followup_questions": followup_questions,
367
493
  }
368
494
  )
369
495
 
@@ -424,16 +550,31 @@ class VannaFlaskApp:
424
550
  def hello(path: str):
425
551
  return html_content
426
552
 
427
- def run(self):
428
- try:
429
- from google.colab import output
553
+ def run(self, *args, **kwargs):
554
+ """
555
+ Run the Flask app.
556
+
557
+ Args:
558
+ *args: Arguments to pass to Flask's run method.
559
+ **kwargs: Keyword arguments to pass to Flask's run method.
560
+
561
+ Returns:
562
+ None
563
+ """
564
+ if args or kwargs:
565
+ self.flask_app.run(*args, **kwargs)
566
+
567
+ else:
568
+ try:
569
+ from google.colab import output
570
+
571
+ output.serve_kernel_port_as_window(8084)
572
+ from google.colab.output import eval_js
430
573
 
431
- output.serve_kernel_port_as_window(8084)
432
- from google.colab.output import eval_js
574
+ print("Your app is running at:")
575
+ print(eval_js("google.colab.kernel.proxyPort(8084)"))
576
+ except:
577
+ print("Your app is running at:")
578
+ print("http://localhost:8084")
433
579
 
434
- print("Your app is running at:")
435
- print(eval_js("google.colab.kernel.proxyPort(8084)"))
436
- except:
437
- print("Your app is running at:")
438
- print("http://localhost:8084")
439
- self.flask_app.run(host="0.0.0.0", port=8084, debug=False)
580
+ self.flask_app.run(host="0.0.0.0", port=8084, debug=False)