vanna 0.5.5__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/advanced/__init__.py +26 -0
- vanna/base/base.py +24 -24
- vanna/flask/__init__.py +124 -10
- vanna/flask/assets.py +36 -16
- vanna/qdrant/qdrant.py +3 -1
- vanna/vannadb/vannadb_vector.py +179 -1
- vanna/vllm/vllm.py +16 -1
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/METADATA +2 -2
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/RECORD +10 -9
- {vanna-0.5.5.dist-info → vanna-0.6.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class VannaAdvanced(ABC):
|
|
5
|
+
def __init__(self, config=None):
|
|
6
|
+
self.config = config
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def delete_function(self, function_name: str) -> bool:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_all_functions(self) -> list:
|
|
26
|
+
pass
|
vanna/base/base.py
CHANGED
|
@@ -79,9 +79,10 @@ class VannaBase(ABC):
|
|
|
79
79
|
self.static_documentation = ""
|
|
80
80
|
self.dialect = self.config.get("dialect", "SQL")
|
|
81
81
|
self.language = self.config.get("language", None)
|
|
82
|
+
self.max_tokens = self.config.get("max_tokens", 14000)
|
|
82
83
|
|
|
83
84
|
def log(self, message: str, title: str = "Info"):
|
|
84
|
-
print(message)
|
|
85
|
+
print(f"{title}: {message}")
|
|
85
86
|
|
|
86
87
|
def _response_language(self) -> str:
|
|
87
88
|
if self.language is None:
|
|
@@ -559,14 +560,14 @@ class VannaBase(ABC):
|
|
|
559
560
|
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
560
561
|
|
|
561
562
|
initial_prompt = self.add_ddl_to_prompt(
|
|
562
|
-
initial_prompt, ddl_list, max_tokens=
|
|
563
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
563
564
|
)
|
|
564
565
|
|
|
565
566
|
if self.static_documentation != "":
|
|
566
567
|
doc_list.append(self.static_documentation)
|
|
567
568
|
|
|
568
569
|
initial_prompt = self.add_documentation_to_prompt(
|
|
569
|
-
initial_prompt, doc_list, max_tokens=
|
|
570
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
570
571
|
)
|
|
571
572
|
|
|
572
573
|
initial_prompt += (
|
|
@@ -603,15 +604,15 @@ class VannaBase(ABC):
|
|
|
603
604
|
initial_prompt = f"The user initially asked the question: '{question}': \n\n"
|
|
604
605
|
|
|
605
606
|
initial_prompt = self.add_ddl_to_prompt(
|
|
606
|
-
initial_prompt, ddl_list, max_tokens=
|
|
607
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
607
608
|
)
|
|
608
609
|
|
|
609
610
|
initial_prompt = self.add_documentation_to_prompt(
|
|
610
|
-
initial_prompt, doc_list, max_tokens=
|
|
611
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
611
612
|
)
|
|
612
613
|
|
|
613
614
|
initial_prompt = self.add_sql_to_prompt(
|
|
614
|
-
initial_prompt, question_sql_list, max_tokens=
|
|
615
|
+
initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
|
615
616
|
)
|
|
616
617
|
|
|
617
618
|
message_log = [self.system_message(initial_prompt)]
|
|
@@ -1022,11 +1023,11 @@ class VannaBase(ABC):
|
|
|
1022
1023
|
):
|
|
1023
1024
|
|
|
1024
1025
|
try:
|
|
1025
|
-
|
|
1026
|
+
import clickhouse_connect
|
|
1026
1027
|
except ImportError:
|
|
1027
1028
|
raise DependencyError(
|
|
1028
1029
|
"You need to install required dependencies to execute this method,"
|
|
1029
|
-
" run command: \npip install
|
|
1030
|
+
" run command: \npip install clickhouse_connect"
|
|
1030
1031
|
)
|
|
1031
1032
|
|
|
1032
1033
|
if not host:
|
|
@@ -1062,12 +1063,13 @@ class VannaBase(ABC):
|
|
|
1062
1063
|
conn = None
|
|
1063
1064
|
|
|
1064
1065
|
try:
|
|
1065
|
-
conn =
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1066
|
+
conn = clickhouse_connect.get_client(
|
|
1067
|
+
host=host,
|
|
1068
|
+
port=port,
|
|
1069
|
+
username=user,
|
|
1070
|
+
password=password,
|
|
1071
|
+
database=dbname,
|
|
1072
|
+
)
|
|
1071
1073
|
print(conn)
|
|
1072
1074
|
except Exception as e:
|
|
1073
1075
|
raise ValidationError(e)
|
|
@@ -1075,19 +1077,16 @@ class VannaBase(ABC):
|
|
|
1075
1077
|
def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
|
|
1076
1078
|
if conn:
|
|
1077
1079
|
try:
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
results = cs.fetchall()
|
|
1080
|
+
result = conn.query(sql)
|
|
1081
|
+
results = result.result_rows
|
|
1081
1082
|
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
)
|
|
1086
|
-
return df
|
|
1083
|
+
# Create a pandas dataframe from the results
|
|
1084
|
+
df = pd.DataFrame(results, columns=result.column_names)
|
|
1085
|
+
return df
|
|
1087
1086
|
|
|
1088
1087
|
except Exception as e:
|
|
1089
1088
|
raise e
|
|
1090
|
-
|
|
1089
|
+
|
|
1091
1090
|
self.run_sql_is_set = True
|
|
1092
1091
|
self.run_sql = run_sql_clickhouse
|
|
1093
1092
|
|
|
@@ -1597,6 +1596,7 @@ class VannaBase(ABC):
|
|
|
1597
1596
|
print_results: bool = True,
|
|
1598
1597
|
auto_train: bool = True,
|
|
1599
1598
|
visualize: bool = True, # if False, will not generate plotly code
|
|
1599
|
+
allow_llm_to_see_data: bool = False,
|
|
1600
1600
|
) -> Union[
|
|
1601
1601
|
Tuple[
|
|
1602
1602
|
Union[str, None],
|
|
@@ -1627,7 +1627,7 @@ class VannaBase(ABC):
|
|
|
1627
1627
|
question = input("Enter a question: ")
|
|
1628
1628
|
|
|
1629
1629
|
try:
|
|
1630
|
-
sql = self.generate_sql(question=question)
|
|
1630
|
+
sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
|
|
1631
1631
|
except Exception as e:
|
|
1632
1632
|
print(e)
|
|
1633
1633
|
return None, None, None
|
vanna/flask/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
import uuid
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -7,7 +8,7 @@ from functools import wraps
|
|
|
7
8
|
|
|
8
9
|
import flask
|
|
9
10
|
import requests
|
|
10
|
-
from flask import Flask, Response, jsonify, request
|
|
11
|
+
from flask import Flask, Response, jsonify, request, send_from_directory
|
|
11
12
|
from flask_sock import Sock
|
|
12
13
|
|
|
13
14
|
from .assets import css_content, html_content, js_content
|
|
@@ -151,7 +152,10 @@ class VannaFlaskApp:
|
|
|
151
152
|
auto_fix_sql=True,
|
|
152
153
|
ask_results_correct=True,
|
|
153
154
|
followup_questions=True,
|
|
154
|
-
summarization=True
|
|
155
|
+
summarization=True,
|
|
156
|
+
function_generation=True,
|
|
157
|
+
index_html_path=None,
|
|
158
|
+
assets_folder=None,
|
|
155
159
|
):
|
|
156
160
|
"""
|
|
157
161
|
Expose a Flask app that can be used to interact with a Vanna instance.
|
|
@@ -176,6 +180,8 @@ class VannaFlaskApp:
|
|
|
176
180
|
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
|
|
177
181
|
followup_questions: Whether to show followup questions. Defaults to True.
|
|
178
182
|
summarization: Whether to show summarization. Defaults to True.
|
|
183
|
+
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
|
|
184
|
+
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
|
|
179
185
|
|
|
180
186
|
Returns:
|
|
181
187
|
None
|
|
@@ -202,6 +208,9 @@ class VannaFlaskApp:
|
|
|
202
208
|
self.ask_results_correct = ask_results_correct
|
|
203
209
|
self.followup_questions = followup_questions
|
|
204
210
|
self.summarization = summarization
|
|
211
|
+
self.function_generation = function_generation and hasattr(vn, "get_function")
|
|
212
|
+
self.index_html_path = index_html_path
|
|
213
|
+
self.assets_folder = assets_folder
|
|
205
214
|
|
|
206
215
|
log = logging.getLogger("werkzeug")
|
|
207
216
|
log.setLevel(logging.ERROR)
|
|
@@ -247,6 +256,7 @@ class VannaFlaskApp:
|
|
|
247
256
|
"ask_results_correct": self.ask_results_correct,
|
|
248
257
|
"followup_questions": self.followup_questions,
|
|
249
258
|
"summarization": self.summarization,
|
|
259
|
+
"function_generation": self.function_generation,
|
|
250
260
|
}
|
|
251
261
|
|
|
252
262
|
config = self.auth.override_config_for_user(user, config)
|
|
@@ -345,6 +355,56 @@ class VannaFlaskApp:
|
|
|
345
355
|
}
|
|
346
356
|
)
|
|
347
357
|
|
|
358
|
+
@self.flask_app.route("/api/v0/get_function", methods=["GET"])
|
|
359
|
+
@self.requires_auth
|
|
360
|
+
def get_function(user: any):
|
|
361
|
+
question = flask.request.args.get("question")
|
|
362
|
+
|
|
363
|
+
if question is None:
|
|
364
|
+
return jsonify({"type": "error", "error": "No question provided"})
|
|
365
|
+
|
|
366
|
+
if not hasattr(vn, "get_function"):
|
|
367
|
+
return jsonify({"type": "error", "error": "This setup does not support function generation."})
|
|
368
|
+
|
|
369
|
+
id = self.cache.generate_id(question=question)
|
|
370
|
+
function = vn.get_function(question=question)
|
|
371
|
+
|
|
372
|
+
if function is None:
|
|
373
|
+
return jsonify({"type": "error", "error": "No function found"})
|
|
374
|
+
|
|
375
|
+
if 'instantiated_sql' not in function:
|
|
376
|
+
self.vn.log(f"No instantiated SQL found for {question} in {function}")
|
|
377
|
+
return jsonify({"type": "error", "error": "No instantiated SQL found"})
|
|
378
|
+
|
|
379
|
+
self.cache.set(id=id, field="question", value=question)
|
|
380
|
+
self.cache.set(id=id, field="sql", value=function['instantiated_sql'])
|
|
381
|
+
|
|
382
|
+
if 'instantiated_post_processing_code' in function and function['instantiated_post_processing_code'] is not None and len(function['instantiated_post_processing_code']) > 0:
|
|
383
|
+
self.cache.set(id=id, field="plotly_code", value=function['instantiated_post_processing_code'])
|
|
384
|
+
|
|
385
|
+
return jsonify(
|
|
386
|
+
{
|
|
387
|
+
"type": "function",
|
|
388
|
+
"id": id,
|
|
389
|
+
"function": function,
|
|
390
|
+
}
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
|
|
394
|
+
@self.requires_auth
|
|
395
|
+
def get_all_functions(user: any):
|
|
396
|
+
if not hasattr(vn, "get_all_functions"):
|
|
397
|
+
return jsonify({"type": "error", "error": "This setup does not support function generation."})
|
|
398
|
+
|
|
399
|
+
functions = vn.get_all_functions()
|
|
400
|
+
|
|
401
|
+
return jsonify(
|
|
402
|
+
{
|
|
403
|
+
"type": "functions",
|
|
404
|
+
"functions": functions,
|
|
405
|
+
}
|
|
406
|
+
)
|
|
407
|
+
|
|
348
408
|
@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
|
|
349
409
|
@self.requires_auth
|
|
350
410
|
@self.requires_cache(["sql"])
|
|
@@ -438,11 +498,18 @@ class VannaFlaskApp:
|
|
|
438
498
|
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
|
|
439
499
|
|
|
440
500
|
try:
|
|
441
|
-
code
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
501
|
+
# If chart_instructions is not set then attempt to retrieve the code from the cache
|
|
502
|
+
if chart_instructions is None or len(chart_instructions) == 0:
|
|
503
|
+
code = self.cache.get(id=id, field="plotly_code")
|
|
504
|
+
|
|
505
|
+
if code is None:
|
|
506
|
+
code = vn.generate_plotly_code(
|
|
507
|
+
question=question,
|
|
508
|
+
sql=sql,
|
|
509
|
+
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
|
|
510
|
+
)
|
|
511
|
+
self.cache.set(id=id, field="plotly_code", value=code)
|
|
512
|
+
|
|
446
513
|
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
447
514
|
fig_json = fig.to_json()
|
|
448
515
|
|
|
@@ -518,6 +585,46 @@ class VannaFlaskApp:
|
|
|
518
585
|
print("TRAINING ERROR", e)
|
|
519
586
|
return jsonify({"type": "error", "error": str(e)})
|
|
520
587
|
|
|
588
|
+
@self.flask_app.route("/api/v0/create_function", methods=["GET"])
|
|
589
|
+
@self.requires_auth
|
|
590
|
+
@self.requires_cache(["question", "sql"])
|
|
591
|
+
def create_function(user: any, id: str, question: str, sql: str):
|
|
592
|
+
plotly_code = self.cache.get(id=id, field="plotly_code")
|
|
593
|
+
|
|
594
|
+
if plotly_code is None:
|
|
595
|
+
plotly_code = ""
|
|
596
|
+
|
|
597
|
+
function_data = self.vn.create_function(question=question, sql=sql, plotly_code=plotly_code)
|
|
598
|
+
|
|
599
|
+
return jsonify(
|
|
600
|
+
{
|
|
601
|
+
"type": "function_template",
|
|
602
|
+
"id": id,
|
|
603
|
+
"function_template": function_data,
|
|
604
|
+
}
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
@self.flask_app.route("/api/v0/update_function", methods=["POST"])
|
|
608
|
+
@self.requires_auth
|
|
609
|
+
def update_function(user: any):
|
|
610
|
+
old_function_name = flask.request.json.get("old_function_name")
|
|
611
|
+
updated_function = flask.request.json.get("updated_function")
|
|
612
|
+
|
|
613
|
+
print("old_function_name", old_function_name)
|
|
614
|
+
print("updated_function", updated_function)
|
|
615
|
+
|
|
616
|
+
updated = vn.update_function(old_function_name=old_function_name, updated_function=updated_function)
|
|
617
|
+
|
|
618
|
+
return jsonify({"success": updated})
|
|
619
|
+
|
|
620
|
+
@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
|
|
621
|
+
@self.requires_auth
|
|
622
|
+
def delete_function(user: any):
|
|
623
|
+
function_name = flask.request.json.get("function_name")
|
|
624
|
+
|
|
625
|
+
return jsonify({"success": vn.delete_function(function_name=function_name)})
|
|
626
|
+
|
|
627
|
+
|
|
521
628
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
522
629
|
@self.requires_auth
|
|
523
630
|
@self.requires_cache(["df", "question", "sql"])
|
|
@@ -578,8 +685,8 @@ class VannaFlaskApp:
|
|
|
578
685
|
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
|
|
579
686
|
@self.requires_auth
|
|
580
687
|
@self.requires_cache(
|
|
581
|
-
["question", "sql", "df"
|
|
582
|
-
optional_fields=["summary"]
|
|
688
|
+
["question", "sql", "df"],
|
|
689
|
+
optional_fields=["summary", "fig_json"]
|
|
583
690
|
)
|
|
584
691
|
def load_question(user: any, id: str, question, sql, df, fig_json, summary):
|
|
585
692
|
try:
|
|
@@ -616,6 +723,9 @@ class VannaFlaskApp:
|
|
|
616
723
|
|
|
617
724
|
@self.flask_app.route("/assets/<path:filename>")
|
|
618
725
|
def proxy_assets(filename):
|
|
726
|
+
if self.assets_folder:
|
|
727
|
+
return send_from_directory(self.assets_folder, filename)
|
|
728
|
+
|
|
619
729
|
if ".css" in filename:
|
|
620
730
|
return Response(css_content, mimetype="text/css")
|
|
621
731
|
|
|
@@ -663,6 +773,10 @@ class VannaFlaskApp:
|
|
|
663
773
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
664
774
|
@self.flask_app.route("/<path:path>")
|
|
665
775
|
def hello(path: str):
|
|
776
|
+
if self.index_html_path:
|
|
777
|
+
directory = os.path.dirname(self.index_html_path)
|
|
778
|
+
filename = os.path.basename(self.index_html_path)
|
|
779
|
+
return send_from_directory(directory=directory, path=filename)
|
|
666
780
|
return html_content
|
|
667
781
|
|
|
668
782
|
def run(self, *args, **kwargs):
|
|
@@ -692,4 +806,4 @@ class VannaFlaskApp:
|
|
|
692
806
|
print("Your app is running at:")
|
|
693
807
|
print("http://localhost:8084")
|
|
694
808
|
|
|
695
|
-
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
|
|
809
|
+
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
|