vanna 0.6.6__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/azuresearch/__init__.py +1 -0
- vanna/azuresearch/azuresearch_vector.py +236 -0
- vanna/base/base.py +62 -17
- vanna/flask/__init__.py +26 -0
- vanna/flask/assets.py +35 -35
- vanna/ollama/ollama.py +3 -1
- vanna/qianwen/QianwenAI_chat.py +133 -0
- vanna/qianwen/QianwenAI_embeddings.py +46 -0
- vanna/qianwen/__init__.py +2 -0
- vanna/weaviate/weaviate_vector.py +3 -2
- {vanna-0.6.6.dist-info → vanna-0.7.1.dist-info}/METADATA +9 -1
- {vanna-0.6.6.dist-info → vanna-0.7.1.dist-info}/RECORD +13 -8
- {vanna-0.6.6.dist-info → vanna-0.7.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .azuresearch_vector import AzureAISearch_VectorStore
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from azure.core.credentials import AzureKeyCredential
|
|
7
|
+
from azure.search.documents import SearchClient
|
|
8
|
+
from azure.search.documents.indexes import SearchIndexClient
|
|
9
|
+
from azure.search.documents.indexes.models import (
|
|
10
|
+
ExhaustiveKnnAlgorithmConfiguration,
|
|
11
|
+
ExhaustiveKnnParameters,
|
|
12
|
+
SearchableField,
|
|
13
|
+
SearchField,
|
|
14
|
+
SearchFieldDataType,
|
|
15
|
+
SearchIndex,
|
|
16
|
+
VectorSearch,
|
|
17
|
+
VectorSearchAlgorithmKind,
|
|
18
|
+
VectorSearchAlgorithmMetric,
|
|
19
|
+
VectorSearchProfile,
|
|
20
|
+
)
|
|
21
|
+
from azure.search.documents.models import VectorFilterMode, VectorizedQuery
|
|
22
|
+
from fastembed import TextEmbedding
|
|
23
|
+
|
|
24
|
+
from ..base import VannaBase
|
|
25
|
+
from ..utils import deterministic_uuid
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AzureAISearch_VectorStore(VannaBase):
|
|
29
|
+
"""
|
|
30
|
+
AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
|
|
34
|
+
- azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
|
|
35
|
+
- azure_search_api_key (str): Azure Search API key.
|
|
36
|
+
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
|
|
37
|
+
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
|
|
38
|
+
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
|
|
39
|
+
- n_results (int, optional): Number of results to return. Defaults to 10.
|
|
40
|
+
- n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
|
|
41
|
+
- n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
|
|
42
|
+
- n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
|
|
46
|
+
"""
|
|
47
|
+
def __init__(self, config=None):
|
|
48
|
+
VannaBase.__init__(self, config=config)
|
|
49
|
+
|
|
50
|
+
self.config = config or None
|
|
51
|
+
|
|
52
|
+
if config is None:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"config is required, pass an API key, 'azure_search_api_key', in the config."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net")
|
|
58
|
+
azure_search_api_key = config.get("azure_search_api_key")
|
|
59
|
+
|
|
60
|
+
self.dimensions = config.get("dimensions", 384)
|
|
61
|
+
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
|
|
62
|
+
|
|
63
|
+
self.index_name = config.get("index_name", "vanna-index")
|
|
64
|
+
|
|
65
|
+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
|
|
66
|
+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
|
|
67
|
+
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
|
|
68
|
+
|
|
69
|
+
if not azure_search_api_key:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.index_client = SearchIndexClient(
|
|
75
|
+
endpoint=azure_search_endpoint,
|
|
76
|
+
credential=AzureKeyCredential(azure_search_api_key)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.search_client = SearchClient(
|
|
80
|
+
endpoint=azure_search_endpoint,
|
|
81
|
+
index_name=self.index_name,
|
|
82
|
+
credential=AzureKeyCredential(azure_search_api_key)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if self.index_name not in self._get_indexes():
|
|
86
|
+
self._create_index()
|
|
87
|
+
|
|
88
|
+
def _create_index(self) -> bool:
|
|
89
|
+
fields = [
|
|
90
|
+
SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
|
|
91
|
+
SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True),
|
|
92
|
+
SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True),
|
|
93
|
+
SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"),
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
vector_search = VectorSearch(
|
|
97
|
+
algorithms=[
|
|
98
|
+
ExhaustiveKnnAlgorithmConfiguration(
|
|
99
|
+
name="ExhaustiveKnn",
|
|
100
|
+
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
|
|
101
|
+
parameters=ExhaustiveKnnParameters(
|
|
102
|
+
metric=VectorSearchAlgorithmMetric.COSINE
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
],
|
|
106
|
+
profiles=[
|
|
107
|
+
VectorSearchProfile(
|
|
108
|
+
name="ExhaustiveKnnProfile",
|
|
109
|
+
algorithm_configuration_name="ExhaustiveKnn",
|
|
110
|
+
)
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
|
115
|
+
result = self.index_client.create_or_update_index(index)
|
|
116
|
+
print(f'{result.name} created')
|
|
117
|
+
|
|
118
|
+
def _get_indexes(self) -> list:
|
|
119
|
+
return [index for index in self.index_client.list_index_names()]
|
|
120
|
+
|
|
121
|
+
def add_ddl(self, ddl: str) -> str:
|
|
122
|
+
id = deterministic_uuid(ddl) + "-ddl"
|
|
123
|
+
document = {
|
|
124
|
+
"id": id,
|
|
125
|
+
"document": ddl,
|
|
126
|
+
"type": "ddl",
|
|
127
|
+
"document_vector": self.generate_embedding(ddl)
|
|
128
|
+
}
|
|
129
|
+
self.search_client.upload_documents(documents=[document])
|
|
130
|
+
return id
|
|
131
|
+
|
|
132
|
+
def add_documentation(self, doc: str) -> str:
|
|
133
|
+
id = deterministic_uuid(doc) + "-doc"
|
|
134
|
+
document = {
|
|
135
|
+
"id": id,
|
|
136
|
+
"document": doc,
|
|
137
|
+
"type": "doc",
|
|
138
|
+
"document_vector": self.generate_embedding(doc)
|
|
139
|
+
}
|
|
140
|
+
self.search_client.upload_documents(documents=[document])
|
|
141
|
+
return id
|
|
142
|
+
|
|
143
|
+
def add_question_sql(self, question: str, sql: str) -> str:
|
|
144
|
+
question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
|
|
145
|
+
id = deterministic_uuid(question_sql_json) + "-sql"
|
|
146
|
+
document = {
|
|
147
|
+
"id": id,
|
|
148
|
+
"document": question_sql_json,
|
|
149
|
+
"type": "sql",
|
|
150
|
+
"document_vector": self.generate_embedding(question_sql_json)
|
|
151
|
+
}
|
|
152
|
+
self.search_client.upload_documents(documents=[document])
|
|
153
|
+
return id
|
|
154
|
+
|
|
155
|
+
def get_related_ddl(self, text: str) -> List[str]:
|
|
156
|
+
result = []
|
|
157
|
+
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
|
|
158
|
+
df = pd.DataFrame(
|
|
159
|
+
self.search_client.search(
|
|
160
|
+
top=self.n_results_ddl,
|
|
161
|
+
vector_queries=[vector_query],
|
|
162
|
+
select=["id", "document", "type"],
|
|
163
|
+
filter=f"type eq 'ddl'"
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if len(df):
|
|
168
|
+
result = df["document"].tolist()
|
|
169
|
+
return result
|
|
170
|
+
|
|
171
|
+
def get_related_documentation(self, text: str) -> List[str]:
|
|
172
|
+
result = []
|
|
173
|
+
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
|
|
174
|
+
|
|
175
|
+
df = pd.DataFrame(
|
|
176
|
+
self.search_client.search(
|
|
177
|
+
top=self.n_results_documentation,
|
|
178
|
+
vector_queries=[vector_query],
|
|
179
|
+
select=["id", "document", "type"],
|
|
180
|
+
filter=f"type eq 'doc'",
|
|
181
|
+
vector_filter_mode=VectorFilterMode.PRE_FILTER
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if len(df):
|
|
186
|
+
result = df["document"].tolist()
|
|
187
|
+
return result
|
|
188
|
+
|
|
189
|
+
def get_similar_question_sql(self, text: str) -> List[str]:
|
|
190
|
+
result = []
|
|
191
|
+
# Vectorize the text
|
|
192
|
+
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
|
|
193
|
+
df = pd.DataFrame(
|
|
194
|
+
self.search_client.search(
|
|
195
|
+
top=self.n_results_sql,
|
|
196
|
+
vector_queries=[vector_query],
|
|
197
|
+
select=["id", "document", "type"],
|
|
198
|
+
filter=f"type eq 'sql'"
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if len(df): # Check if there is similar query and the result is not empty
|
|
203
|
+
result = [ast.literal_eval(element) for element in df["document"].tolist()]
|
|
204
|
+
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
def get_training_data(self) -> List[str]:
|
|
208
|
+
|
|
209
|
+
search = self.search_client.search(
|
|
210
|
+
search_text="*",
|
|
211
|
+
select=['id', 'document', 'type'],
|
|
212
|
+
filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
|
|
213
|
+
).by_page()
|
|
214
|
+
|
|
215
|
+
df = pd.DataFrame([item for page in search for item in page])
|
|
216
|
+
|
|
217
|
+
if len(df):
|
|
218
|
+
df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
|
|
219
|
+
df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
|
|
220
|
+
df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]
|
|
221
|
+
|
|
222
|
+
return df[["id", "question", "content", "type"]]
|
|
223
|
+
|
|
224
|
+
return pd.DataFrame()
|
|
225
|
+
|
|
226
|
+
def remove_training_data(self, id: str) -> bool:
|
|
227
|
+
result = self.search_client.delete_documents(documents=[{'id':id}])
|
|
228
|
+
return result[0].succeeded
|
|
229
|
+
|
|
230
|
+
def remove_index(self):
|
|
231
|
+
self.index_client.delete_index(self.index_name)
|
|
232
|
+
|
|
233
|
+
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
|
234
|
+
embedding_model = TextEmbedding(model_name=self.fastembed_model)
|
|
235
|
+
embedding = next(embedding_model.embed(data))
|
|
236
|
+
return embedding.tolist()
|
vanna/base/base.py
CHANGED
|
@@ -15,7 +15,7 @@ r"""
|
|
|
15
15
|
|
|
16
16
|
# Open-Source and Extending
|
|
17
17
|
|
|
18
|
-
Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](/docs/
|
|
18
|
+
Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](https://vanna.ai/docs/postgres-ollama-chromadb/).
|
|
19
19
|
|
|
20
20
|
The following is an example of where various functions are implemented in the codebase when using the default "local" version of Vanna. `vanna.base.VannaBase` is the base class which provides a `vanna.base.VannaBase.ask` and `vanna.base.VannaBase.train` function. Those rely on abstract methods which are implemented in the subclasses `vanna.openai_chat.OpenAI_Chat` and `vanna.chromadb_vector.ChromaDB_VectorStore`. `vanna.openai_chat.OpenAI_Chat` uses the OpenAI API to generate SQL and Plotly code. `vanna.chromadb_vector.ChromaDB_VectorStore` uses ChromaDB to store training data and generate embeddings.
|
|
21
21
|
|
|
@@ -256,6 +256,33 @@ class VannaBase(ABC):
|
|
|
256
256
|
|
|
257
257
|
return False
|
|
258
258
|
|
|
259
|
+
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
|
260
|
+
"""
|
|
261
|
+
**Example:**
|
|
262
|
+
```python
|
|
263
|
+
rewritten_question = vn.generate_rewritten_question("Who are the top 5 customers by sales?", "Show me their email addresses")
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
Generate a rewritten question by combining the last question and the new question if they are related. If the new question is self-contained and not related to the last question, return the new question.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
last_question (str): The previous question that was asked.
|
|
270
|
+
new_question (str): The new question to be combined with the last question.
|
|
271
|
+
**kwargs: Additional keyword arguments.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
str: The combined question if related, otherwise the new question.
|
|
275
|
+
"""
|
|
276
|
+
if last_question is None:
|
|
277
|
+
return new_question
|
|
278
|
+
|
|
279
|
+
prompt = [
|
|
280
|
+
self.system_message("Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
|
|
281
|
+
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
return self.submit_prompt(prompt=prompt, **kwargs)
|
|
285
|
+
|
|
259
286
|
def generate_followup_questions(
|
|
260
287
|
self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
|
|
261
288
|
) -> list:
|
|
@@ -840,6 +867,7 @@ class VannaBase(ABC):
|
|
|
840
867
|
port: int = None,
|
|
841
868
|
**kwargs
|
|
842
869
|
):
|
|
870
|
+
|
|
843
871
|
"""
|
|
844
872
|
Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
|
|
845
873
|
**Example:**
|
|
@@ -913,26 +941,44 @@ class VannaBase(ABC):
|
|
|
913
941
|
except psycopg2.Error as e:
|
|
914
942
|
raise ValidationError(e)
|
|
915
943
|
|
|
944
|
+
def connect_to_db():
|
|
945
|
+
return psycopg2.connect(host=host, dbname=dbname,
|
|
946
|
+
user=user, password=password, port=port, **kwargs)
|
|
947
|
+
|
|
948
|
+
|
|
916
949
|
def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
950
|
+
conn = None
|
|
951
|
+
try:
|
|
952
|
+
conn = connect_to_db() # Initial connection attempt
|
|
953
|
+
cs = conn.cursor()
|
|
954
|
+
cs.execute(sql)
|
|
955
|
+
results = cs.fetchall()
|
|
922
956
|
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
)
|
|
927
|
-
return df
|
|
957
|
+
# Create a pandas dataframe from the results
|
|
958
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
|
|
959
|
+
return df
|
|
928
960
|
|
|
929
|
-
|
|
961
|
+
except psycopg2.InterfaceError as e:
|
|
962
|
+
# Attempt to reconnect and retry the operation
|
|
963
|
+
if conn:
|
|
964
|
+
conn.close() # Ensure any existing connection is closed
|
|
965
|
+
conn = connect_to_db()
|
|
966
|
+
cs = conn.cursor()
|
|
967
|
+
cs.execute(sql)
|
|
968
|
+
results = cs.fetchall()
|
|
969
|
+
|
|
970
|
+
# Create a pandas dataframe from the results
|
|
971
|
+
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
|
|
972
|
+
return df
|
|
973
|
+
|
|
974
|
+
except psycopg2.Error as e:
|
|
975
|
+
if conn:
|
|
930
976
|
conn.rollback()
|
|
931
977
|
raise ValidationError(e)
|
|
932
978
|
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
979
|
+
except Exception as e:
|
|
980
|
+
conn.rollback()
|
|
981
|
+
raise e
|
|
936
982
|
|
|
937
983
|
self.dialect = "PostgreSQL"
|
|
938
984
|
self.run_sql_is_set = True
|
|
@@ -1279,7 +1325,6 @@ class VannaBase(ABC):
|
|
|
1279
1325
|
job = conn.query(sql)
|
|
1280
1326
|
df = job.result().to_dataframe()
|
|
1281
1327
|
return df
|
|
1282
|
-
|
|
1283
1328
|
return None
|
|
1284
1329
|
|
|
1285
1330
|
self.dialect = "BigQuery SQL"
|
|
@@ -1666,7 +1711,7 @@ class VannaBase(ABC):
|
|
|
1666
1711
|
|
|
1667
1712
|
if self.run_sql_is_set is False:
|
|
1668
1713
|
print(
|
|
1669
|
-
"If you want to run the SQL query, connect to a database first.
|
|
1714
|
+
"If you want to run the SQL query, connect to a database first."
|
|
1670
1715
|
)
|
|
1671
1716
|
|
|
1672
1717
|
if print_results:
|
vanna/flask/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ import sys
|
|
|
5
5
|
import uuid
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from functools import wraps
|
|
8
|
+
import importlib.metadata
|
|
8
9
|
|
|
9
10
|
import flask
|
|
10
11
|
import requests
|
|
@@ -353,6 +354,30 @@ class VannaFlaskAPI:
|
|
|
353
354
|
}
|
|
354
355
|
)
|
|
355
356
|
|
|
357
|
+
@self.flask_app.route("/api/v0/generate_rewritten_question", methods=["GET"])
|
|
358
|
+
@self.requires_auth
|
|
359
|
+
def generate_rewritten_question(user: any):
|
|
360
|
+
"""
|
|
361
|
+
Generate a rewritten question
|
|
362
|
+
---
|
|
363
|
+
parameters:
|
|
364
|
+
- name: last_question
|
|
365
|
+
in: query
|
|
366
|
+
type: string
|
|
367
|
+
required: true
|
|
368
|
+
- name: new_question
|
|
369
|
+
in: query
|
|
370
|
+
type: string
|
|
371
|
+
required: true
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
last_question = flask.request.args.get("last_question")
|
|
375
|
+
new_question = flask.request.args.get("new_question")
|
|
376
|
+
|
|
377
|
+
rewritten_question = self.vn.generate_rewritten_question(last_question, new_question)
|
|
378
|
+
|
|
379
|
+
return jsonify({"type": "rewritten_question", "question": rewritten_question})
|
|
380
|
+
|
|
356
381
|
@self.flask_app.route("/api/v0/get_function", methods=["GET"])
|
|
357
382
|
@self.requires_auth
|
|
358
383
|
def get_function(user: any):
|
|
@@ -1212,6 +1237,7 @@ class VannaFlaskApp(VannaFlaskAPI):
|
|
|
1212
1237
|
self.config["followup_questions"] = followup_questions
|
|
1213
1238
|
self.config["summarization"] = summarization
|
|
1214
1239
|
self.config["function_generation"] = function_generation and hasattr(vn, "get_function")
|
|
1240
|
+
self.config["version"] = importlib.metadata.version('vanna')
|
|
1215
1241
|
|
|
1216
1242
|
self.index_html_path = index_html_path
|
|
1217
1243
|
self.assets_folder = assets_folder
|