vanna 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/advanced/__init__.py +26 -0
- vanna/base/base.py +25 -24
- vanna/flask/__init__.py +124 -10
- vanna/flask/assets.py +36 -16
- vanna/milvus/__init__.py +1 -0
- vanna/milvus/milvus_vector.py +305 -0
- vanna/qdrant/qdrant.py +12 -14
- vanna/vannadb/vannadb_vector.py +179 -1
- vanna/vllm/vllm.py +16 -1
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/METADATA +5 -2
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/RECORD +12 -9
- {vanna-0.5.5.dist-info → vanna-0.6.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class VannaAdvanced(ABC):
|
|
5
|
+
def __init__(self, config=None):
|
|
6
|
+
self.config = config
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def get_function(self, question: str, additional_data: dict = {}) -> dict:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def delete_function(self, function_name: str) -> bool:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_all_functions(self) -> list:
|
|
26
|
+
pass
|
vanna/base/base.py
CHANGED
|
@@ -79,9 +79,10 @@ class VannaBase(ABC):
|
|
|
79
79
|
self.static_documentation = ""
|
|
80
80
|
self.dialect = self.config.get("dialect", "SQL")
|
|
81
81
|
self.language = self.config.get("language", None)
|
|
82
|
+
self.max_tokens = self.config.get("max_tokens", 14000)
|
|
82
83
|
|
|
83
84
|
def log(self, message: str, title: str = "Info"):
|
|
84
|
-
print(message)
|
|
85
|
+
print(f"{title}: {message}")
|
|
85
86
|
|
|
86
87
|
def _response_language(self) -> str:
|
|
87
88
|
if self.language is None:
|
|
@@ -559,14 +560,14 @@ class VannaBase(ABC):
|
|
|
559
560
|
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
560
561
|
|
|
561
562
|
initial_prompt = self.add_ddl_to_prompt(
|
|
562
|
-
initial_prompt, ddl_list, max_tokens=
|
|
563
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
563
564
|
)
|
|
564
565
|
|
|
565
566
|
if self.static_documentation != "":
|
|
566
567
|
doc_list.append(self.static_documentation)
|
|
567
568
|
|
|
568
569
|
initial_prompt = self.add_documentation_to_prompt(
|
|
569
|
-
initial_prompt, doc_list, max_tokens=
|
|
570
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
570
571
|
)
|
|
571
572
|
|
|
572
573
|
initial_prompt += (
|
|
@@ -603,15 +604,15 @@ class VannaBase(ABC):
|
|
|
603
604
|
initial_prompt = f"The user initially asked the question: '{question}': \n\n"
|
|
604
605
|
|
|
605
606
|
initial_prompt = self.add_ddl_to_prompt(
|
|
606
|
-
initial_prompt, ddl_list, max_tokens=
|
|
607
|
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
|
607
608
|
)
|
|
608
609
|
|
|
609
610
|
initial_prompt = self.add_documentation_to_prompt(
|
|
610
|
-
initial_prompt, doc_list, max_tokens=
|
|
611
|
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
|
611
612
|
)
|
|
612
613
|
|
|
613
614
|
initial_prompt = self.add_sql_to_prompt(
|
|
614
|
-
initial_prompt, question_sql_list, max_tokens=
|
|
615
|
+
initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
|
615
616
|
)
|
|
616
617
|
|
|
617
618
|
message_log = [self.system_message(initial_prompt)]
|
|
@@ -991,6 +992,7 @@ class VannaBase(ABC):
|
|
|
991
992
|
def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
|
|
992
993
|
if conn:
|
|
993
994
|
try:
|
|
995
|
+
conn.ping(reconnect=True)
|
|
994
996
|
cs = conn.cursor()
|
|
995
997
|
cs.execute(sql)
|
|
996
998
|
results = cs.fetchall()
|
|
@@ -1022,11 +1024,11 @@ class VannaBase(ABC):
|
|
|
1022
1024
|
):
|
|
1023
1025
|
|
|
1024
1026
|
try:
|
|
1025
|
-
|
|
1027
|
+
import clickhouse_connect
|
|
1026
1028
|
except ImportError:
|
|
1027
1029
|
raise DependencyError(
|
|
1028
1030
|
"You need to install required dependencies to execute this method,"
|
|
1029
|
-
" run command: \npip install
|
|
1031
|
+
" run command: \npip install clickhouse_connect"
|
|
1030
1032
|
)
|
|
1031
1033
|
|
|
1032
1034
|
if not host:
|
|
@@ -1062,12 +1064,13 @@ class VannaBase(ABC):
|
|
|
1062
1064
|
conn = None
|
|
1063
1065
|
|
|
1064
1066
|
try:
|
|
1065
|
-
conn =
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1067
|
+
conn = clickhouse_connect.get_client(
|
|
1068
|
+
host=host,
|
|
1069
|
+
port=port,
|
|
1070
|
+
username=user,
|
|
1071
|
+
password=password,
|
|
1072
|
+
database=dbname,
|
|
1073
|
+
)
|
|
1071
1074
|
print(conn)
|
|
1072
1075
|
except Exception as e:
|
|
1073
1076
|
raise ValidationError(e)
|
|
@@ -1075,19 +1078,16 @@ class VannaBase(ABC):
|
|
|
1075
1078
|
def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]:
|
|
1076
1079
|
if conn:
|
|
1077
1080
|
try:
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
results = cs.fetchall()
|
|
1081
|
+
result = conn.query(sql)
|
|
1082
|
+
results = result.result_rows
|
|
1081
1083
|
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
)
|
|
1086
|
-
return df
|
|
1084
|
+
# Create a pandas dataframe from the results
|
|
1085
|
+
df = pd.DataFrame(results, columns=result.column_names)
|
|
1086
|
+
return df
|
|
1087
1087
|
|
|
1088
1088
|
except Exception as e:
|
|
1089
1089
|
raise e
|
|
1090
|
-
|
|
1090
|
+
|
|
1091
1091
|
self.run_sql_is_set = True
|
|
1092
1092
|
self.run_sql = run_sql_clickhouse
|
|
1093
1093
|
|
|
@@ -1597,6 +1597,7 @@ class VannaBase(ABC):
|
|
|
1597
1597
|
print_results: bool = True,
|
|
1598
1598
|
auto_train: bool = True,
|
|
1599
1599
|
visualize: bool = True, # if False, will not generate plotly code
|
|
1600
|
+
allow_llm_to_see_data: bool = False,
|
|
1600
1601
|
) -> Union[
|
|
1601
1602
|
Tuple[
|
|
1602
1603
|
Union[str, None],
|
|
@@ -1627,7 +1628,7 @@ class VannaBase(ABC):
|
|
|
1627
1628
|
question = input("Enter a question: ")
|
|
1628
1629
|
|
|
1629
1630
|
try:
|
|
1630
|
-
sql = self.generate_sql(question=question)
|
|
1631
|
+
sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
|
|
1631
1632
|
except Exception as e:
|
|
1632
1633
|
print(e)
|
|
1633
1634
|
return None, None, None
|
vanna/flask/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
import uuid
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -7,7 +8,7 @@ from functools import wraps
|
|
|
7
8
|
|
|
8
9
|
import flask
|
|
9
10
|
import requests
|
|
10
|
-
from flask import Flask, Response, jsonify, request
|
|
11
|
+
from flask import Flask, Response, jsonify, request, send_from_directory
|
|
11
12
|
from flask_sock import Sock
|
|
12
13
|
|
|
13
14
|
from .assets import css_content, html_content, js_content
|
|
@@ -151,7 +152,10 @@ class VannaFlaskApp:
|
|
|
151
152
|
auto_fix_sql=True,
|
|
152
153
|
ask_results_correct=True,
|
|
153
154
|
followup_questions=True,
|
|
154
|
-
summarization=True
|
|
155
|
+
summarization=True,
|
|
156
|
+
function_generation=True,
|
|
157
|
+
index_html_path=None,
|
|
158
|
+
assets_folder=None,
|
|
155
159
|
):
|
|
156
160
|
"""
|
|
157
161
|
Expose a Flask app that can be used to interact with a Vanna instance.
|
|
@@ -176,6 +180,8 @@ class VannaFlaskApp:
|
|
|
176
180
|
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
|
|
177
181
|
followup_questions: Whether to show followup questions. Defaults to True.
|
|
178
182
|
summarization: Whether to show summarization. Defaults to True.
|
|
183
|
+
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
|
|
184
|
+
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
|
|
179
185
|
|
|
180
186
|
Returns:
|
|
181
187
|
None
|
|
@@ -202,6 +208,9 @@ class VannaFlaskApp:
|
|
|
202
208
|
self.ask_results_correct = ask_results_correct
|
|
203
209
|
self.followup_questions = followup_questions
|
|
204
210
|
self.summarization = summarization
|
|
211
|
+
self.function_generation = function_generation and hasattr(vn, "get_function")
|
|
212
|
+
self.index_html_path = index_html_path
|
|
213
|
+
self.assets_folder = assets_folder
|
|
205
214
|
|
|
206
215
|
log = logging.getLogger("werkzeug")
|
|
207
216
|
log.setLevel(logging.ERROR)
|
|
@@ -247,6 +256,7 @@ class VannaFlaskApp:
|
|
|
247
256
|
"ask_results_correct": self.ask_results_correct,
|
|
248
257
|
"followup_questions": self.followup_questions,
|
|
249
258
|
"summarization": self.summarization,
|
|
259
|
+
"function_generation": self.function_generation,
|
|
250
260
|
}
|
|
251
261
|
|
|
252
262
|
config = self.auth.override_config_for_user(user, config)
|
|
@@ -345,6 +355,56 @@ class VannaFlaskApp:
|
|
|
345
355
|
}
|
|
346
356
|
)
|
|
347
357
|
|
|
358
|
+
@self.flask_app.route("/api/v0/get_function", methods=["GET"])
|
|
359
|
+
@self.requires_auth
|
|
360
|
+
def get_function(user: any):
|
|
361
|
+
question = flask.request.args.get("question")
|
|
362
|
+
|
|
363
|
+
if question is None:
|
|
364
|
+
return jsonify({"type": "error", "error": "No question provided"})
|
|
365
|
+
|
|
366
|
+
if not hasattr(vn, "get_function"):
|
|
367
|
+
return jsonify({"type": "error", "error": "This setup does not support function generation."})
|
|
368
|
+
|
|
369
|
+
id = self.cache.generate_id(question=question)
|
|
370
|
+
function = vn.get_function(question=question)
|
|
371
|
+
|
|
372
|
+
if function is None:
|
|
373
|
+
return jsonify({"type": "error", "error": "No function found"})
|
|
374
|
+
|
|
375
|
+
if 'instantiated_sql' not in function:
|
|
376
|
+
self.vn.log(f"No instantiated SQL found for {question} in {function}")
|
|
377
|
+
return jsonify({"type": "error", "error": "No instantiated SQL found"})
|
|
378
|
+
|
|
379
|
+
self.cache.set(id=id, field="question", value=question)
|
|
380
|
+
self.cache.set(id=id, field="sql", value=function['instantiated_sql'])
|
|
381
|
+
|
|
382
|
+
if 'instantiated_post_processing_code' in function and function['instantiated_post_processing_code'] is not None and len(function['instantiated_post_processing_code']) > 0:
|
|
383
|
+
self.cache.set(id=id, field="plotly_code", value=function['instantiated_post_processing_code'])
|
|
384
|
+
|
|
385
|
+
return jsonify(
|
|
386
|
+
{
|
|
387
|
+
"type": "function",
|
|
388
|
+
"id": id,
|
|
389
|
+
"function": function,
|
|
390
|
+
}
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
|
|
394
|
+
@self.requires_auth
|
|
395
|
+
def get_all_functions(user: any):
|
|
396
|
+
if not hasattr(vn, "get_all_functions"):
|
|
397
|
+
return jsonify({"type": "error", "error": "This setup does not support function generation."})
|
|
398
|
+
|
|
399
|
+
functions = vn.get_all_functions()
|
|
400
|
+
|
|
401
|
+
return jsonify(
|
|
402
|
+
{
|
|
403
|
+
"type": "functions",
|
|
404
|
+
"functions": functions,
|
|
405
|
+
}
|
|
406
|
+
)
|
|
407
|
+
|
|
348
408
|
@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
|
|
349
409
|
@self.requires_auth
|
|
350
410
|
@self.requires_cache(["sql"])
|
|
@@ -438,11 +498,18 @@ class VannaFlaskApp:
|
|
|
438
498
|
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"
|
|
439
499
|
|
|
440
500
|
try:
|
|
441
|
-
code
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
501
|
+
# If chart_instructions is not set then attempt to retrieve the code from the cache
|
|
502
|
+
if chart_instructions is None or len(chart_instructions) == 0:
|
|
503
|
+
code = self.cache.get(id=id, field="plotly_code")
|
|
504
|
+
|
|
505
|
+
if code is None:
|
|
506
|
+
code = vn.generate_plotly_code(
|
|
507
|
+
question=question,
|
|
508
|
+
sql=sql,
|
|
509
|
+
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
|
|
510
|
+
)
|
|
511
|
+
self.cache.set(id=id, field="plotly_code", value=code)
|
|
512
|
+
|
|
446
513
|
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
447
514
|
fig_json = fig.to_json()
|
|
448
515
|
|
|
@@ -518,6 +585,46 @@ class VannaFlaskApp:
|
|
|
518
585
|
print("TRAINING ERROR", e)
|
|
519
586
|
return jsonify({"type": "error", "error": str(e)})
|
|
520
587
|
|
|
588
|
+
@self.flask_app.route("/api/v0/create_function", methods=["GET"])
|
|
589
|
+
@self.requires_auth
|
|
590
|
+
@self.requires_cache(["question", "sql"])
|
|
591
|
+
def create_function(user: any, id: str, question: str, sql: str):
|
|
592
|
+
plotly_code = self.cache.get(id=id, field="plotly_code")
|
|
593
|
+
|
|
594
|
+
if plotly_code is None:
|
|
595
|
+
plotly_code = ""
|
|
596
|
+
|
|
597
|
+
function_data = self.vn.create_function(question=question, sql=sql, plotly_code=plotly_code)
|
|
598
|
+
|
|
599
|
+
return jsonify(
|
|
600
|
+
{
|
|
601
|
+
"type": "function_template",
|
|
602
|
+
"id": id,
|
|
603
|
+
"function_template": function_data,
|
|
604
|
+
}
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
@self.flask_app.route("/api/v0/update_function", methods=["POST"])
|
|
608
|
+
@self.requires_auth
|
|
609
|
+
def update_function(user: any):
|
|
610
|
+
old_function_name = flask.request.json.get("old_function_name")
|
|
611
|
+
updated_function = flask.request.json.get("updated_function")
|
|
612
|
+
|
|
613
|
+
print("old_function_name", old_function_name)
|
|
614
|
+
print("updated_function", updated_function)
|
|
615
|
+
|
|
616
|
+
updated = vn.update_function(old_function_name=old_function_name, updated_function=updated_function)
|
|
617
|
+
|
|
618
|
+
return jsonify({"success": updated})
|
|
619
|
+
|
|
620
|
+
@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
|
|
621
|
+
@self.requires_auth
|
|
622
|
+
def delete_function(user: any):
|
|
623
|
+
function_name = flask.request.json.get("function_name")
|
|
624
|
+
|
|
625
|
+
return jsonify({"success": vn.delete_function(function_name=function_name)})
|
|
626
|
+
|
|
627
|
+
|
|
521
628
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
522
629
|
@self.requires_auth
|
|
523
630
|
@self.requires_cache(["df", "question", "sql"])
|
|
@@ -578,8 +685,8 @@ class VannaFlaskApp:
|
|
|
578
685
|
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
|
|
579
686
|
@self.requires_auth
|
|
580
687
|
@self.requires_cache(
|
|
581
|
-
["question", "sql", "df"
|
|
582
|
-
optional_fields=["summary"]
|
|
688
|
+
["question", "sql", "df"],
|
|
689
|
+
optional_fields=["summary", "fig_json"]
|
|
583
690
|
)
|
|
584
691
|
def load_question(user: any, id: str, question, sql, df, fig_json, summary):
|
|
585
692
|
try:
|
|
@@ -616,6 +723,9 @@ class VannaFlaskApp:
|
|
|
616
723
|
|
|
617
724
|
@self.flask_app.route("/assets/<path:filename>")
|
|
618
725
|
def proxy_assets(filename):
|
|
726
|
+
if self.assets_folder:
|
|
727
|
+
return send_from_directory(self.assets_folder, filename)
|
|
728
|
+
|
|
619
729
|
if ".css" in filename:
|
|
620
730
|
return Response(css_content, mimetype="text/css")
|
|
621
731
|
|
|
@@ -663,6 +773,10 @@ class VannaFlaskApp:
|
|
|
663
773
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
664
774
|
@self.flask_app.route("/<path:path>")
|
|
665
775
|
def hello(path: str):
|
|
776
|
+
if self.index_html_path:
|
|
777
|
+
directory = os.path.dirname(self.index_html_path)
|
|
778
|
+
filename = os.path.basename(self.index_html_path)
|
|
779
|
+
return send_from_directory(directory=directory, path=filename)
|
|
666
780
|
return html_content
|
|
667
781
|
|
|
668
782
|
def run(self, *args, **kwargs):
|
|
@@ -692,4 +806,4 @@ class VannaFlaskApp:
|
|
|
692
806
|
print("Your app is running at:")
|
|
693
807
|
print("http://localhost:8084")
|
|
694
808
|
|
|
695
|
-
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
|
|
809
|
+
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
|