vanna 0.1.1__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.1.1 → vanna-0.2.1}/PKG-INFO +2 -1
- {vanna-0.1.1 → vanna-0.2.1}/pyproject.toml +2 -2
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/base/base.py +91 -27
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/chromadb/chromadb_vector.py +22 -6
- vanna-0.1.1/src/vanna/flask.py → vanna-0.2.1/src/vanna/flask/__init__.py +59 -52
- vanna-0.2.1/src/vanna/flask/assets.py +36 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/openai/openai_chat.py +31 -25
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/openai/openai_embeddings.py +0 -2
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/remote.py +44 -70
- {vanna-0.1.1 → vanna-0.2.1}/README.md +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/base/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/local.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/types/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/utils.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.1.1 → vanna-0.2.1}/src/vanna/vannadb/vannadb_vector.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -23,6 +23,7 @@ Requires-Dist: snowflake-connector-python ; extra == "all"
|
|
|
23
23
|
Requires-Dist: duckdb ; extra == "all"
|
|
24
24
|
Requires-Dist: openai ; extra == "all"
|
|
25
25
|
Requires-Dist: mistralai ; extra == "all"
|
|
26
|
+
Requires-Dist: chromadb ; extra == "all"
|
|
26
27
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
27
28
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
28
29
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.1"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -30,7 +30,7 @@ postgres = ["psycopg2-binary", "db-dtypes"]
|
|
|
30
30
|
bigquery = ["google-cloud-bigquery"]
|
|
31
31
|
snowflake = ["snowflake-connector-python"]
|
|
32
32
|
duckdb = ["duckdb"]
|
|
33
|
-
all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai"]
|
|
33
|
+
all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
|
|
34
34
|
test = ["tox"]
|
|
35
35
|
chromadb = ["chromadb"]
|
|
36
36
|
openai = ["openai"]
|
|
@@ -72,6 +72,7 @@ class VannaBase(ABC):
|
|
|
72
72
|
def __init__(self, config=None):
|
|
73
73
|
self.config = config
|
|
74
74
|
self.run_sql_is_set = False
|
|
75
|
+
self.static_documentation = ""
|
|
75
76
|
|
|
76
77
|
def log(self, message: str):
|
|
77
78
|
print(message)
|
|
@@ -140,18 +141,35 @@ class VannaBase(ABC):
|
|
|
140
141
|
else:
|
|
141
142
|
return False
|
|
142
143
|
|
|
143
|
-
def generate_followup_questions(
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
144
|
+
def generate_followup_questions(
|
|
145
|
+
self, question: str, sql: str, df: pd.DataFrame, **kwargs
|
|
146
|
+
) -> list:
|
|
147
|
+
"""
|
|
148
|
+
**Example:**
|
|
149
|
+
```python
|
|
150
|
+
vn.generate_followup_questions("What are the top 10 customers by sales?", df)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
Generate a list of followup questions that you can ask Vanna.AI.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
question (str): The question that was asked.
|
|
157
|
+
df (pd.DataFrame): The results of the SQL query.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
list: A list of followup questions that you can ask Vanna.AI.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
message_log = [
|
|
164
|
+
self.system_message(
|
|
165
|
+
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
166
|
+
),
|
|
167
|
+
self.user_message(
|
|
168
|
+
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
169
|
+
),
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
llm_response = self.submit_prompt(message_log, **kwargs)
|
|
155
173
|
|
|
156
174
|
numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
|
|
157
175
|
return numbers_removed.split("\n")
|
|
@@ -169,6 +187,36 @@ class VannaBase(ABC):
|
|
|
169
187
|
|
|
170
188
|
return [q["question"] for q in question_sql]
|
|
171
189
|
|
|
190
|
+
def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
|
|
191
|
+
"""
|
|
192
|
+
**Example:**
|
|
193
|
+
```python
|
|
194
|
+
vn.generate_summary("What are the top 10 customers by sales?", df)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
Generate a summary of the results of a SQL query.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
question (str): The question that was asked.
|
|
201
|
+
df (pd.DataFrame): The results of the SQL query.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
str: The summary of the results of the SQL query.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
message_log = [
|
|
208
|
+
self.system_message(
|
|
209
|
+
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
210
|
+
),
|
|
211
|
+
self.user_message(
|
|
212
|
+
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
213
|
+
),
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
summary = self.submit_prompt(message_log, **kwargs)
|
|
217
|
+
|
|
218
|
+
return summary
|
|
219
|
+
|
|
172
220
|
# ----------------- Use Any Embeddings API ----------------- #
|
|
173
221
|
@abstractmethod
|
|
174
222
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
@@ -184,7 +232,7 @@ class VannaBase(ABC):
|
|
|
184
232
|
question (str): The question to get similar questions and their corresponding SQL statements for.
|
|
185
233
|
|
|
186
234
|
Returns:
|
|
187
|
-
list: A list of similar questions and their corresponding SQL statements.
|
|
235
|
+
list: A list of similar questions and their corresponding SQL statements.
|
|
188
236
|
"""
|
|
189
237
|
pass
|
|
190
238
|
|
|
@@ -224,7 +272,7 @@ class VannaBase(ABC):
|
|
|
224
272
|
sql (str): The SQL query to add.
|
|
225
273
|
|
|
226
274
|
Returns:
|
|
227
|
-
str: The ID of the training data that was added.
|
|
275
|
+
str: The ID of the training data that was added.
|
|
228
276
|
"""
|
|
229
277
|
pass
|
|
230
278
|
|
|
@@ -232,7 +280,7 @@ class VannaBase(ABC):
|
|
|
232
280
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
233
281
|
"""
|
|
234
282
|
This method is used to add a DDL statement to the training data.
|
|
235
|
-
|
|
283
|
+
|
|
236
284
|
Args:
|
|
237
285
|
ddl (str): The DDL statement to add.
|
|
238
286
|
|
|
@@ -265,7 +313,7 @@ class VannaBase(ABC):
|
|
|
265
313
|
This method is used to get all the training data from the retrieval layer.
|
|
266
314
|
|
|
267
315
|
Returns:
|
|
268
|
-
pd.DataFrame: The training data.
|
|
316
|
+
pd.DataFrame: The training data.
|
|
269
317
|
"""
|
|
270
318
|
pass
|
|
271
319
|
|
|
@@ -321,7 +369,10 @@ class VannaBase(ABC):
|
|
|
321
369
|
return initial_prompt
|
|
322
370
|
|
|
323
371
|
def add_documentation_to_prompt(
|
|
324
|
-
self,
|
|
372
|
+
self,
|
|
373
|
+
initial_prompt: str,
|
|
374
|
+
documentation_list: list[str],
|
|
375
|
+
max_tokens: int = 14000,
|
|
325
376
|
) -> str:
|
|
326
377
|
if len(documentation_list) > 0:
|
|
327
378
|
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
@@ -389,6 +440,9 @@ class VannaBase(ABC):
|
|
|
389
440
|
initial_prompt, ddl_list, max_tokens=14000
|
|
390
441
|
)
|
|
391
442
|
|
|
443
|
+
if self.static_documentation != "":
|
|
444
|
+
doc_list.append(self.static_documentation)
|
|
445
|
+
|
|
392
446
|
initial_prompt = self.add_documentation_to_prompt(
|
|
393
447
|
initial_prompt, doc_list, max_tokens=14000
|
|
394
448
|
)
|
|
@@ -599,6 +653,7 @@ class VannaBase(ABC):
|
|
|
599
653
|
|
|
600
654
|
return df
|
|
601
655
|
|
|
656
|
+
self.static_documentation = "This is a Snowflake database"
|
|
602
657
|
self.run_sql = run_sql_snowflake
|
|
603
658
|
self.run_sql_is_set = True
|
|
604
659
|
|
|
@@ -632,6 +687,7 @@ class VannaBase(ABC):
|
|
|
632
687
|
def run_sql_sqlite(sql: str):
|
|
633
688
|
return pd.read_sql_query(sql, conn)
|
|
634
689
|
|
|
690
|
+
self.static_documentation = "This is a SQLite database"
|
|
635
691
|
self.run_sql = run_sql_sqlite
|
|
636
692
|
self.run_sql_is_set = True
|
|
637
693
|
|
|
@@ -731,11 +787,12 @@ class VannaBase(ABC):
|
|
|
731
787
|
except psycopg2.Error as e:
|
|
732
788
|
conn.rollback()
|
|
733
789
|
raise ValidationError(e)
|
|
734
|
-
|
|
790
|
+
|
|
735
791
|
except Exception as e:
|
|
736
792
|
conn.rollback()
|
|
737
793
|
raise e
|
|
738
794
|
|
|
795
|
+
self.static_documentation = "This is a Postgres database"
|
|
739
796
|
self.run_sql_is_set = True
|
|
740
797
|
self.run_sql = run_sql_postgres
|
|
741
798
|
|
|
@@ -825,6 +882,7 @@ class VannaBase(ABC):
|
|
|
825
882
|
raise errors
|
|
826
883
|
return None
|
|
827
884
|
|
|
885
|
+
self.static_documentation = "This is a BigQuery database"
|
|
828
886
|
self.run_sql_is_set = True
|
|
829
887
|
self.run_sql = run_sql_bigquery
|
|
830
888
|
|
|
@@ -847,13 +905,13 @@ class VannaBase(ABC):
|
|
|
847
905
|
" run command: \npip install vanna[duckdb]"
|
|
848
906
|
)
|
|
849
907
|
# URL of the database to download
|
|
850
|
-
if url==":memory:" or url=="":
|
|
851
|
-
path=":memory:"
|
|
908
|
+
if url == ":memory:" or url == "":
|
|
909
|
+
path = ":memory:"
|
|
852
910
|
else:
|
|
853
911
|
# Path to save the downloaded database
|
|
854
912
|
print(os.path.exists(url))
|
|
855
913
|
if os.path.exists(url):
|
|
856
|
-
path=url
|
|
914
|
+
path = url
|
|
857
915
|
elif url.startswith("md") or url.startswith("motherduck"):
|
|
858
916
|
path = url
|
|
859
917
|
else:
|
|
@@ -873,6 +931,7 @@ class VannaBase(ABC):
|
|
|
873
931
|
def run_sql_duckdb(sql: str):
|
|
874
932
|
return conn.query(sql).to_df()
|
|
875
933
|
|
|
934
|
+
self.static_documentation = "This is a DuckDB database"
|
|
876
935
|
self.run_sql = run_sql_duckdb
|
|
877
936
|
self.run_sql_is_set = True
|
|
878
937
|
|
|
@@ -895,17 +954,20 @@ class VannaBase(ABC):
|
|
|
895
954
|
)
|
|
896
955
|
|
|
897
956
|
try:
|
|
898
|
-
from sqlalchemy.engine import URL
|
|
899
957
|
import sqlalchemy as sa
|
|
958
|
+
from sqlalchemy.engine import URL
|
|
900
959
|
except ImportError:
|
|
901
960
|
raise DependencyError(
|
|
902
961
|
"You need to install required dependencies to execute this method,"
|
|
903
962
|
" run command: pip install sqlalchemy"
|
|
904
963
|
)
|
|
905
964
|
|
|
906
|
-
connection_url = URL.create(
|
|
965
|
+
connection_url = URL.create(
|
|
966
|
+
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
|
|
967
|
+
)
|
|
907
968
|
|
|
908
969
|
from sqlalchemy import create_engine
|
|
970
|
+
|
|
909
971
|
engine = create_engine(connection_url)
|
|
910
972
|
|
|
911
973
|
def run_sql_mssql(sql: str):
|
|
@@ -913,9 +975,10 @@ class VannaBase(ABC):
|
|
|
913
975
|
with engine.begin() as conn:
|
|
914
976
|
df = pd.read_sql_query(sa.text(sql), conn)
|
|
915
977
|
return df
|
|
916
|
-
|
|
978
|
+
|
|
917
979
|
raise Exception("Couldn't run sql")
|
|
918
980
|
|
|
981
|
+
self.static_documentation = "This is a Microsoft SQL Server database"
|
|
919
982
|
self.run_sql = run_sql_mssql
|
|
920
983
|
self.run_sql_is_set = True
|
|
921
984
|
|
|
@@ -943,7 +1006,7 @@ class VannaBase(ABC):
|
|
|
943
1006
|
question: Union[str, None] = None,
|
|
944
1007
|
print_results: bool = True,
|
|
945
1008
|
auto_train: bool = True,
|
|
946
|
-
visualize: bool = True,
|
|
1009
|
+
visualize: bool = True, # if False, will not generate plotly code
|
|
947
1010
|
) -> Union[
|
|
948
1011
|
Tuple[
|
|
949
1012
|
Union[str, None],
|
|
@@ -1024,7 +1087,9 @@ class VannaBase(ABC):
|
|
|
1024
1087
|
display = __import__(
|
|
1025
1088
|
"IPython.display", fromlist=["display"]
|
|
1026
1089
|
).display
|
|
1027
|
-
Image = __import__(
|
|
1090
|
+
Image = __import__(
|
|
1091
|
+
"IPython.display", fromlist=["Image"]
|
|
1092
|
+
).Image
|
|
1028
1093
|
img_bytes = fig.to_image(format="png", scale=2)
|
|
1029
1094
|
display(Image(img_bytes))
|
|
1030
1095
|
except Exception as e:
|
|
@@ -1377,4 +1442,3 @@ class VannaBase(ABC):
|
|
|
1377
1442
|
fig.update_layout(template="plotly_dark")
|
|
1378
1443
|
|
|
1379
1444
|
return fig
|
|
1380
|
-
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import List
|
|
3
2
|
import uuid
|
|
4
|
-
from
|
|
3
|
+
from typing import List
|
|
5
4
|
|
|
6
5
|
import chromadb
|
|
7
6
|
import pandas as pd
|
|
@@ -20,13 +19,28 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
20
19
|
if config is not None:
|
|
21
20
|
path = config.get("path", ".")
|
|
22
21
|
self.embedding_function = config.get("embedding_function", default_ef)
|
|
22
|
+
curr_client = config.get("client", "persistent")
|
|
23
|
+
self.n_results = config.get("n_results", 10)
|
|
23
24
|
else:
|
|
24
25
|
path = "."
|
|
25
26
|
self.embedding_function = default_ef
|
|
27
|
+
curr_client = "persistent" # defaults to persistent storage
|
|
28
|
+
self.n_results = 10 # defaults to 10 documents
|
|
29
|
+
|
|
30
|
+
if curr_client == "persistent":
|
|
31
|
+
self.chroma_client = chromadb.PersistentClient(
|
|
32
|
+
path=path, settings=Settings(anonymized_telemetry=False)
|
|
33
|
+
)
|
|
34
|
+
elif curr_client == "in-memory":
|
|
35
|
+
self.chroma_client = chromadb.EphemeralClient(
|
|
36
|
+
settings=Settings(anonymized_telemetry=False)
|
|
37
|
+
)
|
|
38
|
+
elif isinstance(curr_client, chromadb.api.client.Client):
|
|
39
|
+
# allow providing client directly
|
|
40
|
+
self.chroma_client = curr_client
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Unsupported client was set in config: {curr_client}")
|
|
26
43
|
|
|
27
|
-
self.chroma_client = chromadb.PersistentClient(
|
|
28
|
-
path=path, settings=Settings(anonymized_telemetry=False)
|
|
29
|
-
)
|
|
30
44
|
self.documentation_collection = self.chroma_client.get_or_create_collection(
|
|
31
45
|
name="documentation", embedding_function=self.embedding_function
|
|
32
46
|
)
|
|
@@ -196,7 +210,8 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
196
210
|
query_results (pd.DataFrame): The dataframe to use.
|
|
197
211
|
|
|
198
212
|
Returns:
|
|
199
|
-
List[str] or None: The extracted documents, or an empty list or
|
|
213
|
+
List[str] or None: The extracted documents, or an empty list or
|
|
214
|
+
single document if an error occurred.
|
|
200
215
|
"""
|
|
201
216
|
if query_results is None:
|
|
202
217
|
return []
|
|
@@ -216,6 +231,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
216
231
|
return ChromaDB_VectorStore._extract_documents(
|
|
217
232
|
self.sql_collection.query(
|
|
218
233
|
query_texts=[question],
|
|
234
|
+
n_results=self.n_results,
|
|
219
235
|
)
|
|
220
236
|
)
|
|
221
237
|
|
|
@@ -7,6 +7,8 @@ import flask
|
|
|
7
7
|
import requests
|
|
8
8
|
from flask import Flask, Response, jsonify, request
|
|
9
9
|
|
|
10
|
+
from .assets import css_content, html_content, js_content
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class Cache(ABC):
|
|
12
14
|
@abstractmethod
|
|
@@ -92,10 +94,11 @@ class VannaFlaskApp:
|
|
|
92
94
|
|
|
93
95
|
return decorator
|
|
94
96
|
|
|
95
|
-
def __init__(self, vn, cache: Cache = MemoryCache()):
|
|
97
|
+
def __init__(self, vn, cache: Cache = MemoryCache(), allow_llm_to_see_data=False):
|
|
96
98
|
self.flask_app = Flask(__name__)
|
|
97
99
|
self.vn = vn
|
|
98
100
|
self.cache = cache
|
|
101
|
+
self.allow_llm_to_see_data = allow_llm_to_see_data
|
|
99
102
|
|
|
100
103
|
log = logging.getLogger("werkzeug")
|
|
101
104
|
log.setLevel(logging.ERROR)
|
|
@@ -296,23 +299,55 @@ class VannaFlaskApp:
|
|
|
296
299
|
return jsonify({"type": "error", "error": str(e)})
|
|
297
300
|
|
|
298
301
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
299
|
-
@self.requires_cache(["df", "question"])
|
|
300
|
-
def generate_followup_questions(id: str, df, question):
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
302
|
+
@self.requires_cache(["df", "question", "sql"])
|
|
303
|
+
def generate_followup_questions(id: str, df, question, sql):
|
|
304
|
+
if self.allow_llm_to_see_data:
|
|
305
|
+
followup_questions = vn.generate_followup_questions(
|
|
306
|
+
question=question, sql=sql, df=df
|
|
307
|
+
)
|
|
308
|
+
if followup_questions is not None and len(followup_questions) > 5:
|
|
309
|
+
followup_questions = followup_questions[:5]
|
|
305
310
|
|
|
306
|
-
|
|
311
|
+
cache.set(id=id, field="followup_questions", value=followup_questions)
|
|
307
312
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
313
|
+
return jsonify(
|
|
314
|
+
{
|
|
315
|
+
"type": "question_list",
|
|
316
|
+
"id": id,
|
|
317
|
+
"questions": followup_questions,
|
|
318
|
+
"header": "Here are some potential followup questions:",
|
|
319
|
+
}
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
return jsonify(
|
|
323
|
+
{
|
|
324
|
+
"type": "question_list",
|
|
325
|
+
"id": id,
|
|
326
|
+
"questions": [],
|
|
327
|
+
"header": "Followup Questions can be enabled if you set allow_llm_to_see_data=True",
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
|
|
332
|
+
@self.requires_cache(["df", "question"])
|
|
333
|
+
def generate_summary(id: str, df, question):
|
|
334
|
+
if self.allow_llm_to_see_data:
|
|
335
|
+
summary = vn.generate_summary(question=question, df=df)
|
|
336
|
+
return jsonify(
|
|
337
|
+
{
|
|
338
|
+
"type": "text",
|
|
339
|
+
"id": id,
|
|
340
|
+
"text": summary,
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
return jsonify(
|
|
345
|
+
{
|
|
346
|
+
"type": "text",
|
|
347
|
+
"id": id,
|
|
348
|
+
"text": "Summarization can be enabled if you set allow_llm_to_see_data=True",
|
|
349
|
+
}
|
|
350
|
+
)
|
|
316
351
|
|
|
317
352
|
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
|
|
318
353
|
@self.requires_cache(
|
|
@@ -352,25 +387,14 @@ class VannaFlaskApp:
|
|
|
352
387
|
|
|
353
388
|
@self.flask_app.route("/assets/<path:filename>")
|
|
354
389
|
def proxy_assets(filename):
|
|
355
|
-
|
|
356
|
-
|
|
390
|
+
if ".css" in filename:
|
|
391
|
+
return Response(css_content, mimetype="text/css")
|
|
357
392
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
"transfer-encoding",
|
|
364
|
-
"connection",
|
|
365
|
-
]
|
|
366
|
-
headers = [
|
|
367
|
-
(name, value)
|
|
368
|
-
for (name, value) in response.raw.headers.items()
|
|
369
|
-
if name.lower() not in excluded_headers
|
|
370
|
-
]
|
|
371
|
-
return Response(response.content, response.status_code, headers)
|
|
372
|
-
else:
|
|
373
|
-
return "Error fetching file from remote server", response.status_code
|
|
393
|
+
if ".js" in filename:
|
|
394
|
+
return Response(js_content, mimetype="text/javascript")
|
|
395
|
+
|
|
396
|
+
# Return 404
|
|
397
|
+
return "File not found", 404
|
|
374
398
|
|
|
375
399
|
# Proxy the /vanna.svg file to the remote server
|
|
376
400
|
@self.flask_app.route("/vanna.svg")
|
|
@@ -398,24 +422,7 @@ class VannaFlaskApp:
|
|
|
398
422
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
399
423
|
@self.flask_app.route("/<path:path>")
|
|
400
424
|
def hello(path: str):
|
|
401
|
-
return
|
|
402
|
-
<!doctype html>
|
|
403
|
-
<html lang="en">
|
|
404
|
-
<head>
|
|
405
|
-
<meta charset="UTF-8" />
|
|
406
|
-
<link rel="icon" type="image/svg+xml" href="/vanna.svg" />
|
|
407
|
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
408
|
-
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
|
|
409
|
-
<script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
|
|
410
|
-
<title>Vanna.AI</title>
|
|
411
|
-
<script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
|
|
412
|
-
<link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
|
|
413
|
-
</head>
|
|
414
|
-
<body class="bg-white dark:bg-slate-900">
|
|
415
|
-
<div id="app"></div>
|
|
416
|
-
</body>
|
|
417
|
-
</html>
|
|
418
|
-
"""
|
|
425
|
+
return html_content
|
|
419
426
|
|
|
420
427
|
def run(self):
|
|
421
428
|
try:
|