vanna 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/base/base.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import json
2
2
  import os
3
+ import re
3
4
  import sqlite3
4
5
  import traceback
5
-
6
6
  from abc import ABC, abstractmethod
7
7
  from typing import List, Tuple, Union
8
8
  from urllib.parse import urlparse
@@ -12,7 +12,6 @@ import plotly
12
12
  import plotly.express as px
13
13
  import plotly.graph_objects as go
14
14
  import requests
15
- import re
16
15
 
17
16
  from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
18
17
  from ..types import TrainingPlan, TrainingPlanItem
@@ -50,8 +49,8 @@ class VannaBase(ABC):
50
49
  **kwargs,
51
50
  )
52
51
  llm_response = self.submit_prompt(prompt, **kwargs)
53
-
54
- numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)
52
+
53
+ numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
55
54
  return numbers_removed.split("\n")
56
55
 
57
56
  def generate_questions(self, **kwargs) -> list[str]:
@@ -65,7 +64,7 @@ class VannaBase(ABC):
65
64
  """
66
65
  question_sql = self.get_similar_question_sql(question="", **kwargs)
67
66
 
68
- return [q['question'] for q in question_sql]
67
+ return [q["question"] for q in question_sql]
69
68
 
70
69
  # ----------------- Use Any Embeddings API ----------------- #
71
70
  @abstractmethod
@@ -94,7 +93,7 @@ class VannaBase(ABC):
94
93
  pass
95
94
 
96
95
  @abstractmethod
97
- def add_documentation(self, doc: str, **kwargs) -> str:
96
+ def add_documentation(self, documentation: str, **kwargs) -> str:
98
97
  pass
99
98
 
100
99
  @abstractmethod
@@ -120,12 +119,12 @@ class VannaBase(ABC):
120
119
 
121
120
  @abstractmethod
122
121
  def get_followup_questions_prompt(
123
- self,
124
- question: str,
122
+ self,
123
+ question: str,
125
124
  question_sql_list: list,
126
125
  ddl_list: list,
127
- doc_list: list,
128
- **kwargs
126
+ doc_list: list,
127
+ **kwargs,
129
128
  ):
130
129
  pass
131
130
 
@@ -248,7 +247,7 @@ class VannaBase(ABC):
248
247
  url = path
249
248
 
250
249
  # Connect to the database
251
- conn = sqlite3.connect(url)
250
+ conn = sqlite3.connect(url, check_same_thread=False)
252
251
 
253
252
  def run_sql_sqlite(sql: str):
254
253
  return pd.read_sql_query(sql, conn)
@@ -829,9 +828,11 @@ class VannaBase(ABC):
829
828
  fig = ldict.get("fig", None)
830
829
  except Exception as e:
831
830
  # Inspect data types
832
- numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
833
- categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
834
-
831
+ numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
832
+ categorical_cols = df.select_dtypes(
833
+ include=["object", "category"]
834
+ ).columns.tolist()
835
+
835
836
  # Decision-making for plot type
836
837
  if len(numeric_cols) >= 2:
837
838
  # Use the first two numeric columns for a scatter plot
@@ -3,9 +3,9 @@ import uuid
3
3
  from abc import abstractmethod
4
4
 
5
5
  import chromadb
6
+ import pandas as pd
6
7
  from chromadb.config import Settings
7
8
  from chromadb.utils import embedding_functions
8
- import pandas as pd
9
9
 
10
10
  from ..base import VannaBase
11
11
 
@@ -47,7 +47,7 @@ class ChromaDB_VectorStore(VannaBase):
47
47
  "sql": sql,
48
48
  }
49
49
  )
50
- id = str(uuid.uuid4())+"-sql"
50
+ id = str(uuid.uuid4()) + "-sql"
51
51
  self.sql_collection.add(
52
52
  documents=question_sql_json,
53
53
  embeddings=self.generate_embedding(question_sql_json),
@@ -57,7 +57,7 @@ class ChromaDB_VectorStore(VannaBase):
57
57
  return id
58
58
 
59
59
  def add_ddl(self, ddl: str, **kwargs) -> str:
60
- id = str(uuid.uuid4())+"-ddl"
60
+ id = str(uuid.uuid4()) + "-ddl"
61
61
  self.ddl_collection.add(
62
62
  documents=ddl,
63
63
  embeddings=self.generate_embedding(ddl),
@@ -65,11 +65,11 @@ class ChromaDB_VectorStore(VannaBase):
65
65
  )
66
66
  return id
67
67
 
68
- def add_documentation(self, doc: str, **kwargs) -> str:
69
- id = str(uuid.uuid4())+"-doc"
68
+ def add_documentation(self, documentation: str, **kwargs) -> str:
69
+ id = str(uuid.uuid4()) + "-doc"
70
70
  self.documentation_collection.add(
71
- documents=doc,
72
- embeddings=self.generate_embedding(doc),
71
+ documents=documentation,
72
+ embeddings=self.generate_embedding(documentation),
73
73
  ids=id,
74
74
  )
75
75
  return id
@@ -81,15 +81,17 @@ class ChromaDB_VectorStore(VannaBase):
81
81
 
82
82
  if sql_data is not None:
83
83
  # Extract the documents and ids
84
- documents = [json.loads(doc) for doc in sql_data['documents']]
85
- ids = sql_data['ids']
84
+ documents = [json.loads(doc) for doc in sql_data["documents"]]
85
+ ids = sql_data["ids"]
86
86
 
87
87
  # Create a DataFrame
88
- df_sql = pd.DataFrame({
89
- 'id': ids,
90
- 'question': [doc['question'] for doc in documents],
91
- 'content': [doc['sql'] for doc in documents]
92
- })
88
+ df_sql = pd.DataFrame(
89
+ {
90
+ "id": ids,
91
+ "question": [doc["question"] for doc in documents],
92
+ "content": [doc["sql"] for doc in documents],
93
+ }
94
+ )
93
95
 
94
96
  df_sql["training_data_type"] = "sql"
95
97
 
@@ -99,15 +101,17 @@ class ChromaDB_VectorStore(VannaBase):
99
101
 
100
102
  if ddl_data is not None:
101
103
  # Extract the documents and ids
102
- documents = [doc for doc in ddl_data['documents']]
103
- ids = ddl_data['ids']
104
+ documents = [doc for doc in ddl_data["documents"]]
105
+ ids = ddl_data["ids"]
104
106
 
105
107
  # Create a DataFrame
106
- df_ddl = pd.DataFrame({
107
- 'id': ids,
108
- 'question': [None for doc in documents],
109
- 'content': [doc for doc in documents]
110
- })
108
+ df_ddl = pd.DataFrame(
109
+ {
110
+ "id": ids,
111
+ "question": [None for doc in documents],
112
+ "content": [doc for doc in documents],
113
+ }
114
+ )
111
115
 
112
116
  df_ddl["training_data_type"] = "ddl"
113
117
 
@@ -117,15 +121,17 @@ class ChromaDB_VectorStore(VannaBase):
117
121
 
118
122
  if doc_data is not None:
119
123
  # Extract the documents and ids
120
- documents = [doc for doc in doc_data['documents']]
121
- ids = doc_data['ids']
124
+ documents = [doc for doc in doc_data["documents"]]
125
+ ids = doc_data["ids"]
122
126
 
123
127
  # Create a DataFrame
124
- df_doc = pd.DataFrame({
125
- 'id': ids,
126
- 'question': [None for doc in documents],
127
- 'content': [doc for doc in documents]
128
- })
128
+ df_doc = pd.DataFrame(
129
+ {
130
+ "id": ids,
131
+ "question": [None for doc in documents],
132
+ "content": [doc for doc in documents],
133
+ }
134
+ )
129
135
 
130
136
  df_doc["training_data_type"] = "documentation"
131
137
 
vanna/flask.py ADDED
@@ -0,0 +1,331 @@
1
+ import flask
2
+ from flask import Flask, Response, jsonify, request
3
+ import logging
4
+ import requests
5
+ from functools import wraps
6
+
7
+ from abc import ABC, abstractmethod
8
+ import uuid
9
+
10
+ class Cache(ABC):
11
+ @abstractmethod
12
+ def generate_id(self, *args, **kwargs):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def get(self, id, field):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def get_all(self, field_list) -> list:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def set(self, id, field, value):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def delete(self, id):
29
+ pass
30
+
31
+
32
+ class MemoryCache(Cache):
33
+ def __init__(self):
34
+ self.cache = {}
35
+
36
+ def generate_id(self, *args, **kwargs):
37
+ return str(uuid.uuid4())
38
+
39
+ def set(self, id, field, value):
40
+ if id not in self.cache:
41
+ self.cache[id] = {}
42
+
43
+ self.cache[id][field] = value
44
+
45
+ def get(self, id, field):
46
+ if id not in self.cache:
47
+ return None
48
+
49
+ if field not in self.cache[id]:
50
+ return None
51
+
52
+ return self.cache[id][field]
53
+
54
+ def get_all(self, field_list) -> list:
55
+ return [
56
+ {
57
+ "id": id,
58
+ **{
59
+ field: self.get(id=id, field=field)
60
+ for field in field_list
61
+ }
62
+ }
63
+ for id in self.cache
64
+ ]
65
+
66
+ def delete(self, id):
67
+ if id in self.cache:
68
+ del self.cache[id]
69
+
70
+ class VannaFlaskApp:
71
+ flask_app = None
72
+
73
+ def requires_cache(self, fields):
74
+ def decorator(f):
75
+ @wraps(f)
76
+ def decorated(*args, **kwargs):
77
+ id = request.args.get('id')
78
+
79
+ if id is None:
80
+ return jsonify({"type": "error", "error": "No id provided"})
81
+
82
+ for field in fields:
83
+ if self.cache.get(id=id, field=field) is None:
84
+ return jsonify({"type": "error", "error": f"No {field} found"})
85
+
86
+ field_values = {field: self.cache.get(id=id, field=field) for field in fields}
87
+
88
+ # Add the id to the field_values
89
+ field_values['id'] = id
90
+
91
+ return f(*args, **field_values, **kwargs)
92
+ return decorated
93
+ return decorator
94
+
95
+ def __init__(self, vn, cache: Cache = MemoryCache()):
96
+ self.flask_app = Flask(__name__)
97
+ self.vn = vn
98
+ self.cache = cache
99
+
100
+ log = logging.getLogger('werkzeug')
101
+ log.setLevel(logging.ERROR)
102
+
103
+ @self.flask_app.route('/api/v0/generate_questions', methods=['GET'])
104
+ def generate_questions():
105
+ # If self has an _model attribute and model=='chinook'
106
+ if hasattr(self.vn, '_model') and self.vn._model == 'chinook':
107
+ return jsonify({
108
+ "type": "question_list",
109
+ "questions": ['What are the top 10 artists by sales?', 'What are the total sales per year by country?', 'Who is the top selling artist in each genre? Show the sales numbers.', 'How do the employees rank in terms of sales performance?', 'Which 5 cities have the most customers?'],
110
+ "header": "Here are some questions you can ask:"
111
+ })
112
+
113
+ @self.flask_app.route('/api/v0/generate_sql', methods=['GET'])
114
+ def generate_sql():
115
+ question = flask.request.args.get('question')
116
+
117
+ if question is None:
118
+ return jsonify({"type": "error", "error": "No question provided"})
119
+
120
+ id = self.cache.generate_id(question=question)
121
+ sql = vn.generate_sql(question=question)
122
+
123
+ self.cache.set(id=id, field='question', value=question)
124
+ self.cache.set(id=id, field='sql', value=sql)
125
+
126
+ return jsonify(
127
+ {
128
+ "type": "sql",
129
+ "id": id,
130
+ "text": sql,
131
+ })
132
+
133
+ @self.flask_app.route('/api/v0/run_sql', methods=['GET'])
134
+ @self.requires_cache(['sql'])
135
+ def run_sql(id: str, sql: str):
136
+ try:
137
+ if not vn.run_sql_is_set:
138
+ return jsonify({"type": "error", "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries."})
139
+
140
+ df = vn.run_sql(sql=sql)
141
+
142
+ cache.set(id=id, field='df', value=df)
143
+
144
+ return jsonify(
145
+ {
146
+ "type": "df",
147
+ "id": id,
148
+ "df": df.head(10).to_json(orient='records'),
149
+ })
150
+
151
+ except Exception as e:
152
+ return jsonify({"type": "error", "error": str(e)})
153
+
154
+ @self.flask_app.route('/api/v0/download_csv', methods=['GET'])
155
+ @self.requires_cache(['df'])
156
+ def download_csv(id: str, df):
157
+ csv = df.to_csv()
158
+
159
+ return Response(
160
+ csv,
161
+ mimetype="text/csv",
162
+ headers={"Content-disposition":
163
+ f"attachment; filename={id}.csv"})
164
+
165
+ @self.flask_app.route('/api/v0/generate_plotly_figure', methods=['GET'])
166
+ @self.requires_cache(['df', 'question', 'sql'])
167
+ def generate_plotly_figure(id: str, df, question, sql):
168
+ try:
169
+ code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
170
+ fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
171
+ fig_json = fig.to_json()
172
+
173
+ cache.set(id=id, field='fig_json', value=fig_json)
174
+
175
+ return jsonify(
176
+ {
177
+ "type": "plotly_figure",
178
+ "id": id,
179
+ "fig": fig_json,
180
+ })
181
+ except Exception as e:
182
+ # Print the stack trace
183
+ import traceback
184
+ traceback.print_exc()
185
+
186
+ return jsonify({"type": "error", "error": str(e)})
187
+
188
+ @self.flask_app.route('/api/v0/get_training_data', methods=['GET'])
189
+ def get_training_data():
190
+ df = vn.get_training_data()
191
+
192
+ return jsonify(
193
+ {
194
+ "type": "df",
195
+ "id": "training_data",
196
+ "df": df.tail(25).to_json(orient='records'),
197
+ })
198
+
199
+ @self.flask_app.route('/api/v0/remove_training_data', methods=['POST'])
200
+ def remove_training_data():
201
+ # Get id from the JSON body
202
+ id = flask.request.json.get('id')
203
+
204
+ if id is None:
205
+ return jsonify({"type": "error", "error": "No id provided"})
206
+
207
+ if vn.remove_training_data(id=id):
208
+ return jsonify({"success": True})
209
+ else:
210
+ return jsonify({"type": "error", "error": "Couldn't remove training data"})
211
+
212
+ @self.flask_app.route('/api/v0/train', methods=['POST'])
213
+ def add_training_data():
214
+ question = flask.request.json.get('question')
215
+ sql = flask.request.json.get('sql')
216
+ ddl = flask.request.json.get('ddl')
217
+ documentation = flask.request.json.get('documentation')
218
+
219
+ try:
220
+ id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
221
+
222
+ return jsonify({"id": id})
223
+ except Exception as e:
224
+ print("TRAINING ERROR", e)
225
+ return jsonify({"type": "error", "error": str(e)})
226
+
227
+ @self.flask_app.route('/api/v0/generate_followup_questions', methods=['GET'])
228
+ @self.requires_cache(['df', 'question'])
229
+ def generate_followup_questions(id: str, df, question):
230
+ followup_questions = []
231
+ # followup_questions = vn.generate_followup_questions(question=question, df=df)
232
+ # if followup_questions is not None and len(followup_questions) > 5:
233
+ # followup_questions = followup_questions[:5]
234
+
235
+ cache.set(id=id, field='followup_questions', value=followup_questions)
236
+
237
+ return jsonify(
238
+ {
239
+ "type": "question_list",
240
+ "id": id,
241
+ "questions": followup_questions,
242
+ "header": "Followup Questions can be enabled in a future version if you allow the LLM to 'see' your query results."
243
+ })
244
+
245
+ @self.flask_app.route('/api/v0/load_question', methods=['GET'])
246
+ @self.requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
247
+ def load_question(id: str, question, sql, df, fig_json, followup_questions):
248
+ try:
249
+ return jsonify(
250
+ {
251
+ "type": "question_cache",
252
+ "id": id,
253
+ "question": question,
254
+ "sql": sql,
255
+ "df": df.head(10).to_json(orient='records'),
256
+ "fig": fig_json,
257
+ "followup_questions": followup_questions,
258
+ })
259
+
260
+ except Exception as e:
261
+ return jsonify({"type": "error", "error": str(e)})
262
+
263
+ @self.flask_app.route('/api/v0/get_question_history', methods=['GET'])
264
+ def get_question_history():
265
+ return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
266
+
267
+
268
+ @self.flask_app.route('/api/v0/<path:catch_all>', methods=['GET', 'POST'])
269
+ def catch_all(catch_all):
270
+ return jsonify({"type": "error", "error": "The rest of the API is not ported yet."})
271
+
272
+ @self.flask_app.route('/assets/<path:filename>')
273
+ def proxy_assets(filename):
274
+ remote_url = f'https://vanna.ai/assets/{filename}'
275
+ response = requests.get(remote_url, stream=True)
276
+
277
+ # Check if the request to the remote URL was successful
278
+ if response.status_code == 200:
279
+ excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
280
+ headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
281
+ return Response(response.content, response.status_code, headers)
282
+ else:
283
+ return 'Error fetching file from remote server', response.status_code
284
+
285
+ # Proxy the /vanna.svg file to the remote server
286
+ @self.flask_app.route('/vanna.svg')
287
+ def proxy_vanna_svg():
288
+ remote_url = f'https://vanna.ai/img/vanna.svg'
289
+ response = requests.get(remote_url, stream=True)
290
+
291
+ # Check if the request to the remote URL was successful
292
+ if response.status_code == 200:
293
+ excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
294
+ headers = [(name, value) for (name, value) in response.raw.headers.items() if name.lower() not in excluded_headers]
295
+ return Response(response.content, response.status_code, headers)
296
+ else:
297
+ return 'Error fetching file from remote server', response.status_code
298
+
299
+ @self.flask_app.route('/', defaults={'path': ''})
300
+ @self.flask_app.route('/<path:path>')
301
+ def hello(path: str):
302
+ return """
303
+ <!doctype html>
304
+ <html lang="en">
305
+ <head>
306
+ <meta charset="UTF-8" />
307
+ <link rel="icon" type="image/svg+xml" href="/vanna.svg" />
308
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
309
+ <link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
310
+ <script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
311
+ <title>Vanna.AI</title>
312
+ <script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
313
+ <link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
314
+ </head>
315
+ <body class="bg-white dark:bg-slate-900">
316
+ <div id="app"></div>
317
+ </body>
318
+ </html>
319
+ """
320
+
321
+ def run(self):
322
+ try:
323
+ from google.colab import output
324
+ output.serve_kernel_port_as_window(8084)
325
+ from google.colab.output import eval_js
326
+ print("Your app is running at:")
327
+ print(eval_js("google.colab.kernel.proxyPort(8084)"))
328
+ except:
329
+ print("Your app is running at:")
330
+ print("http://localhost:8084")
331
+ self.flask_app.run(host='0.0.0.0', port=8084, debug=False)
vanna/marqo/marqo.py CHANGED
@@ -3,7 +3,6 @@ import uuid
3
3
  from abc import abstractmethod
4
4
 
5
5
  import marqo
6
-
7
6
  import pandas as pd
8
7
 
9
8
  from ..base import VannaBase
@@ -12,7 +11,7 @@ from ..base import VannaBase
12
11
  class Marqo_VectorStore(VannaBase):
13
12
  def __init__(self, config=None):
14
13
  VannaBase.__init__(self, config=config)
15
-
14
+
16
15
  if config is not None and "marqo_url" in config:
17
16
  marqo_url = config["marqo_url"]
18
17
  else:
@@ -22,7 +21,7 @@ class Marqo_VectorStore(VannaBase):
22
21
  marqo_model = config["marqo_model"]
23
22
  else:
24
23
  marqo_model = "hf/all_datasets_v4_MiniLM-L6"
25
-
24
+
26
25
  self.mq = marqo.Client(url=marqo_url)
27
26
 
28
27
  for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]:
@@ -33,18 +32,17 @@ class Marqo_VectorStore(VannaBase):
33
32
  print(f"Marqo index {index} already exists")
34
33
  pass
35
34
 
36
-
37
35
  def generate_embedding(self, data: str, **kwargs) -> list[float]:
38
36
  # Marqo doesn't need to generate embeddings
39
- pass
37
+ pass
40
38
 
41
39
  def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
42
- id = str(uuid.uuid4())+"-sql"
43
- question_sql_dict ={
44
- "question": question,
45
- "sql": sql,
46
- "_id": id,
47
- }
40
+ id = str(uuid.uuid4()) + "-sql"
41
+ question_sql_dict = {
42
+ "question": question,
43
+ "sql": sql,
44
+ "_id": id,
45
+ }
48
46
 
49
47
  self.mq.index("vanna-sql").add_documents(
50
48
  [question_sql_dict],
@@ -54,11 +52,11 @@ class Marqo_VectorStore(VannaBase):
54
52
  return id
55
53
 
56
54
  def add_ddl(self, ddl: str, **kwargs) -> str:
57
- id = str(uuid.uuid4())+"-ddl"
58
- ddl_dict ={
59
- "ddl": ddl,
60
- "_id": id,
61
- }
55
+ id = str(uuid.uuid4()) + "-ddl"
56
+ ddl_dict = {
57
+ "ddl": ddl,
58
+ "_id": id,
59
+ }
62
60
 
63
61
  self.mq.index("vanna-ddl").add_documents(
64
62
  [ddl_dict],
@@ -66,13 +64,13 @@ class Marqo_VectorStore(VannaBase):
66
64
  )
67
65
  return id
68
66
 
69
- def add_documentation(self, doc: str, **kwargs) -> str:
70
- id = str(uuid.uuid4())+"-doc"
71
- doc_dict ={
72
- "doc": doc,
73
- "_id": id,
74
- }
75
-
67
+ def add_documentation(self, documentation: str, **kwargs) -> str:
68
+ id = str(uuid.uuid4()) + "-doc"
69
+ doc_dict = {
70
+ "doc": documentation,
71
+ "_id": id,
72
+ }
73
+
76
74
  self.mq.index("vanna-doc").add_documents(
77
75
  [doc_dict],
78
76
  tensor_fields=["doc"],
@@ -80,31 +78,37 @@ class Marqo_VectorStore(VannaBase):
80
78
  return id
81
79
 
82
80
  def get_training_data(self, **kwargs) -> pd.DataFrame:
83
- data = []
84
-
85
- for hit in self.mq.index('vanna-doc').search("", limit=1000)['hits']:
86
- data.append({
87
- "id": hit["_id"],
88
- "training_data_type": "documentation",
89
- "question": "",
90
- "content": hit["doc"],
91
- })
92
-
93
- for hit in self.mq.index('vanna-ddl').search("", limit=1000)['hits']:
94
- data.append({
95
- "id": hit["_id"],
96
- "training_data_type": "ddl",
97
- "question": "",
98
- "content": hit["ddl"],
99
- })
100
-
101
- for hit in self.mq.index('vanna-sql').search("", limit=1000)['hits']:
102
- data.append({
103
- "id": hit["_id"],
104
- "training_data_type": "sql",
105
- "question": hit["question"],
106
- "content": hit["sql"],
107
- })
81
+ data = []
82
+
83
+ for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]:
84
+ data.append(
85
+ {
86
+ "id": hit["_id"],
87
+ "training_data_type": "documentation",
88
+ "question": "",
89
+ "content": hit["doc"],
90
+ }
91
+ )
92
+
93
+ for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]:
94
+ data.append(
95
+ {
96
+ "id": hit["_id"],
97
+ "training_data_type": "ddl",
98
+ "question": "",
99
+ "content": hit["ddl"],
100
+ }
101
+ )
102
+
103
+ for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]:
104
+ data.append(
105
+ {
106
+ "id": hit["_id"],
107
+ "training_data_type": "sql",
108
+ "question": hit["question"],
109
+ "content": hit["sql"],
110
+ }
111
+ )
108
112
 
109
113
  df = pd.DataFrame(data)
110
114
 
@@ -127,24 +131,24 @@ class Marqo_VectorStore(VannaBase):
127
131
  @staticmethod
128
132
  def _extract_documents(data) -> list:
129
133
  # Check if 'hits' key is in the dictionary and if it's a list
130
- if 'hits' in data and isinstance(data['hits'], list):
134
+ if "hits" in data and isinstance(data["hits"], list):
131
135
  # Iterate over each item in 'hits'
132
136
 
133
- if len(data['hits']) == 0:
137
+ if len(data["hits"]) == 0:
134
138
  return []
135
139
 
136
140
  # If there is a "doc" key, return the value of that key
137
- if "doc" in data['hits'][0]:
138
- return [hit["doc"] for hit in data['hits']]
139
-
141
+ if "doc" in data["hits"][0]:
142
+ return [hit["doc"] for hit in data["hits"]]
143
+
140
144
  # If there is a "ddl" key, return the value of that key
141
- if "ddl" in data['hits'][0]:
142
- return [hit["ddl"] for hit in data['hits']]
143
-
145
+ if "ddl" in data["hits"][0]:
146
+ return [hit["ddl"] for hit in data["hits"]]
147
+
144
148
  # Otherwise return the entire hit
145
149
  return [
146
- {key: value for key, value in hit.items() if not key.startswith('_')}
147
- for hit in data['hits']
150
+ {key: value for key, value in hit.items() if not key.startswith("_")}
151
+ for hit in data["hits"]
148
152
  ]
149
153
  else:
150
154
  # Return an empty list if 'hits' is not found or not a list
@@ -152,15 +156,15 @@ class Marqo_VectorStore(VannaBase):
152
156
 
153
157
  def get_similar_question_sql(self, question: str, **kwargs) -> list:
154
158
  return Marqo_VectorStore._extract_documents(
155
- self.mq.index('vanna-sql').search(question)
159
+ self.mq.index("vanna-sql").search(question)
156
160
  )
157
161
 
158
162
  def get_related_ddl(self, question: str, **kwargs) -> list:
159
163
  return Marqo_VectorStore._extract_documents(
160
- self.mq.index('vanna-ddl').search(question)
164
+ self.mq.index("vanna-ddl").search(question)
161
165
  )
162
166
 
163
167
  def get_related_documentation(self, question: str, **kwargs) -> list:
164
168
  return Marqo_VectorStore._extract_documents(
165
- self.mq.index('vanna-doc').search(question)
169
+ self.mq.index("vanna-doc").search(question)
166
170
  )
vanna/remote.py CHANGED
@@ -197,7 +197,7 @@ class VannaDefault(VannaBase):
197
197
 
198
198
  return status.id
199
199
 
200
- def add_documentation(self, doc: str, **kwargs) -> str:
200
+ def add_documentation(self, documentation: str, **kwargs) -> str:
201
201
  """
202
202
  Adds documentation to the model's training data
203
203
 
@@ -469,4 +469,4 @@ class VannaDefault(VannaBase):
469
469
  # Load the result into a dataclass
470
470
  question_string_list = QuestionStringList(**d["result"])
471
471
 
472
- return question_string_list.questions
472
+ return question_string_list.questions
@@ -1,18 +1,21 @@
1
+ import dataclasses
2
+ import json
3
+ from io import StringIO
4
+
5
+ import pandas as pd
6
+ import requests
7
+
1
8
  from ..base import VannaBase
2
9
  from ..types import (
3
- QuestionSQLPair,
4
- StatusWithId,
5
- StringData,
6
10
  DataFrameJSON,
11
+ Question,
12
+ QuestionSQLPair,
7
13
  Status,
14
+ StatusWithId,
15
+ StringData,
8
16
  TrainingData,
9
- Question,
10
17
  )
11
- from io import StringIO
12
- import pandas as pd
13
- import requests
14
- import json
15
- import dataclasses
18
+
16
19
 
17
20
  class VannaDB_VectorStore(VannaBase):
18
21
  def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
@@ -105,8 +108,8 @@ class VannaDB_VectorStore(VannaBase):
105
108
 
106
109
  return status.id
107
110
 
108
- def add_documentation(self, doc: str, **kwargs) -> str:
109
- params = [StringData(data=doc)]
111
+ def add_documentation(self, documentation: str, **kwargs) -> str:
112
+ params = [StringData(data=documentation)]
110
113
 
111
114
  d = self._rpc_call(method="add_documentation", params=params)
112
115
 
@@ -167,7 +170,7 @@ class VannaDB_VectorStore(VannaBase):
167
170
  training_data = self.related_training_data[question]
168
171
  else:
169
172
  training_data = self.get_related_training_data_cached(question)
170
-
173
+
171
174
  return training_data.questions
172
175
 
173
176
  def get_related_ddl(self, question: str, **kwargs) -> list:
@@ -175,7 +178,7 @@ class VannaDB_VectorStore(VannaBase):
175
178
  training_data = self.related_training_data[question]
176
179
  else:
177
180
  training_data = self.get_related_training_data_cached(question)
178
-
181
+
179
182
  return training_data.ddl
180
183
 
181
184
  def get_related_documentation(self, question: str, **kwargs) -> list:
@@ -183,5 +186,5 @@ class VannaDB_VectorStore(VannaBase):
183
186
  training_data = self.related_training_data[question]
184
187
  else:
185
188
  training_data = self.get_related_training_data_cached(question)
186
-
187
- return training_data.documentation
189
+
190
+ return training_data.documentation
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.0.31
3
+ Version: 0.0.33
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.7
@@ -41,7 +41,7 @@ Provides-Extra: postgres
41
41
  Provides-Extra: snowflake
42
42
  Provides-Extra: test
43
43
 
44
- ![](https://img.vanna.ai/vanna-github.svg)
44
+
45
45
 
46
46
  | GitHub | PyPI | Documentation |
47
47
  | ------ | ---- | ------------- |
@@ -52,6 +52,8 @@ Vanna is an MIT-licensed open-source Python RAG (Retrieval-Augmented Generation)
52
52
 
53
53
  https://github.com/vanna-ai/vanna/assets/7146154/1901f47a-515d-4982-af50-f12761a3b2ce
54
54
 
55
+ ![vanna-quadrants](https://github.com/vanna-ai/vanna/assets/7146154/1c7c88ba-c144-4ecf-a028-cf5ba7344ca2)
56
+
55
57
  ## How Vanna works
56
58
  Vanna works in two easy steps - train a RAG "model" on your data, and then ask questions which will return SQL queries that can be set up to automatically run on your database.
57
59
 
@@ -1,14 +1,15 @@
1
1
  vanna/__init__.py,sha256=thjmOUgHCboSxIkzQRKw-JvZLLFbnuyM7G5YIzmmmPQ,61545
2
+ vanna/flask.py,sha256=Kw7qjObb39J1BWX5PsjLRIbMJqntZI91K3tGlUxx5M0,12496
2
3
  vanna/local.py,sha256=U5s8ybCRQhBUizi8I69o3jqOpTeu_6KGYY6DMwZxjG4,313
3
- vanna/remote.py,sha256=xWlF48eQXuc03NZrDpMQgvrM6dbbfbEjX_FEmQf_b5c,13573
4
+ vanna/remote.py,sha256=yr0QSJCAKzziiPOa-mfsjGg1pVa5-vLj9vYl2VDlAfU,13584
4
5
  vanna/utils.py,sha256=Q0H4eugPYg9SVpEoTWgvmuoJZZxOVRhNzrP97E5lyak,1472
5
6
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
6
- vanna/base/base.py,sha256=4y0FFUNIb8Y948RFAzt6sYOs7iKqp3ORwMvtgrda00o,31162
7
+ vanna/base/base.py,sha256=w3qYB-8LhcP0bvV0MCc4VlGgmVzQ4TY-N3Ufix8DZs8,31204
7
8
  vanna/chromadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- vanna/chromadb/chromadb_vector.py,sha256=af1n7htIkSnpd7h9906mkKSK9BpvNNQa48_z4FS-_nE,5716
9
+ vanna/chromadb/chromadb_vector.py,sha256=4YGgWQNIw4QJFwtBRIW53vieXH8rTBez-cs7EZwxsNI,5893
9
10
  vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
10
11
  vanna/marqo/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- vanna/marqo/marqo.py,sha256=8L2W6XRu37BvnIpVnlyGbg_w2r2bceDyLuVCDAhDqs0,5206
12
+ vanna/marqo/marqo.py,sha256=2OBuC5IZmGcFXN2Ah6GVPKHBYtkDXeSwhXsqUbxyU94,5285
12
13
  vanna/mistral/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
14
  vanna/mistral/mistral.py,sha256=A9dD8U-c12whGx8h_WOX15zUzaAJV-XLu_tpSiLamHo,8095
14
15
  vanna/openai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,7 +17,7 @@ vanna/openai/openai_chat.py,sha256=U6wkXztJnQtABItUMDlBIDN6m3fqD6pMpa9gyQAQx8A,9
16
17
  vanna/openai/openai_embeddings.py,sha256=kPtOqrKQYJnXe6My3pO9BWg-L3KIR1sJVqE3YoW0roA,1139
17
18
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
18
19
  vanna/vannadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- vanna/vannadb/vannadb_vector.py,sha256=zX_oT66LQSDeqO87I5xdKA87uQRQDl-ZrGOh8BYkUOU,5645
20
- vanna-0.0.31.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
21
- vanna-0.0.31.dist-info/METADATA,sha256=ILA-ZXW-UHaYqL6xKJrrJfu4WfoX45kawA8VN4X5itE,8715
22
- vanna-0.0.31.dist-info/RECORD,,
20
+ vanna/vannadb/vannadb_vector.py,sha256=f4kddaJgTpZync7wnQi09QdODUuMtiHsK7WfKBUAmSo,5644
21
+ vanna-0.0.33.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
22
+ vanna-0.0.33.dist-info/METADATA,sha256=iLDk_AR5D179kN3392f67KeKoUyIDVNEWKEtLiomSTA,8780
23
+ vanna-0.0.33.dist-info/RECORD,,
File without changes