vanna 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.1.0 → vanna-0.2.0}/PKG-INFO +3 -1
- {vanna-0.1.0 → vanna-0.2.0}/pyproject.toml +3 -3
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/base/base.py +137 -24
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/chromadb/chromadb_vector.py +22 -6
- vanna-0.1.0/src/vanna/flask.py → vanna-0.2.0/src/vanna/flask/__init__.py +59 -52
- vanna-0.2.0/src/vanna/flask/assets.py +36 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/openai_chat.py +25 -19
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/openai_embeddings.py +0 -2
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/remote.py +44 -70
- {vanna-0.1.0 → vanna-0.2.0}/README.md +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/base/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/local.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/ollama/ollama.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/types/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/utils.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.1.0 → vanna-0.2.0}/src/vanna/vannadb/vannadb_vector.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -15,6 +15,7 @@ Requires-Dist: pandas
|
|
|
15
15
|
Requires-Dist: sqlparse
|
|
16
16
|
Requires-Dist: kaleido
|
|
17
17
|
Requires-Dist: flask
|
|
18
|
+
Requires-Dist: sqlalchemy
|
|
18
19
|
Requires-Dist: psycopg2-binary ; extra == "all"
|
|
19
20
|
Requires-Dist: db-dtypes ; extra == "all"
|
|
20
21
|
Requires-Dist: google-cloud-bigquery ; extra == "all"
|
|
@@ -22,6 +23,7 @@ Requires-Dist: snowflake-connector-python ; extra == "all"
|
|
|
22
23
|
Requires-Dist: duckdb ; extra == "all"
|
|
23
24
|
Requires-Dist: openai ; extra == "all"
|
|
24
25
|
Requires-Dist: mistralai ; extra == "all"
|
|
26
|
+
Requires-Dist: chromadb ; extra == "all"
|
|
25
27
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
26
28
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
27
29
|
Requires-Dist: duckdb ; extra == "duckdb"
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -18,7 +18,7 @@ classifiers = [
|
|
|
18
18
|
"Operating System :: OS Independent",
|
|
19
19
|
]
|
|
20
20
|
dependencies = [
|
|
21
|
-
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask"
|
|
21
|
+
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy"
|
|
22
22
|
]
|
|
23
23
|
|
|
24
24
|
[project.urls]
|
|
@@ -30,7 +30,7 @@ postgres = ["psycopg2-binary", "db-dtypes"]
|
|
|
30
30
|
bigquery = ["google-cloud-bigquery"]
|
|
31
31
|
snowflake = ["snowflake-connector-python"]
|
|
32
32
|
duckdb = ["duckdb"]
|
|
33
|
-
all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai"]
|
|
33
|
+
all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
|
|
34
34
|
test = ["tox"]
|
|
35
35
|
chromadb = ["chromadb"]
|
|
36
36
|
openai = ["openai"]
|
|
@@ -72,6 +72,7 @@ class VannaBase(ABC):
|
|
|
72
72
|
def __init__(self, config=None):
|
|
73
73
|
self.config = config
|
|
74
74
|
self.run_sql_is_set = False
|
|
75
|
+
self.static_documentation = ""
|
|
75
76
|
|
|
76
77
|
def log(self, message: str):
|
|
77
78
|
print(message)
|
|
@@ -140,18 +141,35 @@ class VannaBase(ABC):
|
|
|
140
141
|
else:
|
|
141
142
|
return False
|
|
142
143
|
|
|
143
|
-
def generate_followup_questions(
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
144
|
+
def generate_followup_questions(
|
|
145
|
+
self, question: str, sql: str, df: pd.DataFrame, **kwargs
|
|
146
|
+
) -> list:
|
|
147
|
+
"""
|
|
148
|
+
**Example:**
|
|
149
|
+
```python
|
|
150
|
+
vn.generate_followup_questions("What are the top 10 customers by sales?", df)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
Generate a list of followup questions that you can ask Vanna.AI.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
question (str): The question that was asked.
|
|
157
|
+
df (pd.DataFrame): The results of the SQL query.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
list: A list of followup questions that you can ask Vanna.AI.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
message_log = [
|
|
164
|
+
self.system_message(
|
|
165
|
+
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
166
|
+
),
|
|
167
|
+
self.user_message(
|
|
168
|
+
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
169
|
+
),
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
llm_response = self.submit_prompt(message_log, **kwargs)
|
|
155
173
|
|
|
156
174
|
numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
|
|
157
175
|
return numbers_removed.split("\n")
|
|
@@ -169,6 +187,36 @@ class VannaBase(ABC):
|
|
|
169
187
|
|
|
170
188
|
return [q["question"] for q in question_sql]
|
|
171
189
|
|
|
190
|
+
def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
|
|
191
|
+
"""
|
|
192
|
+
**Example:**
|
|
193
|
+
```python
|
|
194
|
+
vn.generate_summary("What are the top 10 customers by sales?", df)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
Generate a summary of the results of a SQL query.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
question (str): The question that was asked.
|
|
201
|
+
df (pd.DataFrame): The results of the SQL query.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
str: The summary of the results of the SQL query.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
message_log = [
|
|
208
|
+
self.system_message(
|
|
209
|
+
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
210
|
+
),
|
|
211
|
+
self.user_message(
|
|
212
|
+
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
213
|
+
),
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
summary = self.submit_prompt(message_log, **kwargs)
|
|
217
|
+
|
|
218
|
+
return summary
|
|
219
|
+
|
|
172
220
|
# ----------------- Use Any Embeddings API ----------------- #
|
|
173
221
|
@abstractmethod
|
|
174
222
|
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
@@ -184,7 +232,7 @@ class VannaBase(ABC):
|
|
|
184
232
|
question (str): The question to get similar questions and their corresponding SQL statements for.
|
|
185
233
|
|
|
186
234
|
Returns:
|
|
187
|
-
list: A list of similar questions and their corresponding SQL statements.
|
|
235
|
+
list: A list of similar questions and their corresponding SQL statements.
|
|
188
236
|
"""
|
|
189
237
|
pass
|
|
190
238
|
|
|
@@ -224,7 +272,7 @@ class VannaBase(ABC):
|
|
|
224
272
|
sql (str): The SQL query to add.
|
|
225
273
|
|
|
226
274
|
Returns:
|
|
227
|
-
str: The ID of the training data that was added.
|
|
275
|
+
str: The ID of the training data that was added.
|
|
228
276
|
"""
|
|
229
277
|
pass
|
|
230
278
|
|
|
@@ -232,7 +280,7 @@ class VannaBase(ABC):
|
|
|
232
280
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
233
281
|
"""
|
|
234
282
|
This method is used to add a DDL statement to the training data.
|
|
235
|
-
|
|
283
|
+
|
|
236
284
|
Args:
|
|
237
285
|
ddl (str): The DDL statement to add.
|
|
238
286
|
|
|
@@ -265,7 +313,7 @@ class VannaBase(ABC):
|
|
|
265
313
|
This method is used to get all the training data from the retrieval layer.
|
|
266
314
|
|
|
267
315
|
Returns:
|
|
268
|
-
pd.DataFrame: The training data.
|
|
316
|
+
pd.DataFrame: The training data.
|
|
269
317
|
"""
|
|
270
318
|
pass
|
|
271
319
|
|
|
@@ -321,7 +369,10 @@ class VannaBase(ABC):
|
|
|
321
369
|
return initial_prompt
|
|
322
370
|
|
|
323
371
|
def add_documentation_to_prompt(
|
|
324
|
-
self,
|
|
372
|
+
self,
|
|
373
|
+
initial_prompt: str,
|
|
374
|
+
documentation_list: list[str],
|
|
375
|
+
max_tokens: int = 14000,
|
|
325
376
|
) -> str:
|
|
326
377
|
if len(documentation_list) > 0:
|
|
327
378
|
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
|
|
@@ -389,6 +440,9 @@ class VannaBase(ABC):
|
|
|
389
440
|
initial_prompt, ddl_list, max_tokens=14000
|
|
390
441
|
)
|
|
391
442
|
|
|
443
|
+
if self.static_documentation != "":
|
|
444
|
+
doc_list.append(self.static_documentation)
|
|
445
|
+
|
|
392
446
|
initial_prompt = self.add_documentation_to_prompt(
|
|
393
447
|
initial_prompt, doc_list, max_tokens=14000
|
|
394
448
|
)
|
|
@@ -599,6 +653,7 @@ class VannaBase(ABC):
|
|
|
599
653
|
|
|
600
654
|
return df
|
|
601
655
|
|
|
656
|
+
self.static_documentation = "This is a Snowflake database"
|
|
602
657
|
self.run_sql = run_sql_snowflake
|
|
603
658
|
self.run_sql_is_set = True
|
|
604
659
|
|
|
@@ -632,6 +687,7 @@ class VannaBase(ABC):
|
|
|
632
687
|
def run_sql_sqlite(sql: str):
|
|
633
688
|
return pd.read_sql_query(sql, conn)
|
|
634
689
|
|
|
690
|
+
self.static_documentation = "This is a SQLite database"
|
|
635
691
|
self.run_sql = run_sql_sqlite
|
|
636
692
|
self.run_sql_is_set = True
|
|
637
693
|
|
|
@@ -732,6 +788,11 @@ class VannaBase(ABC):
|
|
|
732
788
|
conn.rollback()
|
|
733
789
|
raise ValidationError(e)
|
|
734
790
|
|
|
791
|
+
except Exception as e:
|
|
792
|
+
conn.rollback()
|
|
793
|
+
raise e
|
|
794
|
+
|
|
795
|
+
self.static_documentation = "This is a Postgres database"
|
|
735
796
|
self.run_sql_is_set = True
|
|
736
797
|
self.run_sql = run_sql_postgres
|
|
737
798
|
|
|
@@ -821,6 +882,7 @@ class VannaBase(ABC):
|
|
|
821
882
|
raise errors
|
|
822
883
|
return None
|
|
823
884
|
|
|
885
|
+
self.static_documentation = "This is a BigQuery database"
|
|
824
886
|
self.run_sql_is_set = True
|
|
825
887
|
self.run_sql = run_sql_bigquery
|
|
826
888
|
|
|
@@ -829,7 +891,7 @@ class VannaBase(ABC):
|
|
|
829
891
|
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
830
892
|
|
|
831
893
|
Args:
|
|
832
|
-
url (str): The URL of the database to connect to.
|
|
894
|
+
url (str): The URL of the database to connect to. Use :memory: to create an in-memory database. Use md: or motherduck: to use the MotherDuck database.
|
|
833
895
|
init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
|
|
834
896
|
|
|
835
897
|
Returns:
|
|
@@ -843,13 +905,15 @@ class VannaBase(ABC):
|
|
|
843
905
|
" run command: \npip install vanna[duckdb]"
|
|
844
906
|
)
|
|
845
907
|
# URL of the database to download
|
|
846
|
-
if url==":memory:" or url=="":
|
|
847
|
-
path=":memory:"
|
|
908
|
+
if url == ":memory:" or url == "":
|
|
909
|
+
path = ":memory:"
|
|
848
910
|
else:
|
|
849
911
|
# Path to save the downloaded database
|
|
850
912
|
print(os.path.exists(url))
|
|
851
913
|
if os.path.exists(url):
|
|
852
|
-
path=url
|
|
914
|
+
path = url
|
|
915
|
+
elif url.startswith("md") or url.startswith("motherduck"):
|
|
916
|
+
path = url
|
|
853
917
|
else:
|
|
854
918
|
path = os.path.basename(urlparse(url).path)
|
|
855
919
|
# Download the database if it doesn't exist
|
|
@@ -867,9 +931,57 @@ class VannaBase(ABC):
|
|
|
867
931
|
def run_sql_duckdb(sql: str):
|
|
868
932
|
return conn.query(sql).to_df()
|
|
869
933
|
|
|
934
|
+
self.static_documentation = "This is a DuckDB database"
|
|
870
935
|
self.run_sql = run_sql_duckdb
|
|
871
936
|
self.run_sql_is_set = True
|
|
872
937
|
|
|
938
|
+
def connect_to_mssql(self, odbc_conn_str: str):
|
|
939
|
+
"""
|
|
940
|
+
Connect to a Microsoft SQL Server database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
odbc_conn_str (str): The ODBC connection string.
|
|
944
|
+
|
|
945
|
+
Returns:
|
|
946
|
+
None
|
|
947
|
+
"""
|
|
948
|
+
try:
|
|
949
|
+
import pyodbc
|
|
950
|
+
except ImportError:
|
|
951
|
+
raise DependencyError(
|
|
952
|
+
"You need to install required dependencies to execute this method,"
|
|
953
|
+
" run command: pip install pyodbc"
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
try:
|
|
957
|
+
import sqlalchemy as sa
|
|
958
|
+
from sqlalchemy.engine import URL
|
|
959
|
+
except ImportError:
|
|
960
|
+
raise DependencyError(
|
|
961
|
+
"You need to install required dependencies to execute this method,"
|
|
962
|
+
" run command: pip install sqlalchemy"
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
connection_url = URL.create(
|
|
966
|
+
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
from sqlalchemy import create_engine
|
|
970
|
+
|
|
971
|
+
engine = create_engine(connection_url)
|
|
972
|
+
|
|
973
|
+
def run_sql_mssql(sql: str):
|
|
974
|
+
# Execute the SQL statement and return the result as a pandas DataFrame
|
|
975
|
+
with engine.begin() as conn:
|
|
976
|
+
df = pd.read_sql_query(sa.text(sql), conn)
|
|
977
|
+
return df
|
|
978
|
+
|
|
979
|
+
raise Exception("Couldn't run sql")
|
|
980
|
+
|
|
981
|
+
self.static_documentation = "This is a Microsoft SQL Server database"
|
|
982
|
+
self.run_sql = run_sql_mssql
|
|
983
|
+
self.run_sql_is_set = True
|
|
984
|
+
|
|
873
985
|
def run_sql(self, sql: str, **kwargs) -> pd.DataFrame:
|
|
874
986
|
"""
|
|
875
987
|
Example:
|
|
@@ -894,7 +1006,7 @@ class VannaBase(ABC):
|
|
|
894
1006
|
question: Union[str, None] = None,
|
|
895
1007
|
print_results: bool = True,
|
|
896
1008
|
auto_train: bool = True,
|
|
897
|
-
visualize: bool = True,
|
|
1009
|
+
visualize: bool = True, # if False, will not generate plotly code
|
|
898
1010
|
) -> Union[
|
|
899
1011
|
Tuple[
|
|
900
1012
|
Union[str, None],
|
|
@@ -975,7 +1087,9 @@ class VannaBase(ABC):
|
|
|
975
1087
|
display = __import__(
|
|
976
1088
|
"IPython.display", fromlist=["display"]
|
|
977
1089
|
).display
|
|
978
|
-
Image = __import__(
|
|
1090
|
+
Image = __import__(
|
|
1091
|
+
"IPython.display", fromlist=["Image"]
|
|
1092
|
+
).Image
|
|
979
1093
|
img_bytes = fig.to_image(format="png", scale=2)
|
|
980
1094
|
display(Image(img_bytes))
|
|
981
1095
|
except Exception as e:
|
|
@@ -1328,4 +1442,3 @@ class VannaBase(ABC):
|
|
|
1328
1442
|
fig.update_layout(template="plotly_dark")
|
|
1329
1443
|
|
|
1330
1444
|
return fig
|
|
1331
|
-
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import List
|
|
3
2
|
import uuid
|
|
4
|
-
from
|
|
3
|
+
from typing import List
|
|
5
4
|
|
|
6
5
|
import chromadb
|
|
7
6
|
import pandas as pd
|
|
@@ -20,13 +19,28 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
20
19
|
if config is not None:
|
|
21
20
|
path = config.get("path", ".")
|
|
22
21
|
self.embedding_function = config.get("embedding_function", default_ef)
|
|
22
|
+
curr_client = config.get("client", "persistent")
|
|
23
|
+
self.n_results = config.get("n_results", 10)
|
|
23
24
|
else:
|
|
24
25
|
path = "."
|
|
25
26
|
self.embedding_function = default_ef
|
|
27
|
+
curr_client = "persistent" # defaults to persistent storage
|
|
28
|
+
self.n_results = 10 # defaults to 10 documents
|
|
29
|
+
|
|
30
|
+
if curr_client == "persistent":
|
|
31
|
+
self.chroma_client = chromadb.PersistentClient(
|
|
32
|
+
path=path, settings=Settings(anonymized_telemetry=False)
|
|
33
|
+
)
|
|
34
|
+
elif curr_client == "in-memory":
|
|
35
|
+
self.chroma_client = chromadb.EphemeralClient(
|
|
36
|
+
settings=Settings(anonymized_telemetry=False)
|
|
37
|
+
)
|
|
38
|
+
elif isinstance(curr_client, chromadb.api.client.Client):
|
|
39
|
+
# allow providing client directly
|
|
40
|
+
self.chroma_client = curr_client
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Unsupported client was set in config: {curr_client}")
|
|
26
43
|
|
|
27
|
-
self.chroma_client = chromadb.PersistentClient(
|
|
28
|
-
path=path, settings=Settings(anonymized_telemetry=False)
|
|
29
|
-
)
|
|
30
44
|
self.documentation_collection = self.chroma_client.get_or_create_collection(
|
|
31
45
|
name="documentation", embedding_function=self.embedding_function
|
|
32
46
|
)
|
|
@@ -196,7 +210,8 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
196
210
|
query_results (pd.DataFrame): The dataframe to use.
|
|
197
211
|
|
|
198
212
|
Returns:
|
|
199
|
-
List[str] or None: The extracted documents, or an empty list or
|
|
213
|
+
List[str] or None: The extracted documents, or an empty list or
|
|
214
|
+
single document if an error occurred.
|
|
200
215
|
"""
|
|
201
216
|
if query_results is None:
|
|
202
217
|
return []
|
|
@@ -216,6 +231,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
216
231
|
return ChromaDB_VectorStore._extract_documents(
|
|
217
232
|
self.sql_collection.query(
|
|
218
233
|
query_texts=[question],
|
|
234
|
+
n_results=self.n_results,
|
|
219
235
|
)
|
|
220
236
|
)
|
|
221
237
|
|
|
@@ -7,6 +7,8 @@ import flask
|
|
|
7
7
|
import requests
|
|
8
8
|
from flask import Flask, Response, jsonify, request
|
|
9
9
|
|
|
10
|
+
from .assets import css_content, html_content, js_content
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class Cache(ABC):
|
|
12
14
|
@abstractmethod
|
|
@@ -92,10 +94,11 @@ class VannaFlaskApp:
|
|
|
92
94
|
|
|
93
95
|
return decorator
|
|
94
96
|
|
|
95
|
-
def __init__(self, vn, cache: Cache = MemoryCache()):
|
|
97
|
+
def __init__(self, vn, cache: Cache = MemoryCache(), allow_llm_to_see_data=False):
|
|
96
98
|
self.flask_app = Flask(__name__)
|
|
97
99
|
self.vn = vn
|
|
98
100
|
self.cache = cache
|
|
101
|
+
self.allow_llm_to_see_data = allow_llm_to_see_data
|
|
99
102
|
|
|
100
103
|
log = logging.getLogger("werkzeug")
|
|
101
104
|
log.setLevel(logging.ERROR)
|
|
@@ -296,23 +299,55 @@ class VannaFlaskApp:
|
|
|
296
299
|
return jsonify({"type": "error", "error": str(e)})
|
|
297
300
|
|
|
298
301
|
@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
|
|
299
|
-
@self.requires_cache(["df", "question"])
|
|
300
|
-
def generate_followup_questions(id: str, df, question):
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
302
|
+
@self.requires_cache(["df", "question", "sql"])
|
|
303
|
+
def generate_followup_questions(id: str, df, question, sql):
|
|
304
|
+
if self.allow_llm_to_see_data:
|
|
305
|
+
followup_questions = vn.generate_followup_questions(
|
|
306
|
+
question=question, sql=sql, df=df
|
|
307
|
+
)
|
|
308
|
+
if followup_questions is not None and len(followup_questions) > 5:
|
|
309
|
+
followup_questions = followup_questions[:5]
|
|
305
310
|
|
|
306
|
-
|
|
311
|
+
cache.set(id=id, field="followup_questions", value=followup_questions)
|
|
307
312
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
313
|
+
return jsonify(
|
|
314
|
+
{
|
|
315
|
+
"type": "question_list",
|
|
316
|
+
"id": id,
|
|
317
|
+
"questions": followup_questions,
|
|
318
|
+
"header": "Here are some potential followup questions:",
|
|
319
|
+
}
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
return jsonify(
|
|
323
|
+
{
|
|
324
|
+
"type": "question_list",
|
|
325
|
+
"id": id,
|
|
326
|
+
"questions": [],
|
|
327
|
+
"header": "Followup Questions can be enabled if you set allow_llm_to_see_data=True",
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
@self.flask_app.route("/api/v0/generate_summary", methods=["GET"])
|
|
332
|
+
@self.requires_cache(["df", "question"])
|
|
333
|
+
def generate_summary(id: str, df, question):
|
|
334
|
+
if self.allow_llm_to_see_data:
|
|
335
|
+
summary = vn.generate_summary(question=question, df=df)
|
|
336
|
+
return jsonify(
|
|
337
|
+
{
|
|
338
|
+
"type": "text",
|
|
339
|
+
"id": id,
|
|
340
|
+
"text": summary,
|
|
341
|
+
}
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
return jsonify(
|
|
345
|
+
{
|
|
346
|
+
"type": "text",
|
|
347
|
+
"id": id,
|
|
348
|
+
"text": "Summarization can be enabled if you set allow_llm_to_see_data=True",
|
|
349
|
+
}
|
|
350
|
+
)
|
|
316
351
|
|
|
317
352
|
@self.flask_app.route("/api/v0/load_question", methods=["GET"])
|
|
318
353
|
@self.requires_cache(
|
|
@@ -352,25 +387,14 @@ class VannaFlaskApp:
|
|
|
352
387
|
|
|
353
388
|
@self.flask_app.route("/assets/<path:filename>")
|
|
354
389
|
def proxy_assets(filename):
|
|
355
|
-
|
|
356
|
-
|
|
390
|
+
if ".css" in filename:
|
|
391
|
+
return Response(css_content, mimetype="text/css")
|
|
357
392
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
"transfer-encoding",
|
|
364
|
-
"connection",
|
|
365
|
-
]
|
|
366
|
-
headers = [
|
|
367
|
-
(name, value)
|
|
368
|
-
for (name, value) in response.raw.headers.items()
|
|
369
|
-
if name.lower() not in excluded_headers
|
|
370
|
-
]
|
|
371
|
-
return Response(response.content, response.status_code, headers)
|
|
372
|
-
else:
|
|
373
|
-
return "Error fetching file from remote server", response.status_code
|
|
393
|
+
if ".js" in filename:
|
|
394
|
+
return Response(js_content, mimetype="text/javascript")
|
|
395
|
+
|
|
396
|
+
# Return 404
|
|
397
|
+
return "File not found", 404
|
|
374
398
|
|
|
375
399
|
# Proxy the /vanna.svg file to the remote server
|
|
376
400
|
@self.flask_app.route("/vanna.svg")
|
|
@@ -398,24 +422,7 @@ class VannaFlaskApp:
|
|
|
398
422
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
399
423
|
@self.flask_app.route("/<path:path>")
|
|
400
424
|
def hello(path: str):
|
|
401
|
-
return
|
|
402
|
-
<!doctype html>
|
|
403
|
-
<html lang="en">
|
|
404
|
-
<head>
|
|
405
|
-
<meta charset="UTF-8" />
|
|
406
|
-
<link rel="icon" type="image/svg+xml" href="/vanna.svg" />
|
|
407
|
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
408
|
-
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
|
|
409
|
-
<script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
|
|
410
|
-
<title>Vanna.AI</title>
|
|
411
|
-
<script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
|
|
412
|
-
<link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
|
|
413
|
-
</head>
|
|
414
|
-
<body class="bg-white dark:bg-slate-900">
|
|
415
|
-
<div id="app"></div>
|
|
416
|
-
</body>
|
|
417
|
-
</html>
|
|
418
|
-
"""
|
|
425
|
+
return html_content
|
|
419
426
|
|
|
420
427
|
def run(self):
|
|
421
428
|
try:
|