vanna 0.4.3__tar.gz → 0.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vanna-0.4.3 → vanna-0.5.1}/PKG-INFO +20 -3
- {vanna-0.4.3 → vanna-0.5.1}/README.md +2 -2
- {vanna-0.4.3 → vanna-0.5.1}/pyproject.toml +8 -4
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/base/base.py +150 -41
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/flask/__init__.py +50 -9
- vanna-0.5.1/src/vanna/flask/assets.py +38 -0
- vanna-0.5.1/src/vanna/hf/__init__.py +1 -0
- vanna-0.5.1/src/vanna/hf/hf.py +79 -0
- vanna-0.5.1/src/vanna/ollama/ollama.py +101 -0
- vanna-0.5.1/src/vanna/opensearch/__init__.py +1 -0
- vanna-0.5.1/src/vanna/opensearch/opensearch_vector.py +289 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/qdrant/qdrant.py +2 -4
- vanna-0.5.1/src/vanna/vllm/__init__.py +1 -0
- vanna-0.4.3/src/vanna/ollama/ollama.py → vanna-0.5.1/src/vanna/vllm/vllm.py +7 -7
- vanna-0.4.3/src/vanna/flask/assets.py +0 -38
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/ZhipuAI/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/anthropic/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/anthropic/anthropic_chat.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/base/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/chromadb/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/chromadb/chromadb_vector.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/exceptions/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/flask/auth.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/google/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/google/gemini_chat.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/local.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/marqo/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/marqo/marqo.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/mistral/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/mistral/mistral.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/ollama/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/openai/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/openai/openai_chat.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/openai/openai_embeddings.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/qdrant/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/remote.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/types/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/utils.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/vannadb/__init__.py +0 -0
- {vanna-0.4.3 → vanna-0.5.1}/src/vanna/vannadb/vannadb_vector.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.9
|
|
@@ -15,6 +15,7 @@ Requires-Dist: pandas
|
|
|
15
15
|
Requires-Dist: sqlparse
|
|
16
16
|
Requires-Dist: kaleido
|
|
17
17
|
Requires-Dist: flask
|
|
18
|
+
Requires-Dist: flask-sock
|
|
18
19
|
Requires-Dist: sqlalchemy
|
|
19
20
|
Requires-Dist: psycopg2-binary ; extra == "all"
|
|
20
21
|
Requires-Dist: db-dtypes ; extra == "all"
|
|
@@ -32,6 +33,11 @@ Requires-Dist: google-generativeai ; extra == "all"
|
|
|
32
33
|
Requires-Dist: google-cloud-aiplatform ; extra == "all"
|
|
33
34
|
Requires-Dist: qdrant-client ; extra == "all"
|
|
34
35
|
Requires-Dist: fastembed ; extra == "all"
|
|
36
|
+
Requires-Dist: ollama ; extra == "all"
|
|
37
|
+
Requires-Dist: httpx ; extra == "all"
|
|
38
|
+
Requires-Dist: opensearch-py ; extra == "all"
|
|
39
|
+
Requires-Dist: opensearch-dsl ; extra == "all"
|
|
40
|
+
Requires-Dist: transformers ; extra == "all"
|
|
35
41
|
Requires-Dist: anthropic ; extra == "anthropic"
|
|
36
42
|
Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
|
|
37
43
|
Requires-Dist: chromadb ; extra == "chromadb"
|
|
@@ -39,15 +45,22 @@ Requires-Dist: duckdb ; extra == "duckdb"
|
|
|
39
45
|
Requires-Dist: google-generativeai ; extra == "gemini"
|
|
40
46
|
Requires-Dist: google-generativeai ; extra == "google"
|
|
41
47
|
Requires-Dist: google-cloud-aiplatform ; extra == "google"
|
|
48
|
+
Requires-Dist: transformers ; extra == "hf"
|
|
42
49
|
Requires-Dist: marqo ; extra == "marqo"
|
|
43
50
|
Requires-Dist: mistralai ; extra == "mistralai"
|
|
44
51
|
Requires-Dist: PyMySQL ; extra == "mysql"
|
|
52
|
+
Requires-Dist: ollama ; extra == "ollama"
|
|
53
|
+
Requires-Dist: httpx ; extra == "ollama"
|
|
45
54
|
Requires-Dist: openai ; extra == "openai"
|
|
55
|
+
Requires-Dist: opensearch-py ; extra == "opensearch"
|
|
56
|
+
Requires-Dist: opensearch-dsl ; extra == "opensearch"
|
|
46
57
|
Requires-Dist: psycopg2-binary ; extra == "postgres"
|
|
47
58
|
Requires-Dist: db-dtypes ; extra == "postgres"
|
|
48
59
|
Requires-Dist: qdrant-client ; extra == "qdrant"
|
|
60
|
+
Requires-Dist: fastembed ; extra == "qdrant"
|
|
49
61
|
Requires-Dist: snowflake-connector-python ; extra == "snowflake"
|
|
50
62
|
Requires-Dist: tox ; extra == "test"
|
|
63
|
+
Requires-Dist: vllm ; extra == "vllm"
|
|
51
64
|
Requires-Dist: zhipuai ; extra == "zhipuai"
|
|
52
65
|
Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
|
|
53
66
|
Project-URL: Homepage, https://github.com/vanna-ai/vanna
|
|
@@ -58,14 +71,18 @@ Provides-Extra: chromadb
|
|
|
58
71
|
Provides-Extra: duckdb
|
|
59
72
|
Provides-Extra: gemini
|
|
60
73
|
Provides-Extra: google
|
|
74
|
+
Provides-Extra: hf
|
|
61
75
|
Provides-Extra: marqo
|
|
62
76
|
Provides-Extra: mistralai
|
|
63
77
|
Provides-Extra: mysql
|
|
78
|
+
Provides-Extra: ollama
|
|
64
79
|
Provides-Extra: openai
|
|
80
|
+
Provides-Extra: opensearch
|
|
65
81
|
Provides-Extra: postgres
|
|
66
82
|
Provides-Extra: qdrant
|
|
67
83
|
Provides-Extra: snowflake
|
|
68
84
|
Provides-Extra: test
|
|
85
|
+
Provides-Extra: vllm
|
|
69
86
|
Provides-Extra: zhipuai
|
|
70
87
|
|
|
71
88
|
|
|
@@ -95,7 +112,7 @@ Vanna works in two easy steps - train a RAG "model" on your data, and then ask q
|
|
|
95
112
|
|
|
96
113
|
If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions.
|
|
97
114
|
|
|
98
|
-
See the [base class](src/vanna/base/base.py) for more details on how this works under the hood.
|
|
115
|
+
See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood.
|
|
99
116
|
|
|
100
117
|
## User Interfaces
|
|
101
118
|
These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface.
|
|
@@ -288,7 +305,7 @@ Fine-Tuning
|
|
|
288
305
|
- Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end.
|
|
289
306
|
|
|
290
307
|
## Extending Vanna
|
|
291
|
-
Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
|
|
308
|
+
Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
|
|
292
309
|
|
|
293
310
|
## Vanna in 100 Seconds
|
|
294
311
|
|
|
@@ -25,7 +25,7 @@ Vanna works in two easy steps - train a RAG "model" on your data, and then ask q
|
|
|
25
25
|
|
|
26
26
|
If you don't know what RAG is, don't worry -- you don't need to know how this works under the hood to use it. You just need to know that you "train" a model, which stores some metadata and then use it to "ask" questions.
|
|
27
27
|
|
|
28
|
-
See the [base class](src/vanna/base/base.py) for more details on how this works under the hood.
|
|
28
|
+
See the [base class](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) for more details on how this works under the hood.
|
|
29
29
|
|
|
30
30
|
## User Interfaces
|
|
31
31
|
These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface.
|
|
@@ -218,7 +218,7 @@ Fine-Tuning
|
|
|
218
218
|
- Expose to your end users via Slackbot, web app, Streamlit app, or a custom front end.
|
|
219
219
|
|
|
220
220
|
## Extending Vanna
|
|
221
|
-
Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
|
|
221
|
+
Vanna is designed to connect to any database, LLM, and vector database. There's a [VannaBase](https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py) abstract base class that defines some basic functionality. The package provides implementations for use with OpenAI and ChromaDB. You can easily extend Vanna to use your own LLM or vector database. See the [documentation](https://vanna.ai/docs/) for more details.
|
|
222
222
|
|
|
223
223
|
## Vanna in 100 Seconds
|
|
224
224
|
|
|
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "vanna"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.5.1"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Zain Hoda", email="zain@vanna.ai" },
|
|
10
10
|
]
|
|
@@ -18,7 +18,7 @@ classifiers = [
|
|
|
18
18
|
"Operating System :: OS Independent",
|
|
19
19
|
]
|
|
20
20
|
dependencies = [
|
|
21
|
-
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "sqlalchemy"
|
|
21
|
+
"requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "flask-sock", "sqlalchemy"
|
|
22
22
|
]
|
|
23
23
|
|
|
24
24
|
[project.urls]
|
|
@@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"]
|
|
|
32
32
|
snowflake = ["snowflake-connector-python"]
|
|
33
33
|
duckdb = ["duckdb"]
|
|
34
34
|
google = ["google-generativeai", "google-cloud-aiplatform"]
|
|
35
|
-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed"]
|
|
35
|
+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"]
|
|
36
36
|
test = ["tox"]
|
|
37
37
|
chromadb = ["chromadb"]
|
|
38
38
|
openai = ["openai"]
|
|
@@ -41,4 +41,8 @@ anthropic = ["anthropic"]
|
|
|
41
41
|
gemini = ["google-generativeai"]
|
|
42
42
|
marqo = ["marqo"]
|
|
43
43
|
zhipuai = ["zhipuai"]
|
|
44
|
-
|
|
44
|
+
ollama = ["ollama", "httpx"]
|
|
45
|
+
qdrant = ["qdrant-client", "fastembed"]
|
|
46
|
+
vllm = ["vllm"]
|
|
47
|
+
opensearch = ["opensearch-py", "opensearch-dsl"]
|
|
48
|
+
hf = ["transformers"]
|
|
@@ -62,6 +62,7 @@ import plotly
|
|
|
62
62
|
import plotly.express as px
|
|
63
63
|
import plotly.graph_objects as go
|
|
64
64
|
import requests
|
|
65
|
+
import sqlparse
|
|
65
66
|
|
|
66
67
|
from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
|
|
67
68
|
from ..types import TrainingPlan, TrainingPlanItem
|
|
@@ -70,14 +71,25 @@ from ..utils import validate_config_path
|
|
|
70
71
|
|
|
71
72
|
class VannaBase(ABC):
|
|
72
73
|
def __init__(self, config=None):
|
|
74
|
+
if config is None:
|
|
75
|
+
config = {}
|
|
76
|
+
|
|
73
77
|
self.config = config
|
|
74
78
|
self.run_sql_is_set = False
|
|
75
79
|
self.static_documentation = ""
|
|
80
|
+
self.dialect = self.config.get("dialect", "SQL")
|
|
81
|
+
self.language = self.config.get("language", None)
|
|
76
82
|
|
|
77
|
-
def log(self, message: str):
|
|
83
|
+
def log(self, message: str, title: str = "Info"):
|
|
78
84
|
print(message)
|
|
79
85
|
|
|
80
|
-
def
|
|
86
|
+
def _response_language(self) -> str:
|
|
87
|
+
if self.language is None:
|
|
88
|
+
return ""
|
|
89
|
+
|
|
90
|
+
return f"Respond in the {self.language} language."
|
|
91
|
+
|
|
92
|
+
def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
|
|
81
93
|
"""
|
|
82
94
|
Example:
|
|
83
95
|
```python
|
|
@@ -99,6 +111,7 @@ class VannaBase(ABC):
|
|
|
99
111
|
|
|
100
112
|
Args:
|
|
101
113
|
question (str): The question to generate a SQL query for.
|
|
114
|
+
allow_llm_to_see_data (bool): Whether to allow the LLM to see the data (for the purposes of introspecting the data to generate the final SQL).
|
|
102
115
|
|
|
103
116
|
Returns:
|
|
104
117
|
str: The SQL query that answers the question.
|
|
@@ -118,45 +131,129 @@ class VannaBase(ABC):
|
|
|
118
131
|
doc_list=doc_list,
|
|
119
132
|
**kwargs,
|
|
120
133
|
)
|
|
121
|
-
self.log(prompt)
|
|
134
|
+
self.log(title="SQL Prompt", message=prompt)
|
|
122
135
|
llm_response = self.submit_prompt(prompt, **kwargs)
|
|
123
|
-
self.log(llm_response)
|
|
136
|
+
self.log(title="LLM Response", message=llm_response)
|
|
137
|
+
|
|
138
|
+
if 'intermediate_sql' in llm_response:
|
|
139
|
+
if not allow_llm_to_see_data:
|
|
140
|
+
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
|
|
141
|
+
|
|
142
|
+
if allow_llm_to_see_data:
|
|
143
|
+
intermediate_sql = self.extract_sql(llm_response)
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
self.log(title="Running Intermediate SQL", message=intermediate_sql)
|
|
147
|
+
df = self.run_sql(intermediate_sql)
|
|
148
|
+
|
|
149
|
+
prompt = self.get_sql_prompt(
|
|
150
|
+
initial_prompt=initial_prompt,
|
|
151
|
+
question=question,
|
|
152
|
+
question_sql_list=question_sql_list,
|
|
153
|
+
ddl_list=ddl_list,
|
|
154
|
+
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
|
|
155
|
+
**kwargs,
|
|
156
|
+
)
|
|
157
|
+
self.log(title="Final SQL Prompt", message=prompt)
|
|
158
|
+
llm_response = self.submit_prompt(prompt, **kwargs)
|
|
159
|
+
self.log(title="LLM Response", message=llm_response)
|
|
160
|
+
except Exception as e:
|
|
161
|
+
return f"Error running intermediate SQL: {e}"
|
|
162
|
+
|
|
163
|
+
|
|
124
164
|
return self.extract_sql(llm_response)
|
|
125
165
|
|
|
126
166
|
def extract_sql(self, llm_response: str) -> str:
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
# If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
|
|
133
|
-
sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
|
|
134
|
-
if sql:
|
|
135
|
-
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
|
|
136
|
-
)
|
|
137
|
-
return sql.group(0)
|
|
167
|
+
"""
|
|
168
|
+
Example:
|
|
169
|
+
```python
|
|
170
|
+
vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
|
|
171
|
+
```
|
|
138
172
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
if sql:
|
|
142
|
-
self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
|
|
143
|
-
return sql.group(1)
|
|
173
|
+
Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
|
|
174
|
+
Override this function if your LLM responses need custom extraction logic.
|
|
144
175
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
176
|
+
Args:
|
|
177
|
+
llm_response (str): The LLM response.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
str: The extracted SQL query.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
# If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
|
|
184
|
+
sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
|
|
185
|
+
if sqls:
|
|
186
|
+
sql = sqls[-1]
|
|
187
|
+
self.log(title="Extracted SQL", message=f"{sql}")
|
|
188
|
+
return sql
|
|
189
|
+
|
|
190
|
+
# If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
|
|
191
|
+
sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
|
|
192
|
+
if sqls:
|
|
193
|
+
sql = sqls[-1]
|
|
194
|
+
self.log(title="Extracted SQL", message=f"{sql}")
|
|
195
|
+
return sql
|
|
196
|
+
|
|
197
|
+
# If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
|
|
198
|
+
sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
|
|
199
|
+
if sqls:
|
|
200
|
+
sql = sqls[-1]
|
|
201
|
+
self.log(title="Extracted SQL", message=f"{sql}")
|
|
202
|
+
return sql
|
|
203
|
+
|
|
204
|
+
sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
|
|
205
|
+
if sqls:
|
|
206
|
+
sql = sqls[-1]
|
|
207
|
+
self.log(title="Extracted SQL", message=f"{sql}")
|
|
208
|
+
return sql
|
|
149
209
|
|
|
150
210
|
return llm_response
|
|
151
211
|
|
|
152
212
|
def is_sql_valid(self, sql: str) -> bool:
|
|
153
|
-
|
|
154
|
-
|
|
213
|
+
"""
|
|
214
|
+
Example:
|
|
215
|
+
```python
|
|
216
|
+
vn.is_sql_valid("SELECT * FROM customers")
|
|
217
|
+
```
|
|
218
|
+
Checks if the SQL query is valid. This is usually used to check if we should run the SQL query or not.
|
|
219
|
+
By default it checks if the SQL query is a SELECT statement. You can override this method to enable running other types of SQL queries.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
sql (str): The SQL query to check.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
bool: True if the SQL query is valid, False otherwise.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
parsed = sqlparse.parse(sql)
|
|
229
|
+
|
|
230
|
+
for statement in parsed:
|
|
231
|
+
if statement.get_type() == 'SELECT':
|
|
232
|
+
return True
|
|
155
233
|
|
|
156
|
-
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
def should_generate_chart(self, df: pd.DataFrame) -> bool:
|
|
237
|
+
"""
|
|
238
|
+
Example:
|
|
239
|
+
```python
|
|
240
|
+
vn.should_generate_chart(df)
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
Checks if a chart should be generated for the given DataFrame. By default, it checks if the DataFrame has more than one row and has numerical columns.
|
|
244
|
+
You can override this method to customize the logic for generating charts.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
df (pd.DataFrame): The DataFrame to check.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
bool: True if a chart should be generated, False otherwise.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0:
|
|
157
254
|
return True
|
|
158
|
-
|
|
159
|
-
|
|
255
|
+
|
|
256
|
+
return False
|
|
160
257
|
|
|
161
258
|
def generate_followup_questions(
|
|
162
259
|
self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
|
|
@@ -184,7 +281,8 @@ class VannaBase(ABC):
|
|
|
184
281
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
185
282
|
),
|
|
186
283
|
self.user_message(
|
|
187
|
-
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
|
|
284
|
+
f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." +
|
|
285
|
+
self._response_language()
|
|
188
286
|
),
|
|
189
287
|
]
|
|
190
288
|
|
|
@@ -228,7 +326,8 @@ class VannaBase(ABC):
|
|
|
228
326
|
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
|
|
229
327
|
),
|
|
230
328
|
self.user_message(
|
|
231
|
-
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
|
|
329
|
+
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." +
|
|
330
|
+
self._response_language()
|
|
232
331
|
),
|
|
233
332
|
]
|
|
234
333
|
|
|
@@ -375,7 +474,7 @@ class VannaBase(ABC):
|
|
|
375
474
|
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
|
|
376
475
|
) -> str:
|
|
377
476
|
if len(ddl_list) > 0:
|
|
378
|
-
initial_prompt += "\
|
|
477
|
+
initial_prompt += "\n===Tables \n"
|
|
379
478
|
|
|
380
479
|
for ddl in ddl_list:
|
|
381
480
|
if (
|
|
@@ -394,7 +493,7 @@ class VannaBase(ABC):
|
|
|
394
493
|
max_tokens: int = 14000,
|
|
395
494
|
) -> str:
|
|
396
495
|
if len(documentation_list) > 0:
|
|
397
|
-
initial_prompt += "\
|
|
496
|
+
initial_prompt += "\n===Additional Context \n\n"
|
|
398
497
|
|
|
399
498
|
for documentation in documentation_list:
|
|
400
499
|
if (
|
|
@@ -410,7 +509,7 @@ class VannaBase(ABC):
|
|
|
410
509
|
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
|
|
411
510
|
) -> str:
|
|
412
511
|
if len(sql_list) > 0:
|
|
413
|
-
initial_prompt += "\
|
|
512
|
+
initial_prompt += "\n===Question-SQL Pairs\n\n"
|
|
414
513
|
|
|
415
514
|
for question in sql_list:
|
|
416
515
|
if (
|
|
@@ -456,7 +555,8 @@ class VannaBase(ABC):
|
|
|
456
555
|
"""
|
|
457
556
|
|
|
458
557
|
if initial_prompt is None:
|
|
459
|
-
initial_prompt = "
|
|
558
|
+
initial_prompt = f"You are a {self.dialect} expert. "
|
|
559
|
+
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
|
460
560
|
|
|
461
561
|
initial_prompt = self.add_ddl_to_prompt(
|
|
462
562
|
initial_prompt, ddl_list, max_tokens=14000
|
|
@@ -469,6 +569,15 @@ class VannaBase(ABC):
|
|
|
469
569
|
initial_prompt, doc_list, max_tokens=14000
|
|
470
570
|
)
|
|
471
571
|
|
|
572
|
+
initial_prompt += (
|
|
573
|
+
"===Response Guidelines \n"
|
|
574
|
+
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
|
575
|
+
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
|
576
|
+
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
|
577
|
+
"4. Please use the most relevant table(s). \n"
|
|
578
|
+
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
|
579
|
+
)
|
|
580
|
+
|
|
472
581
|
message_log = [self.system_message(initial_prompt)]
|
|
473
582
|
|
|
474
583
|
for example in question_sql_list:
|
|
@@ -676,7 +785,7 @@ class VannaBase(ABC):
|
|
|
676
785
|
|
|
677
786
|
return df
|
|
678
787
|
|
|
679
|
-
self.
|
|
788
|
+
self.dialect = "Snowflake SQL"
|
|
680
789
|
self.run_sql = run_sql_snowflake
|
|
681
790
|
self.run_sql_is_set = True
|
|
682
791
|
|
|
@@ -710,7 +819,7 @@ class VannaBase(ABC):
|
|
|
710
819
|
def run_sql_sqlite(sql: str):
|
|
711
820
|
return pd.read_sql_query(sql, conn)
|
|
712
821
|
|
|
713
|
-
self.
|
|
822
|
+
self.dialect = "SQLite"
|
|
714
823
|
self.run_sql = run_sql_sqlite
|
|
715
824
|
self.run_sql_is_set = True
|
|
716
825
|
|
|
@@ -815,7 +924,7 @@ class VannaBase(ABC):
|
|
|
815
924
|
conn.rollback()
|
|
816
925
|
raise e
|
|
817
926
|
|
|
818
|
-
self.
|
|
927
|
+
self.dialect = "PostgreSQL"
|
|
819
928
|
self.run_sql_is_set = True
|
|
820
929
|
self.run_sql = run_sql_postgres
|
|
821
930
|
|
|
@@ -1078,7 +1187,7 @@ class VannaBase(ABC):
|
|
|
1078
1187
|
raise errors
|
|
1079
1188
|
return None
|
|
1080
1189
|
|
|
1081
|
-
self.
|
|
1190
|
+
self.dialect = "BigQuery SQL"
|
|
1082
1191
|
self.run_sql_is_set = True
|
|
1083
1192
|
self.run_sql = run_sql_bigquery
|
|
1084
1193
|
|
|
@@ -1127,7 +1236,7 @@ class VannaBase(ABC):
|
|
|
1127
1236
|
def run_sql_duckdb(sql: str):
|
|
1128
1237
|
return conn.query(sql).to_df()
|
|
1129
1238
|
|
|
1130
|
-
self.
|
|
1239
|
+
self.dialect = "DuckDB SQL"
|
|
1131
1240
|
self.run_sql = run_sql_duckdb
|
|
1132
1241
|
self.run_sql_is_set = True
|
|
1133
1242
|
|
|
@@ -1174,7 +1283,7 @@ class VannaBase(ABC):
|
|
|
1174
1283
|
|
|
1175
1284
|
raise Exception("Couldn't run sql")
|
|
1176
1285
|
|
|
1177
|
-
self.
|
|
1286
|
+
self.dialect = "T-SQL / Microsoft SQL Server"
|
|
1178
1287
|
self.run_sql = run_sql_mssql
|
|
1179
1288
|
self.run_sql_is_set = True
|
|
1180
1289
|
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
3
|
+
import sys
|
|
2
4
|
import uuid
|
|
3
5
|
from abc import ABC, abstractmethod
|
|
4
6
|
from functools import wraps
|
|
@@ -6,6 +8,7 @@ from functools import wraps
|
|
|
6
8
|
import flask
|
|
7
9
|
import requests
|
|
8
10
|
from flask import Flask, Response, jsonify, request
|
|
11
|
+
from flask_sock import Sock
|
|
9
12
|
|
|
10
13
|
from .assets import css_content, html_content, js_content
|
|
11
14
|
from .auth import AuthInterface, NoAuth
|
|
@@ -133,6 +136,7 @@ class VannaFlaskApp:
|
|
|
133
136
|
|
|
134
137
|
def __init__(self, vn, cache: Cache = MemoryCache(),
|
|
135
138
|
auth: AuthInterface = NoAuth(),
|
|
139
|
+
debug=True,
|
|
136
140
|
allow_llm_to_see_data=False,
|
|
137
141
|
logo="https://img.vanna.ai/vanna-flask.svg",
|
|
138
142
|
title="Welcome to Vanna.AI",
|
|
@@ -156,6 +160,7 @@ class VannaFlaskApp:
|
|
|
156
160
|
vn: The Vanna instance to interact with.
|
|
157
161
|
cache: The cache to use. Defaults to MemoryCache, which uses an in-memory cache. You can also pass in a custom cache that implements the Cache interface.
|
|
158
162
|
auth: The authentication method to use. Defaults to NoAuth, which doesn't require authentication. You can also pass in a custom authentication method that implements the AuthInterface interface.
|
|
163
|
+
debug: Show the debug console. Defaults to True.
|
|
159
164
|
allow_llm_to_see_data: Whether to allow the LLM to see data. Defaults to False.
|
|
160
165
|
logo: The logo to display in the UI. Defaults to the Vanna logo.
|
|
161
166
|
title: The title to display in the UI. Defaults to "Welcome to Vanna.AI".
|
|
@@ -176,7 +181,10 @@ class VannaFlaskApp:
|
|
|
176
181
|
None
|
|
177
182
|
"""
|
|
178
183
|
self.flask_app = Flask(__name__)
|
|
184
|
+
self.sock = Sock(self.flask_app)
|
|
185
|
+
self.ws_clients = []
|
|
179
186
|
self.vn = vn
|
|
187
|
+
self.debug = debug
|
|
180
188
|
self.auth = auth
|
|
181
189
|
self.cache = cache
|
|
182
190
|
self.allow_llm_to_see_data = allow_llm_to_see_data
|
|
@@ -198,6 +206,16 @@ class VannaFlaskApp:
|
|
|
198
206
|
log = logging.getLogger("werkzeug")
|
|
199
207
|
log.setLevel(logging.ERROR)
|
|
200
208
|
|
|
209
|
+
if "google.colab" in sys.modules:
|
|
210
|
+
self.debug = False
|
|
211
|
+
print("Google Colab doesn't support running websocket servers. Disabling debug mode.")
|
|
212
|
+
|
|
213
|
+
if self.debug:
|
|
214
|
+
def log(message, title="Info"):
|
|
215
|
+
[ws.send(json.dumps({'message': message, 'title': title})) for ws in self.ws_clients]
|
|
216
|
+
|
|
217
|
+
self.vn.log = log
|
|
218
|
+
|
|
201
219
|
@self.flask_app.route("/auth/login", methods=["POST"])
|
|
202
220
|
def login():
|
|
203
221
|
return self.auth.login_handler(flask.request)
|
|
@@ -214,6 +232,7 @@ class VannaFlaskApp:
|
|
|
214
232
|
@self.requires_auth
|
|
215
233
|
def get_config(user: any):
|
|
216
234
|
config = {
|
|
235
|
+
"debug": self.debug,
|
|
217
236
|
"logo": self.logo,
|
|
218
237
|
"title": self.title,
|
|
219
238
|
"subtitle": self.subtitle,
|
|
@@ -304,18 +323,27 @@ class VannaFlaskApp:
|
|
|
304
323
|
return jsonify({"type": "error", "error": "No question provided"})
|
|
305
324
|
|
|
306
325
|
id = self.cache.generate_id(question=question)
|
|
307
|
-
sql = vn.generate_sql(question=question)
|
|
326
|
+
sql = vn.generate_sql(question=question, allow_llm_to_see_data=self.allow_llm_to_see_data)
|
|
308
327
|
|
|
309
328
|
self.cache.set(id=id, field="question", value=question)
|
|
310
329
|
self.cache.set(id=id, field="sql", value=sql)
|
|
311
330
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
331
|
+
if vn.is_sql_valid(sql=sql):
|
|
332
|
+
return jsonify(
|
|
333
|
+
{
|
|
334
|
+
"type": "sql",
|
|
335
|
+
"id": id,
|
|
336
|
+
"text": sql,
|
|
337
|
+
}
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
return jsonify(
|
|
341
|
+
{
|
|
342
|
+
"type": "text",
|
|
343
|
+
"id": id,
|
|
344
|
+
"text": sql,
|
|
345
|
+
}
|
|
346
|
+
)
|
|
319
347
|
|
|
320
348
|
@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
|
|
321
349
|
@self.requires_auth
|
|
@@ -339,6 +367,7 @@ class VannaFlaskApp:
|
|
|
339
367
|
"type": "df",
|
|
340
368
|
"id": id,
|
|
341
369
|
"df": df.head(10).to_json(orient='records', date_format='iso'),
|
|
370
|
+
"should_generate_chart": self.chart and vn.should_generate_chart(df),
|
|
342
371
|
}
|
|
343
372
|
)
|
|
344
373
|
|
|
@@ -619,6 +648,18 @@ class VannaFlaskApp:
|
|
|
619
648
|
else:
|
|
620
649
|
return "Error fetching file from remote server", response.status_code
|
|
621
650
|
|
|
651
|
+
if self.debug:
|
|
652
|
+
@self.sock.route("/api/v0/log")
|
|
653
|
+
def sock_log(ws):
|
|
654
|
+
self.ws_clients.append(ws)
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
while True:
|
|
658
|
+
message = ws.receive() # This example just reads and ignores to keep the socket open
|
|
659
|
+
finally:
|
|
660
|
+
self.ws_clients.remove(ws)
|
|
661
|
+
|
|
662
|
+
|
|
622
663
|
@self.flask_app.route("/", defaults={"path": ""})
|
|
623
664
|
@self.flask_app.route("/<path:path>")
|
|
624
665
|
def hello(path: str):
|
|
@@ -651,4 +692,4 @@ class VannaFlaskApp:
|
|
|
651
692
|
print("Your app is running at:")
|
|
652
693
|
print("http://localhost:8084")
|
|
653
694
|
|
|
654
|
-
self.flask_app.run(host="0.0.0.0", port=8084, debug=
|
|
695
|
+
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
|