vanna 0.6.5__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {vanna-0.6.5 → vanna-0.7.0}/PKG-INFO +11 -3
  2. {vanna-0.6.5 → vanna-0.7.0}/pyproject.toml +4 -3
  3. vanna-0.7.0/src/vanna/azuresearch/__init__.py +1 -0
  4. vanna-0.7.0/src/vanna/azuresearch/azuresearch_vector.py +236 -0
  5. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/base/base.py +66 -26
  6. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/flask/__init__.py +28 -2
  7. vanna-0.7.0/src/vanna/flask/assets.py +58 -0
  8. vanna-0.7.0/src/vanna/google/__init__.py +2 -0
  9. vanna-0.7.0/src/vanna/google/bigquery_vector.py +230 -0
  10. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mistral/mistral.py +8 -6
  11. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/ollama/ollama.py +3 -1
  12. vanna-0.7.0/src/vanna/qianwen/QianwenAI_chat.py +133 -0
  13. vanna-0.7.0/src/vanna/qianwen/QianwenAI_embeddings.py +46 -0
  14. vanna-0.7.0/src/vanna/qianwen/__init__.py +2 -0
  15. vanna-0.6.5/src/vanna/flask/assets.py +0 -58
  16. vanna-0.6.5/src/vanna/google/__init__.py +0 -1
  17. {vanna-0.6.5 → vanna-0.7.0}/README.md +0 -0
  18. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/ZhipuAI/ZhipuAI_Chat.py +0 -0
  19. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/ZhipuAI/ZhipuAI_embeddings.py +0 -0
  20. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/ZhipuAI/__init__.py +0 -0
  21. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/__init__.py +0 -0
  22. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/advanced/__init__.py +0 -0
  23. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/anthropic/__init__.py +0 -0
  24. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/anthropic/anthropic_chat.py +0 -0
  25. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/base/__init__.py +0 -0
  26. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/bedrock/__init__.py +0 -0
  27. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/bedrock/bedrock_converse.py +0 -0
  28. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/chromadb/__init__.py +0 -0
  29. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/chromadb/chromadb_vector.py +0 -0
  30. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/exceptions/__init__.py +0 -0
  31. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/flask/auth.py +0 -0
  32. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/google/gemini_chat.py +0 -0
  33. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/hf/__init__.py +0 -0
  34. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/hf/hf.py +0 -0
  35. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/local.py +0 -0
  36. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/marqo/__init__.py +0 -0
  37. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/marqo/marqo.py +0 -0
  38. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/milvus/__init__.py +0 -0
  39. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/milvus/milvus_vector.py +0 -0
  40. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mistral/__init__.py +0 -0
  41. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mock/__init__.py +0 -0
  42. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mock/embedding.py +0 -0
  43. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mock/llm.py +0 -0
  44. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/mock/vectordb.py +0 -0
  45. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/ollama/__init__.py +0 -0
  46. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/openai/__init__.py +0 -0
  47. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/openai/openai_chat.py +0 -0
  48. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/openai/openai_embeddings.py +0 -0
  49. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/opensearch/__init__.py +0 -0
  50. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/opensearch/opensearch_vector.py +0 -0
  51. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/pinecone/__init__.py +0 -0
  52. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/pinecone/pinecone_vector.py +0 -0
  53. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/qdrant/__init__.py +0 -0
  54. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/qdrant/qdrant.py +0 -0
  55. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/qianfan/Qianfan_Chat.py +0 -0
  56. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/qianfan/Qianfan_embeddings.py +0 -0
  57. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/qianfan/__init__.py +0 -0
  58. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/remote.py +0 -0
  59. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/types/__init__.py +0 -0
  60. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/utils.py +0 -0
  61. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/vannadb/__init__.py +0 -0
  62. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/vannadb/vannadb_vector.py +0 -0
  63. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/vllm/__init__.py +0 -0
  64. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/vllm/vllm.py +0 -0
  65. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/weaviate/__init__.py +0 -0
  66. {vanna-0.6.5 → vanna-0.7.0}/src/vanna/weaviate/weaviate_vector.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.6.5
3
+ Version: 0.7.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -26,7 +26,7 @@ Requires-Dist: snowflake-connector-python ; extra == "all"
26
26
  Requires-Dist: duckdb ; extra == "all"
27
27
  Requires-Dist: openai ; extra == "all"
28
28
  Requires-Dist: qianfan ; extra == "all"
29
- Requires-Dist: mistralai ; extra == "all"
29
+ Requires-Dist: mistralai>=1.0.0 ; extra == "all"
30
30
  Requires-Dist: chromadb ; extra == "all"
31
31
  Requires-Dist: anthropic ; extra == "all"
32
32
  Requires-Dist: zhipuai ; extra == "all"
@@ -43,7 +43,14 @@ Requires-Dist: transformers ; extra == "all"
43
43
  Requires-Dist: pinecone-client ; extra == "all"
44
44
  Requires-Dist: pymilvus[model] ; extra == "all"
45
45
  Requires-Dist: weaviate-client ; extra == "all"
46
+ Requires-Dist: azure-search-documents ; extra == "all"
47
+ Requires-Dist: azure-identity ; extra == "all"
48
+ Requires-Dist: azure-common ; extra == "all"
46
49
  Requires-Dist: anthropic ; extra == "anthropic"
50
+ Requires-Dist: azure-search-documents ; extra == "azuresearch"
51
+ Requires-Dist: azure-identity ; extra == "azuresearch"
52
+ Requires-Dist: azure-common ; extra == "azuresearch"
53
+ Requires-Dist: fastembed ; extra == "azuresearch"
47
54
  Requires-Dist: boto3 ; extra == "bedrock"
48
55
  Requires-Dist: botocore ; extra == "bedrock"
49
56
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
@@ -56,7 +63,7 @@ Requires-Dist: google-cloud-aiplatform ; extra == "google"
56
63
  Requires-Dist: transformers ; extra == "hf"
57
64
  Requires-Dist: marqo ; extra == "marqo"
58
65
  Requires-Dist: pymilvus[model] ; extra == "milvus"
59
- Requires-Dist: mistralai ; extra == "mistralai"
66
+ Requires-Dist: mistralai>=1.0.0 ; extra == "mistralai"
60
67
  Requires-Dist: PyMySQL ; extra == "mysql"
61
68
  Requires-Dist: ollama ; extra == "ollama"
62
69
  Requires-Dist: httpx ; extra == "ollama"
@@ -79,6 +86,7 @@ Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
79
86
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
80
87
  Provides-Extra: all
81
88
  Provides-Extra: anthropic
89
+ Provides-Extra: azuresearch
82
90
  Provides-Extra: bedrock
83
91
  Provides-Extra: bigquery
84
92
  Provides-Extra: chromadb
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
4
4
 
5
5
  [project]
6
6
  name = "vanna"
7
- version = "0.6.5"
7
+ version = "0.7.0"
8
8
  authors = [
9
9
  { name="Zain Hoda", email="zain@vanna.ai" },
10
10
  ]
@@ -33,12 +33,12 @@ bigquery = ["google-cloud-bigquery"]
33
33
  snowflake = ["snowflake-connector-python"]
34
34
  duckdb = ["duckdb"]
35
35
  google = ["google-generativeai", "google-cloud-aiplatform"]
36
- all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
36
+ all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"]
37
37
  test = ["tox"]
38
38
  chromadb = ["chromadb"]
39
39
  openai = ["openai"]
40
40
  qianfan = ["qianfan"]
41
- mistralai = ["mistralai"]
41
+ mistralai = ["mistralai>=1.0.0"]
42
42
  anthropic = ["anthropic"]
43
43
  gemini = ["google-generativeai"]
44
44
  marqo = ["marqo"]
@@ -52,3 +52,4 @@ hf = ["transformers"]
52
52
  milvus = ["pymilvus[model]"]
53
53
  bedrock = ["boto3", "botocore"]
54
54
  weaviate = ["weaviate-client"]
55
+ azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
@@ -0,0 +1 @@
1
+ from .azuresearch_vector import AzureAISearch_VectorStore
@@ -0,0 +1,236 @@
1
+ import ast
2
+ import json
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ from azure.core.credentials import AzureKeyCredential
7
+ from azure.search.documents import SearchClient
8
+ from azure.search.documents.indexes import SearchIndexClient
9
+ from azure.search.documents.indexes.models import (
10
+ ExhaustiveKnnAlgorithmConfiguration,
11
+ ExhaustiveKnnParameters,
12
+ SearchableField,
13
+ SearchField,
14
+ SearchFieldDataType,
15
+ SearchIndex,
16
+ VectorSearch,
17
+ VectorSearchAlgorithmKind,
18
+ VectorSearchAlgorithmMetric,
19
+ VectorSearchProfile,
20
+ )
21
+ from azure.search.documents.models import VectorFilterMode, VectorizedQuery
22
+ from fastembed import TextEmbedding
23
+
24
+ from ..base import VannaBase
25
+ from ..utils import deterministic_uuid
26
+
27
+
28
+ class AzureAISearch_VectorStore(VannaBase):
29
+ """
30
+ AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.
31
+
32
+ Args:
33
+ config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
34
+ - azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
35
+ - azure_search_api_key (str): Azure Search API key.
36
+ - dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
37
+ - fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
38
+ - index_name (str, optional): Name of the index. Defaults to "vanna-index".
39
+ - n_results (int, optional): Number of results to return. Defaults to 10.
40
+ - n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
41
+ - n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
42
+ - n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.
43
+
44
+ Raises:
45
+ ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
46
+ """
47
+ def __init__(self, config=None):
48
+ VannaBase.__init__(self, config=config)
49
+
50
+ self.config = config or None
51
+
52
+ if config is None:
53
+ raise ValueError(
54
+ "config is required, pass an API key, 'azure_search_api_key', in the config."
55
+ )
56
+
57
+ azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net")
58
+ azure_search_api_key = config.get("azure_search_api_key")
59
+
60
+ self.dimensions = config.get("dimensions", 384)
61
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
62
+
63
+ self.index_name = config.get("index_name", "vanna-index")
64
+
65
+ self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
66
+ self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
67
+ self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
68
+
69
+ if not azure_search_api_key:
70
+ raise ValueError(
71
+ "'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
72
+ )
73
+
74
+ self.index_client = SearchIndexClient(
75
+ endpoint=azure_search_endpoint,
76
+ credential=AzureKeyCredential(azure_search_api_key)
77
+ )
78
+
79
+ self.search_client = SearchClient(
80
+ endpoint=azure_search_endpoint,
81
+ index_name=self.index_name,
82
+ credential=AzureKeyCredential(azure_search_api_key)
83
+ )
84
+
85
+ if self.index_name not in self._get_indexes():
86
+ self._create_index()
87
+
88
+ def _create_index(self) -> bool:
89
+ fields = [
90
+ SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
91
+ SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True),
92
+ SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True),
93
+ SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"),
94
+ ]
95
+
96
+ vector_search = VectorSearch(
97
+ algorithms=[
98
+ ExhaustiveKnnAlgorithmConfiguration(
99
+ name="ExhaustiveKnn",
100
+ kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
101
+ parameters=ExhaustiveKnnParameters(
102
+ metric=VectorSearchAlgorithmMetric.COSINE
103
+ )
104
+ )
105
+ ],
106
+ profiles=[
107
+ VectorSearchProfile(
108
+ name="ExhaustiveKnnProfile",
109
+ algorithm_configuration_name="ExhaustiveKnn",
110
+ )
111
+ ]
112
+ )
113
+
114
+ index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
115
+ result = self.index_client.create_or_update_index(index)
116
+ print(f'{result.name} created')
117
+
118
+ def _get_indexes(self) -> list:
119
+ return [index for index in self.index_client.list_index_names()]
120
+
121
+ def add_ddl(self, ddl: str) -> str:
122
+ id = deterministic_uuid(ddl) + "-ddl"
123
+ document = {
124
+ "id": id,
125
+ "document": ddl,
126
+ "type": "ddl",
127
+ "document_vector": self.generate_embedding(ddl)
128
+ }
129
+ self.search_client.upload_documents(documents=[document])
130
+ return id
131
+
132
+ def add_documentation(self, doc: str) -> str:
133
+ id = deterministic_uuid(doc) + "-doc"
134
+ document = {
135
+ "id": id,
136
+ "document": doc,
137
+ "type": "doc",
138
+ "document_vector": self.generate_embedding(doc)
139
+ }
140
+ self.search_client.upload_documents(documents=[document])
141
+ return id
142
+
143
+ def add_question_sql(self, question: str, sql: str) -> str:
144
+ question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
145
+ id = deterministic_uuid(question_sql_json) + "-sql"
146
+ document = {
147
+ "id": id,
148
+ "document": question_sql_json,
149
+ "type": "sql",
150
+ "document_vector": self.generate_embedding(question_sql_json)
151
+ }
152
+ self.search_client.upload_documents(documents=[document])
153
+ return id
154
+
155
+ def get_related_ddl(self, text: str) -> List[str]:
156
+ result = []
157
+ vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
158
+ df = pd.DataFrame(
159
+ self.search_client.search(
160
+ top=self.n_results_ddl,
161
+ vector_queries=[vector_query],
162
+ select=["id", "document", "type"],
163
+ filter=f"type eq 'ddl'"
164
+ )
165
+ )
166
+
167
+ if len(df):
168
+ result = df["document"].tolist()
169
+ return result
170
+
171
+ def get_related_documentation(self, text: str) -> List[str]:
172
+ result = []
173
+ vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
174
+
175
+ df = pd.DataFrame(
176
+ self.search_client.search(
177
+ top=self.n_results_documentation,
178
+ vector_queries=[vector_query],
179
+ select=["id", "document", "type"],
180
+ filter=f"type eq 'doc'",
181
+ vector_filter_mode=VectorFilterMode.PRE_FILTER
182
+ )
183
+ )
184
+
185
+ if len(df):
186
+ result = df["document"].tolist()
187
+ return result
188
+
189
+ def get_similar_question_sql(self, text: str) -> List[str]:
190
+ result = []
191
+ # Vectorize the text
192
+ vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
193
+ df = pd.DataFrame(
194
+ self.search_client.search(
195
+ top=self.n_results_sql,
196
+ vector_queries=[vector_query],
197
+ select=["id", "document", "type"],
198
+ filter=f"type eq 'sql'"
199
+ )
200
+ )
201
+
202
+ if len(df): # Check if there is similar query and the result is not empty
203
+ result = [ast.literal_eval(element) for element in df["document"].tolist()]
204
+
205
+ return result
206
+
207
+ def get_training_data(self) -> List[str]:
208
+
209
+ search = self.search_client.search(
210
+ search_text="*",
211
+ select=['id', 'document', 'type'],
212
+ filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
213
+ ).by_page()
214
+
215
+ df = pd.DataFrame([item for page in search for item in page])
216
+
217
+ if len(df):
218
+ df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
219
+ df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
220
+ df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]
221
+
222
+ return df[["id", "question", "content", "type"]]
223
+
224
+ return pd.DataFrame()
225
+
226
+ def remove_training_data(self, id: str) -> bool:
227
+ result = self.search_client.delete_documents(documents=[{'id':id}])
228
+ return result[0].succeeded
229
+
230
+ def remove_index(self):
231
+ self.index_client.delete_index(self.index_name)
232
+
233
+ def generate_embedding(self, data: str, **kwargs) -> List[float]:
234
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
235
+ embedding = next(embedding_model.embed(data))
236
+ return embedding.tolist()
@@ -15,7 +15,7 @@ r"""
15
15
 
16
16
  # Open-Source and Extending
17
17
 
18
- Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](/docs/local.html).
18
+ Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](https://vanna.ai/docs/postgres-ollama-chromadb/).
19
19
 
20
20
  The following is an example of where various functions are implemented in the codebase when using the default "local" version of Vanna. `vanna.base.VannaBase` is the base class which provides a `vanna.base.VannaBase.ask` and `vanna.base.VannaBase.train` function. Those rely on abstract methods which are implemented in the subclasses `vanna.openai_chat.OpenAI_Chat` and `vanna.chromadb_vector.ChromaDB_VectorStore`. `vanna.openai_chat.OpenAI_Chat` uses the OpenAI API to generate SQL and Plotly code. `vanna.chromadb_vector.ChromaDB_VectorStore` uses ChromaDB to store training data and generate embeddings.
21
21
 
@@ -256,6 +256,33 @@ class VannaBase(ABC):
256
256
 
257
257
  return False
258
258
 
259
+ def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
260
+ """
261
+ **Example:**
262
+ ```python
263
+ rewritten_question = vn.generate_rewritten_question("Who are the top 5 customers by sales?", "Show me their email addresses")
264
+ ```
265
+
266
+ Generate a rewritten question by combining the last question and the new question if they are related. If the new question is self-contained and not related to the last question, return the new question.
267
+
268
+ Args:
269
+ last_question (str): The previous question that was asked.
270
+ new_question (str): The new question to be combined with the last question.
271
+ **kwargs: Additional keyword arguments.
272
+
273
+ Returns:
274
+ str: The combined question if related, otherwise the new question.
275
+ """
276
+ if last_question is None:
277
+ return new_question
278
+
279
+ prompt = [
280
+ self.system_message("Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
281
+ self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
282
+ ]
283
+
284
+ return self.submit_prompt(prompt=prompt, **kwargs)
285
+
259
286
  def generate_followup_questions(
260
287
  self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
261
288
  ) -> list:
@@ -437,7 +464,7 @@ class VannaBase(ABC):
437
464
  pass
438
465
 
439
466
  @abstractmethod
440
- def remove_training_data(id: str, **kwargs) -> bool:
467
+ def remove_training_data(self, id: str, **kwargs) -> bool:
441
468
  """
442
469
  Example:
443
470
  ```python
@@ -840,6 +867,7 @@ class VannaBase(ABC):
840
867
  port: int = None,
841
868
  **kwargs
842
869
  ):
870
+
843
871
  """
844
872
  Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
845
873
  **Example:**
@@ -913,26 +941,44 @@ class VannaBase(ABC):
913
941
  except psycopg2.Error as e:
914
942
  raise ValidationError(e)
915
943
 
944
+ def connect_to_db():
945
+ return psycopg2.connect(host=host, dbname=dbname,
946
+ user=user, password=password, port=port, **kwargs)
947
+
948
+
916
949
  def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
917
- if conn:
918
- try:
919
- cs = conn.cursor()
920
- cs.execute(sql)
921
- results = cs.fetchall()
950
+ conn = None
951
+ try:
952
+ conn = connect_to_db() # Initial connection attempt
953
+ cs = conn.cursor()
954
+ cs.execute(sql)
955
+ results = cs.fetchall()
922
956
 
923
- # Create a pandas dataframe from the results
924
- df = pd.DataFrame(
925
- results, columns=[desc[0] for desc in cs.description]
926
- )
927
- return df
957
+ # Create a pandas dataframe from the results
958
+ df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
959
+ return df
928
960
 
929
- except psycopg2.Error as e:
961
+ except psycopg2.InterfaceError as e:
962
+ # Attempt to reconnect and retry the operation
963
+ if conn:
964
+ conn.close() # Ensure any existing connection is closed
965
+ conn = connect_to_db()
966
+ cs = conn.cursor()
967
+ cs.execute(sql)
968
+ results = cs.fetchall()
969
+
970
+ # Create a pandas dataframe from the results
971
+ df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
972
+ return df
973
+
974
+ except psycopg2.Error as e:
975
+ if conn:
930
976
  conn.rollback()
931
977
  raise ValidationError(e)
932
978
 
933
- except Exception as e:
934
- conn.rollback()
935
- raise e
979
+ except Exception as e:
980
+ conn.rollback()
981
+ raise e
936
982
 
937
983
  self.dialect = "PostgreSQL"
938
984
  self.run_sql_is_set = True
@@ -1276,15 +1322,9 @@ class VannaBase(ABC):
1276
1322
 
1277
1323
  def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
1278
1324
  if conn:
1279
- try:
1280
- job = conn.query(sql)
1281
- df = job.result().to_dataframe()
1282
- return df
1283
- except GoogleAPIError as error:
1284
- errors = []
1285
- for error in error.errors:
1286
- errors.append(error["message"])
1287
- raise errors
1325
+ job = conn.query(sql)
1326
+ df = job.result().to_dataframe()
1327
+ return df
1288
1328
  return None
1289
1329
 
1290
1330
  self.dialect = "BigQuery SQL"
@@ -1671,7 +1711,7 @@ class VannaBase(ABC):
1671
1711
 
1672
1712
  if self.run_sql_is_set is False:
1673
1713
  print(
1674
- "If you want to run the SQL query, connect to a database first. See here: https://vanna.ai/docs/databases.html"
1714
+ "If you want to run the SQL query, connect to a database first."
1675
1715
  )
1676
1716
 
1677
1717
  if print_results:
@@ -5,6 +5,7 @@ import sys
5
5
  import uuid
6
6
  from abc import ABC, abstractmethod
7
7
  from functools import wraps
8
+ import importlib.metadata
8
9
 
9
10
  import flask
10
11
  import requests
@@ -12,9 +13,9 @@ from flasgger import Swagger
12
13
  from flask import Flask, Response, jsonify, request, send_from_directory
13
14
  from flask_sock import Sock
14
15
 
16
+ from ..base import VannaBase
15
17
  from .assets import css_content, html_content, js_content
16
18
  from .auth import AuthInterface, NoAuth
17
- from ..base import VannaBase
18
19
 
19
20
 
20
21
  class Cache(ABC):
@@ -353,6 +354,30 @@ class VannaFlaskAPI:
353
354
  }
354
355
  )
355
356
 
357
+ @self.flask_app.route("/api/v0/generate_rewritten_question", methods=["GET"])
358
+ @self.requires_auth
359
+ def generate_rewritten_question(user: any):
360
+ """
361
+ Generate a rewritten question
362
+ ---
363
+ parameters:
364
+ - name: last_question
365
+ in: query
366
+ type: string
367
+ required: true
368
+ - name: new_question
369
+ in: query
370
+ type: string
371
+ required: true
372
+ """
373
+
374
+ last_question = flask.request.args.get("last_question")
375
+ new_question = flask.request.args.get("new_question")
376
+
377
+ rewritten_question = self.vn.generate_rewritten_question(last_question, new_question)
378
+
379
+ return jsonify({"type": "rewritten_question", "question": rewritten_question})
380
+
356
381
  @self.flask_app.route("/api/v0/get_function", methods=["GET"])
357
382
  @self.requires_auth
358
383
  def get_function(user: any):
@@ -1211,7 +1236,8 @@ class VannaFlaskApp(VannaFlaskAPI):
1211
1236
  self.config["ask_results_correct"] = ask_results_correct
1212
1237
  self.config["followup_questions"] = followup_questions
1213
1238
  self.config["summarization"] = summarization
1214
- self.config["function_generation"] = function_generation
1239
+ self.config["function_generation"] = function_generation and hasattr(vn, "get_function")
1240
+ self.config["version"] = importlib.metadata.version('vanna')
1215
1241
 
1216
1242
  self.index_html_path = index_html_path
1217
1243
  self.assets_folder = assets_folder