vanna 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/base/base.py +15 -14
- vanna/chromadb/chromadb_vector.py +34 -28
- vanna/flask.py +331 -0
- vanna/marqo/marqo.py +65 -61
- vanna/remote.py +2 -2
- vanna/vannadb/vannadb_vector.py +18 -15
- {vanna-0.0.31.dist-info → vanna-0.0.33.dist-info}/METADATA +4 -2
- {vanna-0.0.31.dist-info → vanna-0.0.33.dist-info}/RECORD +9 -8
- {vanna-0.0.31.dist-info → vanna-0.0.33.dist-info}/WHEEL +0 -0
vanna/base/base.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
import re
|
|
3
4
|
import sqlite3
|
|
4
5
|
import traceback
|
|
5
|
-
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
from typing import List, Tuple, Union
|
|
8
8
|
from urllib.parse import urlparse
|
|
@@ -12,7 +12,6 @@ import plotly
|
|
|
12
12
|
import plotly.express as px
|
|
13
13
|
import plotly.graph_objects as go
|
|
14
14
|
import requests
|
|
15
|
-
import re
|
|
16
15
|
|
|
17
16
|
from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
|
|
18
17
|
from ..types import TrainingPlan, TrainingPlanItem
|
|
@@ -50,8 +49,8 @@ class VannaBase(ABC):
|
|
|
50
49
|
**kwargs,
|
|
51
50
|
)
|
|
52
51
|
llm_response = self.submit_prompt(prompt, **kwargs)
|
|
53
|
-
|
|
54
|
-
numbers_removed = re.sub(r
|
|
52
|
+
|
|
53
|
+
numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
|
|
55
54
|
return numbers_removed.split("\n")
|
|
56
55
|
|
|
57
56
|
def generate_questions(self, **kwargs) -> list[str]:
|
|
@@ -65,7 +64,7 @@ class VannaBase(ABC):
|
|
|
65
64
|
"""
|
|
66
65
|
question_sql = self.get_similar_question_sql(question="", **kwargs)
|
|
67
66
|
|
|
68
|
-
return [q[
|
|
67
|
+
return [q["question"] for q in question_sql]
|
|
69
68
|
|
|
70
69
|
# ----------------- Use Any Embeddings API ----------------- #
|
|
71
70
|
@abstractmethod
|
|
@@ -94,7 +93,7 @@ class VannaBase(ABC):
|
|
|
94
93
|
pass
|
|
95
94
|
|
|
96
95
|
@abstractmethod
|
|
97
|
-
def add_documentation(self,
|
|
96
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
98
97
|
pass
|
|
99
98
|
|
|
100
99
|
@abstractmethod
|
|
@@ -120,12 +119,12 @@ class VannaBase(ABC):
|
|
|
120
119
|
|
|
121
120
|
@abstractmethod
|
|
122
121
|
def get_followup_questions_prompt(
|
|
123
|
-
self,
|
|
124
|
-
question: str,
|
|
122
|
+
self,
|
|
123
|
+
question: str,
|
|
125
124
|
question_sql_list: list,
|
|
126
125
|
ddl_list: list,
|
|
127
|
-
doc_list: list,
|
|
128
|
-
**kwargs
|
|
126
|
+
doc_list: list,
|
|
127
|
+
**kwargs,
|
|
129
128
|
):
|
|
130
129
|
pass
|
|
131
130
|
|
|
@@ -248,7 +247,7 @@ class VannaBase(ABC):
|
|
|
248
247
|
url = path
|
|
249
248
|
|
|
250
249
|
# Connect to the database
|
|
251
|
-
conn = sqlite3.connect(url)
|
|
250
|
+
conn = sqlite3.connect(url, check_same_thread=False)
|
|
252
251
|
|
|
253
252
|
def run_sql_sqlite(sql: str):
|
|
254
253
|
return pd.read_sql_query(sql, conn)
|
|
@@ -829,9 +828,11 @@ class VannaBase(ABC):
|
|
|
829
828
|
fig = ldict.get("fig", None)
|
|
830
829
|
except Exception as e:
|
|
831
830
|
# Inspect data types
|
|
832
|
-
numeric_cols = df.select_dtypes(include=[
|
|
833
|
-
categorical_cols = df.select_dtypes(
|
|
834
|
-
|
|
831
|
+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
|
832
|
+
categorical_cols = df.select_dtypes(
|
|
833
|
+
include=["object", "category"]
|
|
834
|
+
).columns.tolist()
|
|
835
|
+
|
|
835
836
|
# Decision-making for plot type
|
|
836
837
|
if len(numeric_cols) >= 2:
|
|
837
838
|
# Use the first two numeric columns for a scatter plot
|
|
@@ -3,9 +3,9 @@ import uuid
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
|
|
5
5
|
import chromadb
|
|
6
|
+
import pandas as pd
|
|
6
7
|
from chromadb.config import Settings
|
|
7
8
|
from chromadb.utils import embedding_functions
|
|
8
|
-
import pandas as pd
|
|
9
9
|
|
|
10
10
|
from ..base import VannaBase
|
|
11
11
|
|
|
@@ -47,7 +47,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
47
47
|
"sql": sql,
|
|
48
48
|
}
|
|
49
49
|
)
|
|
50
|
-
id = str(uuid.uuid4())+"-sql"
|
|
50
|
+
id = str(uuid.uuid4()) + "-sql"
|
|
51
51
|
self.sql_collection.add(
|
|
52
52
|
documents=question_sql_json,
|
|
53
53
|
embeddings=self.generate_embedding(question_sql_json),
|
|
@@ -57,7 +57,7 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
57
57
|
return id
|
|
58
58
|
|
|
59
59
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
60
|
-
id = str(uuid.uuid4())+"-ddl"
|
|
60
|
+
id = str(uuid.uuid4()) + "-ddl"
|
|
61
61
|
self.ddl_collection.add(
|
|
62
62
|
documents=ddl,
|
|
63
63
|
embeddings=self.generate_embedding(ddl),
|
|
@@ -65,11 +65,11 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
65
65
|
)
|
|
66
66
|
return id
|
|
67
67
|
|
|
68
|
-
def add_documentation(self,
|
|
69
|
-
id = str(uuid.uuid4())+"-doc"
|
|
68
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
69
|
+
id = str(uuid.uuid4()) + "-doc"
|
|
70
70
|
self.documentation_collection.add(
|
|
71
|
-
documents=
|
|
72
|
-
embeddings=self.generate_embedding(
|
|
71
|
+
documents=documentation,
|
|
72
|
+
embeddings=self.generate_embedding(documentation),
|
|
73
73
|
ids=id,
|
|
74
74
|
)
|
|
75
75
|
return id
|
|
@@ -81,15 +81,17 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
81
81
|
|
|
82
82
|
if sql_data is not None:
|
|
83
83
|
# Extract the documents and ids
|
|
84
|
-
documents = [json.loads(doc) for doc in sql_data[
|
|
85
|
-
ids = sql_data[
|
|
84
|
+
documents = [json.loads(doc) for doc in sql_data["documents"]]
|
|
85
|
+
ids = sql_data["ids"]
|
|
86
86
|
|
|
87
87
|
# Create a DataFrame
|
|
88
|
-
df_sql = pd.DataFrame(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
88
|
+
df_sql = pd.DataFrame(
|
|
89
|
+
{
|
|
90
|
+
"id": ids,
|
|
91
|
+
"question": [doc["question"] for doc in documents],
|
|
92
|
+
"content": [doc["sql"] for doc in documents],
|
|
93
|
+
}
|
|
94
|
+
)
|
|
93
95
|
|
|
94
96
|
df_sql["training_data_type"] = "sql"
|
|
95
97
|
|
|
@@ -99,15 +101,17 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
99
101
|
|
|
100
102
|
if ddl_data is not None:
|
|
101
103
|
# Extract the documents and ids
|
|
102
|
-
documents = [doc for doc in ddl_data[
|
|
103
|
-
ids = ddl_data[
|
|
104
|
+
documents = [doc for doc in ddl_data["documents"]]
|
|
105
|
+
ids = ddl_data["ids"]
|
|
104
106
|
|
|
105
107
|
# Create a DataFrame
|
|
106
|
-
df_ddl = pd.DataFrame(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
df_ddl = pd.DataFrame(
|
|
109
|
+
{
|
|
110
|
+
"id": ids,
|
|
111
|
+
"question": [None for doc in documents],
|
|
112
|
+
"content": [doc for doc in documents],
|
|
113
|
+
}
|
|
114
|
+
)
|
|
111
115
|
|
|
112
116
|
df_ddl["training_data_type"] = "ddl"
|
|
113
117
|
|
|
@@ -117,15 +121,17 @@ class ChromaDB_VectorStore(VannaBase):
|
|
|
117
121
|
|
|
118
122
|
if doc_data is not None:
|
|
119
123
|
# Extract the documents and ids
|
|
120
|
-
documents = [doc for doc in doc_data[
|
|
121
|
-
ids = doc_data[
|
|
124
|
+
documents = [doc for doc in doc_data["documents"]]
|
|
125
|
+
ids = doc_data["ids"]
|
|
122
126
|
|
|
123
127
|
# Create a DataFrame
|
|
124
|
-
df_doc = pd.DataFrame(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
df_doc = pd.DataFrame(
|
|
129
|
+
{
|
|
130
|
+
"id": ids,
|
|
131
|
+
"question": [None for doc in documents],
|
|
132
|
+
"content": [doc for doc in documents],
|
|
133
|
+
}
|
|
134
|
+
)
|
|
129
135
|
|
|
130
136
|
df_doc["training_data_type"] = "documentation"
|
|
131
137
|
|
vanna/flask.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import flask
|
|
2
|
+
from flask import Flask, Response, jsonify, request
|
|
3
|
+
import logging
|
|
4
|
+
import requests
|
|
5
|
+
from functools import wraps
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
import uuid
|
|
9
|
+
|
|
10
|
+
class Cache(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def generate_id(self, *args, **kwargs):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def get(self, id, field):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def get_all(self, field_list) -> list:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def set(self, id, field, value):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def delete(self, id):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MemoryCache(Cache):
|
|
33
|
+
def __init__(self):
|
|
34
|
+
self.cache = {}
|
|
35
|
+
|
|
36
|
+
def generate_id(self, *args, **kwargs):
|
|
37
|
+
return str(uuid.uuid4())
|
|
38
|
+
|
|
39
|
+
def set(self, id, field, value):
|
|
40
|
+
if id not in self.cache:
|
|
41
|
+
self.cache[id] = {}
|
|
42
|
+
|
|
43
|
+
self.cache[id][field] = value
|
|
44
|
+
|
|
45
|
+
def get(self, id, field):
|
|
46
|
+
if id not in self.cache:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
if field not in self.cache[id]:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
return self.cache[id][field]
|
|
53
|
+
|
|
54
|
+
def get_all(self, field_list) -> list:
|
|
55
|
+
return [
|
|
56
|
+
{
|
|
57
|
+
"id": id,
|
|
58
|
+
**{
|
|
59
|
+
field: self.get(id=id, field=field)
|
|
60
|
+
for field in field_list
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
for id in self.cache
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
def delete(self, id):
|
|
67
|
+
if id in self.cache:
|
|
68
|
+
del self.cache[id]
|
|
69
|
+
|
|
70
|
+
class VannaFlaskApp:
|
|
71
|
+
flask_app = None
|
|
72
|
+
|
|
73
|
+
def requires_cache(self, fields):
|
|
74
|
+
def decorator(f):
|
|
75
|
+
@wraps(f)
|
|
76
|
+
def decorated(*args, **kwargs):
|
|
77
|
+
id = request.args.get('id')
|
|
78
|
+
|
|
79
|
+
if id is None:
|
|
80
|
+
return jsonify({"type": "error", "error": "No id provided"})
|
|
81
|
+
|
|
82
|
+
for field in fields:
|
|
83
|
+
if self.cache.get(id=id, field=field) is None:
|
|
84
|
+
return jsonify({"type": "error", "error": f"No {field} found"})
|
|
85
|
+
|
|
86
|
+
field_values = {field: self.cache.get(id=id, field=field) for field in fields}
|
|
87
|
+
|
|
88
|
+
# Add the id to the field_values
|
|
89
|
+
field_values['id'] = id
|
|
90
|
+
|
|
91
|
+
return f(*args, **field_values, **kwargs)
|
|
92
|
+
return decorated
|
|
93
|
+
return decorator
|
|
94
|
+
|
|
95
|
+
def __init__(self, vn, cache: Cache = MemoryCache()):
|
|
96
|
+
self.flask_app = Flask(__name__)
|
|
97
|
+
self.vn = vn
|
|
98
|
+
self.cache = cache
|
|
99
|
+
|
|
100
|
+
log = logging.getLogger('werkzeug')
|
|
101
|
+
log.setLevel(logging.ERROR)
|
|
102
|
+
|
|
103
|
+
@self.flask_app.route('/api/v0/generate_questions', methods=['GET'])
|
|
104
|
+
def generate_questions():
|
|
105
|
+
# If self has an _model attribute and model=='chinook'
|
|
106
|
+
if hasattr(self.vn, '_model') and self.vn._model == 'chinook':
|
|
107
|
+
return jsonify({
|
|
108
|
+
"type": "question_list",
|
|
109
|
+
"questions": ['What are the top 10 artists by sales?', 'What are the total sales per year by country?', 'Who is the top selling artist in each genre? Show the sales numbers.', 'How do the employees rank in terms of sales performance?', 'Which 5 cities have the most customers?'],
|
|
110
|
+
"header": "Here are some questions you can ask:"
|
|
111
|
+
})
|
|
112
|
+
|
|
113
|
+
@self.flask_app.route('/api/v0/generate_sql', methods=['GET'])
|
|
114
|
+
def generate_sql():
|
|
115
|
+
question = flask.request.args.get('question')
|
|
116
|
+
|
|
117
|
+
if question is None:
|
|
118
|
+
return jsonify({"type": "error", "error": "No question provided"})
|
|
119
|
+
|
|
120
|
+
id = self.cache.generate_id(question=question)
|
|
121
|
+
sql = vn.generate_sql(question=question)
|
|
122
|
+
|
|
123
|
+
self.cache.set(id=id, field='question', value=question)
|
|
124
|
+
self.cache.set(id=id, field='sql', value=sql)
|
|
125
|
+
|
|
126
|
+
return jsonify(
|
|
127
|
+
{
|
|
128
|
+
"type": "sql",
|
|
129
|
+
"id": id,
|
|
130
|
+
"text": sql,
|
|
131
|
+
})
|
|
132
|
+
|
|
133
|
+
@self.flask_app.route('/api/v0/run_sql', methods=['GET'])
|
|
134
|
+
@self.requires_cache(['sql'])
|
|
135
|
+
def run_sql(id: str, sql: str):
|
|
136
|
+
try:
|
|
137
|
+
if not vn.run_sql_is_set:
|
|
138
|
+
return jsonify({"type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries."})
|
|
139
|
+
|
|
140
|
+
df = vn.run_sql(sql=sql)
|
|
141
|
+
|
|
142
|
+
cache.set(id=id, field='df', value=df)
|
|
143
|
+
|
|
144
|
+
return jsonify(
|
|
145
|
+
{
|
|
146
|
+
"type": "df",
|
|
147
|
+
"id": id,
|
|
148
|
+
"df": df.head(10).to_json(orient='records'),
|
|
149
|
+
})
|
|
150
|
+
|
|
151
|
+
except Exception as e:
|
|
152
|
+
return jsonify({"type": "error", "error": str(e)})
|
|
153
|
+
|
|
154
|
+
@self.flask_app.route('/api/v0/download_csv', methods=['GET'])
|
|
155
|
+
@self.requires_cache(['df'])
|
|
156
|
+
def download_csv(id: str, df):
|
|
157
|
+
csv = df.to_csv()
|
|
158
|
+
|
|
159
|
+
return Response(
|
|
160
|
+
csv,
|
|
161
|
+
mimetype="text/csv",
|
|
162
|
+
headers={"Content-disposition":
|
|
163
|
+
f"attachment; filename={id}.csv"})
|
|
164
|
+
|
|
165
|
+
@self.flask_app.route('/api/v0/generate_plotly_figure', methods=['GET'])
|
|
166
|
+
@self.requires_cache(['df', 'question', 'sql'])
|
|
167
|
+
def generate_plotly_figure(id: str, df, question, sql):
|
|
168
|
+
try:
|
|
169
|
+
code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
|
|
170
|
+
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
171
|
+
fig_json = fig.to_json()
|
|
172
|
+
|
|
173
|
+
cache.set(id=id, field='fig_json', value=fig_json)
|
|
174
|
+
|
|
175
|
+
return jsonify(
|
|
176
|
+
{
|
|
177
|
+
"type": "plotly_figure",
|
|
178
|
+
"id": id,
|
|
179
|
+
"fig": fig_json,
|
|
180
|
+
})
|
|
181
|
+
except Exception as e:
|
|
182
|
+
# Print the stack trace
|
|
183
|
+
import traceback
|
|
184
|
+
traceback.print_exc()
|
|
185
|
+
|
|
186
|
+
return jsonify({"type": "error", "error": str(e)})
|
|
187
|
+
|
|
188
|
+
@self.flask_app.route('/api/v0/get_training_data', methods=['GET'])
|
|
189
|
+
def get_training_data():
|
|
190
|
+
df = vn.get_training_data()
|
|
191
|
+
|
|
192
|
+
return jsonify(
|
|
193
|
+
{
|
|
194
|
+
"type": "df",
|
|
195
|
+
"id": "training_data",
|
|
196
|
+
"df": df.tail(25).to_json(orient='records'),
|
|
197
|
+
})
|
|
198
|
+
|
|
199
|
+
@self.flask_app.route('/api/v0/remove_training_data', methods=['POST'])
|
|
200
|
+
def remove_training_data():
|
|
201
|
+
# Get id from the JSON body
|
|
202
|
+
id = flask.request.json.get('id')
|
|
203
|
+
|
|
204
|
+
if id is None:
|
|
205
|
+
return jsonify({"type": "error", "error": "No id provided"})
|
|
206
|
+
|
|
207
|
+
if vn.remove_training_data(id=id):
|
|
208
|
+
return jsonify({"success": True})
|
|
209
|
+
else:
|
|
210
|
+
return jsonify({"type": "error", "error": "Couldn't remove training data"})
|
|
211
|
+
|
|
212
|
+
@self.flask_app.route('/api/v0/train', methods=['POST'])
|
|
213
|
+
def add_training_data():
|
|
214
|
+
question = flask.request.json.get('question')
|
|
215
|
+
sql = flask.request.json.get('sql')
|
|
216
|
+
ddl = flask.request.json.get('ddl')
|
|
217
|
+
documentation = flask.request.json.get('documentation')
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
|
221
|
+
|
|
222
|
+
return jsonify({"id": id})
|
|
223
|
+
except Exception as e:
|
|
224
|
+
print("TRAINING ERROR", e)
|
|
225
|
+
return jsonify({"type": "error", "error": str(e)})
|
|
226
|
+
|
|
227
|
+
@self.flask_app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
|
228
|
+
@self.requires_cache(['df', 'question'])
|
|
229
|
+
def generate_followup_questions(id: str, df, question):
|
|
230
|
+
followup_questions = []
|
|
231
|
+
# followup_questions = vn.generate_followup_questions(question=question, df=df)
|
|
232
|
+
# if followup_questions is not None and len(followup_questions) > 5:
|
|
233
|
+
# followup_questions = followup_questions[:5]
|
|
234
|
+
|
|
235
|
+
cache.set(id=id, field='followup_questions', value=followup_questions)
|
|
236
|
+
|
|
237
|
+
return jsonify(
|
|
238
|
+
{
|
|
239
|
+
"type": "question_list",
|
|
240
|
+
"id": id,
|
|
241
|
+
"questions": followup_questions,
|
|
242
|
+
"header": "Followup Questions can be enabled in a future version if you allow the LLM to 'see' your query results."
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
@self.flask_app.route('/api/v0/load_question', methods=['GET'])
|
|
246
|
+
@self.requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
|
|
247
|
+
def load_question(id: str, question, sql, df, fig_json, followup_questions):
|
|
248
|
+
try:
|
|
249
|
+
return jsonify(
|
|
250
|
+
{
|
|
251
|
+
"type": "question_cache",
|
|
252
|
+
"id": id,
|
|
253
|
+
"question": question,
|
|
254
|
+
"sql": sql,
|
|
255
|
+
"df": df.head(10).to_json(orient='records'),
|
|
256
|
+
"fig": fig_json,
|
|
257
|
+
"followup_questions": followup_questions,
|
|
258
|
+
})
|
|
259
|
+
|
|
260
|
+
except Exception as e:
|
|
261
|
+
return jsonify({"type": "error", "error": str(e)})
|
|
262
|
+
|
|
263
|
+
@self.flask_app.route('/api/v0/get_question_history', methods=['GET'])
|
|
264
|
+
def get_question_history():
|
|
265
|
+
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@self.flask_app.route('/api/v0/<path:catch_all>', methods=['GET', 'POST'])
|
|
269
|
+
def catch_all(catch_all):
|
|
270
|
+
return jsonify({"type": "error", "error": "The rest of the API is not ported yet."})
|
|
271
|
+
|
|
272
|
+
@self.flask_app.route('/assets/<path:filename>')
|
|
273
|
+
def proxy_assets(filename):
|
|
274
|
+
remote_url = f'https://vanna.ai/assets/{filename}'
|
|
275
|
+
response = requests.get(remote_url, stream=True)
|
|
276
|
+
|
|
277
|
+
# Check if the request to the remote URL was successful
|
|
278
|
+
if response.status_code == 200:
|
|
279
|
+
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
|
|
280
|
+
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
|
|
281
|
+
return Response(response.content, response.status_code, headers)
|
|
282
|
+
else:
|
|
283
|
+
return 'Error fetching file from remote server', response.status_code
|
|
284
|
+
|
|
285
|
+
# Proxy the /vanna.svg file to the remote server
|
|
286
|
+
@self.flask_app.route('/vanna.svg')
|
|
287
|
+
def proxy_vanna_svg():
|
|
288
|
+
remote_url = f'https://vanna.ai/img/vanna.svg'
|
|
289
|
+
response = requests.get(remote_url, stream=True)
|
|
290
|
+
|
|
291
|
+
# Check if the request to the remote URL was successful
|
|
292
|
+
if response.status_code == 200:
|
|
293
|
+
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
|
|
294
|
+
headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
|
|
295
|
+
return Response(response.content, response.status_code, headers)
|
|
296
|
+
else:
|
|
297
|
+
return 'Error fetching file from remote server', response.status_code
|
|
298
|
+
|
|
299
|
+
@self.flask_app.route('/', defaults={'path': ''})
|
|
300
|
+
@self.flask_app.route('/<path:path>')
|
|
301
|
+
def hello(path: str):
|
|
302
|
+
return """
|
|
303
|
+
<!doctype html>
|
|
304
|
+
<html lang="en">
|
|
305
|
+
<head>
|
|
306
|
+
<meta charset="UTF-8" />
|
|
307
|
+
<link rel="icon" type="image/svg+xml" href="/vanna.svg" />
|
|
308
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
309
|
+
<link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
|
|
310
|
+
<script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
|
|
311
|
+
<title>Vanna.AI</title>
|
|
312
|
+
<script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
|
|
313
|
+
<link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
|
|
314
|
+
</head>
|
|
315
|
+
<body class="bg-white dark:bg-slate-900">
|
|
316
|
+
<div id="app"></div>
|
|
317
|
+
</body>
|
|
318
|
+
</html>
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
def run(self):
|
|
322
|
+
try:
|
|
323
|
+
from google.colab import output
|
|
324
|
+
output.serve_kernel_port_as_window(8084)
|
|
325
|
+
from google.colab.output import eval_js
|
|
326
|
+
print("Your app is running at:")
|
|
327
|
+
print(eval_js("google.colab.kernel.proxyPort(8084)"))
|
|
328
|
+
except:
|
|
329
|
+
print("Your app is running at:")
|
|
330
|
+
print("http://localhost:8084")
|
|
331
|
+
self.flask_app.run(host='0.0.0.0', port=8084, debug=False)
|
vanna/marqo/marqo.py
CHANGED
|
@@ -3,7 +3,6 @@ import uuid
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
|
|
5
5
|
import marqo
|
|
6
|
-
|
|
7
6
|
import pandas as pd
|
|
8
7
|
|
|
9
8
|
from ..base import VannaBase
|
|
@@ -12,7 +11,7 @@ from ..base import VannaBase
|
|
|
12
11
|
class Marqo_VectorStore(VannaBase):
|
|
13
12
|
def __init__(self, config=None):
|
|
14
13
|
VannaBase.__init__(self, config=config)
|
|
15
|
-
|
|
14
|
+
|
|
16
15
|
if config is not None and "marqo_url" in config:
|
|
17
16
|
marqo_url = config["marqo_url"]
|
|
18
17
|
else:
|
|
@@ -22,7 +21,7 @@ class Marqo_VectorStore(VannaBase):
|
|
|
22
21
|
marqo_model = config["marqo_model"]
|
|
23
22
|
else:
|
|
24
23
|
marqo_model = "hf/all_datasets_v4_MiniLM-L6"
|
|
25
|
-
|
|
24
|
+
|
|
26
25
|
self.mq = marqo.Client(url=marqo_url)
|
|
27
26
|
|
|
28
27
|
for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]:
|
|
@@ -33,18 +32,17 @@ class Marqo_VectorStore(VannaBase):
|
|
|
33
32
|
print(f"Marqo index {index} already exists")
|
|
34
33
|
pass
|
|
35
34
|
|
|
36
|
-
|
|
37
35
|
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
38
36
|
# Marqo doesn't need to generate embeddings
|
|
39
|
-
pass
|
|
37
|
+
pass
|
|
40
38
|
|
|
41
39
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
42
|
-
id = str(uuid.uuid4())+"-sql"
|
|
43
|
-
question_sql_dict ={
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
40
|
+
id = str(uuid.uuid4()) + "-sql"
|
|
41
|
+
question_sql_dict = {
|
|
42
|
+
"question": question,
|
|
43
|
+
"sql": sql,
|
|
44
|
+
"_id": id,
|
|
45
|
+
}
|
|
48
46
|
|
|
49
47
|
self.mq.index("vanna-sql").add_documents(
|
|
50
48
|
[question_sql_dict],
|
|
@@ -54,11 +52,11 @@ class Marqo_VectorStore(VannaBase):
|
|
|
54
52
|
return id
|
|
55
53
|
|
|
56
54
|
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
57
|
-
id = str(uuid.uuid4())+"-ddl"
|
|
58
|
-
ddl_dict ={
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
55
|
+
id = str(uuid.uuid4()) + "-ddl"
|
|
56
|
+
ddl_dict = {
|
|
57
|
+
"ddl": ddl,
|
|
58
|
+
"_id": id,
|
|
59
|
+
}
|
|
62
60
|
|
|
63
61
|
self.mq.index("vanna-ddl").add_documents(
|
|
64
62
|
[ddl_dict],
|
|
@@ -66,13 +64,13 @@ class Marqo_VectorStore(VannaBase):
|
|
|
66
64
|
)
|
|
67
65
|
return id
|
|
68
66
|
|
|
69
|
-
def add_documentation(self,
|
|
70
|
-
id = str(uuid.uuid4())+"-doc"
|
|
71
|
-
doc_dict ={
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
67
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
68
|
+
id = str(uuid.uuid4()) + "-doc"
|
|
69
|
+
doc_dict = {
|
|
70
|
+
"doc": documentation,
|
|
71
|
+
"_id": id,
|
|
72
|
+
}
|
|
73
|
+
|
|
76
74
|
self.mq.index("vanna-doc").add_documents(
|
|
77
75
|
[doc_dict],
|
|
78
76
|
tensor_fields=["doc"],
|
|
@@ -80,31 +78,37 @@ class Marqo_VectorStore(VannaBase):
|
|
|
80
78
|
return id
|
|
81
79
|
|
|
82
80
|
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
83
|
-
data = []
|
|
84
|
-
|
|
85
|
-
for hit in self.mq.index(
|
|
86
|
-
data.append(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
81
|
+
data = []
|
|
82
|
+
|
|
83
|
+
for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]:
|
|
84
|
+
data.append(
|
|
85
|
+
{
|
|
86
|
+
"id": hit["_id"],
|
|
87
|
+
"training_data_type": "documentation",
|
|
88
|
+
"question": "",
|
|
89
|
+
"content": hit["doc"],
|
|
90
|
+
}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]:
|
|
94
|
+
data.append(
|
|
95
|
+
{
|
|
96
|
+
"id": hit["_id"],
|
|
97
|
+
"training_data_type": "ddl",
|
|
98
|
+
"question": "",
|
|
99
|
+
"content": hit["ddl"],
|
|
100
|
+
}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]:
|
|
104
|
+
data.append(
|
|
105
|
+
{
|
|
106
|
+
"id": hit["_id"],
|
|
107
|
+
"training_data_type": "sql",
|
|
108
|
+
"question": hit["question"],
|
|
109
|
+
"content": hit["sql"],
|
|
110
|
+
}
|
|
111
|
+
)
|
|
108
112
|
|
|
109
113
|
df = pd.DataFrame(data)
|
|
110
114
|
|
|
@@ -127,24 +131,24 @@ class Marqo_VectorStore(VannaBase):
|
|
|
127
131
|
@staticmethod
|
|
128
132
|
def _extract_documents(data) -> list:
|
|
129
133
|
# Check if 'hits' key is in the dictionary and if it's a list
|
|
130
|
-
if
|
|
134
|
+
if "hits" in data and isinstance(data["hits"], list):
|
|
131
135
|
# Iterate over each item in 'hits'
|
|
132
136
|
|
|
133
|
-
if len(data[
|
|
137
|
+
if len(data["hits"]) == 0:
|
|
134
138
|
return []
|
|
135
139
|
|
|
136
140
|
# If there is a "doc" key, return the value of that key
|
|
137
|
-
if "doc" in data[
|
|
138
|
-
return [hit["doc"] for hit in data[
|
|
139
|
-
|
|
141
|
+
if "doc" in data["hits"][0]:
|
|
142
|
+
return [hit["doc"] for hit in data["hits"]]
|
|
143
|
+
|
|
140
144
|
# If there is a "ddl" key, return the value of that key
|
|
141
|
-
if "ddl" in data[
|
|
142
|
-
return [hit["ddl"] for hit in data[
|
|
143
|
-
|
|
145
|
+
if "ddl" in data["hits"][0]:
|
|
146
|
+
return [hit["ddl"] for hit in data["hits"]]
|
|
147
|
+
|
|
144
148
|
# Otherwise return the entire hit
|
|
145
149
|
return [
|
|
146
|
-
{key: value for key, value in hit.items() if not key.startswith(
|
|
147
|
-
for hit in data[
|
|
150
|
+
{key: value for key, value in hit.items() if not key.startswith("_")}
|
|
151
|
+
for hit in data["hits"]
|
|
148
152
|
]
|
|
149
153
|
else:
|
|
150
154
|
# Return an empty list if 'hits' is not found or not a list
|
|
@@ -152,15 +156,15 @@ class Marqo_VectorStore(VannaBase):
|
|
|
152
156
|
|
|
153
157
|
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
|
154
158
|
return Marqo_VectorStore._extract_documents(
|
|
155
|
-
self.mq.index(
|
|
159
|
+
self.mq.index("vanna-sql").search(question)
|
|
156
160
|
)
|
|
157
161
|
|
|
158
162
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
159
163
|
return Marqo_VectorStore._extract_documents(
|
|
160
|
-
self.mq.index(
|
|
164
|
+
self.mq.index("vanna-ddl").search(question)
|
|
161
165
|
)
|
|
162
166
|
|
|
163
167
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
164
168
|
return Marqo_VectorStore._extract_documents(
|
|
165
|
-
self.mq.index(
|
|
169
|
+
self.mq.index("vanna-doc").search(question)
|
|
166
170
|
)
|
vanna/remote.py
CHANGED
|
@@ -197,7 +197,7 @@ class VannaDefault(VannaBase):
|
|
|
197
197
|
|
|
198
198
|
return status.id
|
|
199
199
|
|
|
200
|
-
def add_documentation(self,
|
|
200
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
201
201
|
"""
|
|
202
202
|
Adds documentation to the model's training data
|
|
203
203
|
|
|
@@ -469,4 +469,4 @@ class VannaDefault(VannaBase):
|
|
|
469
469
|
# Load the result into a dataclass
|
|
470
470
|
question_string_list = QuestionStringList(**d["result"])
|
|
471
471
|
|
|
472
|
-
return question_string_list.questions
|
|
472
|
+
return question_string_list.questions
|
vanna/vannadb/vannadb_vector.py
CHANGED
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import json
|
|
3
|
+
from io import StringIO
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import requests
|
|
7
|
+
|
|
1
8
|
from ..base import VannaBase
|
|
2
9
|
from ..types import (
|
|
3
|
-
QuestionSQLPair,
|
|
4
|
-
StatusWithId,
|
|
5
|
-
StringData,
|
|
6
10
|
DataFrameJSON,
|
|
11
|
+
Question,
|
|
12
|
+
QuestionSQLPair,
|
|
7
13
|
Status,
|
|
14
|
+
StatusWithId,
|
|
15
|
+
StringData,
|
|
8
16
|
TrainingData,
|
|
9
|
-
Question,
|
|
10
17
|
)
|
|
11
|
-
|
|
12
|
-
import pandas as pd
|
|
13
|
-
import requests
|
|
14
|
-
import json
|
|
15
|
-
import dataclasses
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class VannaDB_VectorStore(VannaBase):
|
|
18
21
|
def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
|
|
@@ -105,8 +108,8 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
105
108
|
|
|
106
109
|
return status.id
|
|
107
110
|
|
|
108
|
-
def add_documentation(self,
|
|
109
|
-
params = [StringData(data=
|
|
111
|
+
def add_documentation(self, documentation: str, **kwargs) -> str:
|
|
112
|
+
params = [StringData(data=documentation)]
|
|
110
113
|
|
|
111
114
|
d = self._rpc_call(method="add_documentation", params=params)
|
|
112
115
|
|
|
@@ -167,7 +170,7 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
167
170
|
training_data = self.related_training_data[question]
|
|
168
171
|
else:
|
|
169
172
|
training_data = self.get_related_training_data_cached(question)
|
|
170
|
-
|
|
173
|
+
|
|
171
174
|
return training_data.questions
|
|
172
175
|
|
|
173
176
|
def get_related_ddl(self, question: str, **kwargs) -> list:
|
|
@@ -175,7 +178,7 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
175
178
|
training_data = self.related_training_data[question]
|
|
176
179
|
else:
|
|
177
180
|
training_data = self.get_related_training_data_cached(question)
|
|
178
|
-
|
|
181
|
+
|
|
179
182
|
return training_data.ddl
|
|
180
183
|
|
|
181
184
|
def get_related_documentation(self, question: str, **kwargs) -> list:
|
|
@@ -183,5 +186,5 @@ class VannaDB_VectorStore(VannaBase):
|
|
|
183
186
|
training_data = self.related_training_data[question]
|
|
184
187
|
else:
|
|
185
188
|
training_data = self.get_related_training_data_cached(question)
|
|
186
|
-
|
|
187
|
-
return training_data.documentation
|
|
189
|
+
|
|
190
|
+
return training_data.documentation
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vanna
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Generate SQL queries from natural language
|
|
5
5
|
Author-email: Zain Hoda <zain@vanna.ai>
|
|
6
6
|
Requires-Python: >=3.7
|
|
@@ -41,7 +41,7 @@ Provides-Extra: postgres
|
|
|
41
41
|
Provides-Extra: snowflake
|
|
42
42
|
Provides-Extra: test
|
|
43
43
|
|
|
44
|
-
|
|
44
|
+
|
|
45
45
|
|
|
46
46
|
| GitHub | PyPI | Documentation |
|
|
47
47
|
| ------ | ---- | ------------- |
|
|
@@ -52,6 +52,8 @@ Vanna is an MIT-licensed open-source Python RAG (Retrieval-Augmented Generation)
|
|
|
52
52
|
|
|
53
53
|
https://github.com/vanna-ai/vanna/assets/7146154/1901f47a-515d-4982-af50-f12761a3b2ce
|
|
54
54
|
|
|
55
|
+

|
|
56
|
+
|
|
55
57
|
## How Vanna works
|
|
56
58
|
Vanna works in two easy steps - train a RAG "model" on your data, and then ask questions which will return SQL queries that can be set up to automatically run on your database.
|
|
57
59
|
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
vanna/__init__.py,sha256=thjmOUgHCboSxIkzQRKw-JvZLLFbnuyM7G5YIzmmmPQ,61545
|
|
2
|
+
vanna/flask.py,sha256=Kw7qjObb39J1BWX5PsjLRIbMJqntZI91K3tGlUxx5M0,12496
|
|
2
3
|
vanna/local.py,sha256=U5s8ybCRQhBUizi8I69o3jqOpTeu_6KGYY6DMwZxjG4,313
|
|
3
|
-
vanna/remote.py,sha256=
|
|
4
|
+
vanna/remote.py,sha256=yr0QSJCAKzziiPOa-mfsjGg1pVa5-vLj9vYl2VDlAfU,13584
|
|
4
5
|
vanna/utils.py,sha256=Q0H4eugPYg9SVpEoTWgvmuoJZZxOVRhNzrP97E5lyak,1472
|
|
5
6
|
vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
|
|
6
|
-
vanna/base/base.py,sha256=
|
|
7
|
+
vanna/base/base.py,sha256=w3qYB-8LhcP0bvV0MCc4VlGgmVzQ4TY-N3Ufix8DZs8,31204
|
|
7
8
|
vanna/chromadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
vanna/chromadb/chromadb_vector.py,sha256=
|
|
9
|
+
vanna/chromadb/chromadb_vector.py,sha256=4YGgWQNIw4QJFwtBRIW53vieXH8rTBez-cs7EZwxsNI,5893
|
|
9
10
|
vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
|
|
10
11
|
vanna/marqo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
vanna/marqo/marqo.py,sha256=
|
|
12
|
+
vanna/marqo/marqo.py,sha256=2OBuC5IZmGcFXN2Ah6GVPKHBYtkDXeSwhXsqUbxyU94,5285
|
|
12
13
|
vanna/mistral/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
14
|
vanna/mistral/mistral.py,sha256=A9dD8U-c12whGx8h_WOX15zUzaAJV-XLu_tpSiLamHo,8095
|
|
14
15
|
vanna/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -16,7 +17,7 @@ vanna/openai/openai_chat.py,sha256=U6wkXztJnQtABItUMDlBIDN6m3fqD6pMpa9gyQAQx8A,9
|
|
|
16
17
|
vanna/openai/openai_embeddings.py,sha256=kPtOqrKQYJnXe6My3pO9BWg-L3KIR1sJVqE3YoW0roA,1139
|
|
17
18
|
vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
|
|
18
19
|
vanna/vannadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
-
vanna/vannadb/vannadb_vector.py,sha256=
|
|
20
|
-
vanna-0.0.
|
|
21
|
-
vanna-0.0.
|
|
22
|
-
vanna-0.0.
|
|
20
|
+
vanna/vannadb/vannadb_vector.py,sha256=f4kddaJgTpZync7wnQi09QdODUuMtiHsK7WfKBUAmSo,5644
|
|
21
|
+
vanna-0.0.33.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
|
|
22
|
+
vanna-0.0.33.dist-info/METADATA,sha256=iLDk_AR5D179kN3392f67KeKoUyIDVNEWKEtLiomSTA,8780
|
|
23
|
+
vanna-0.0.33.dist-info/RECORD,,
|
|
File without changes
|