vanna 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/flask/auth.py ADDED
@@ -0,0 +1,55 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import flask
4
+
5
+
6
+ class AuthInterface(ABC):
7
+ @abstractmethod
8
+ def get_user(self, flask_request) -> any:
9
+ pass
10
+
11
+ @abstractmethod
12
+ def is_logged_in(self, user: any) -> bool:
13
+ pass
14
+
15
+ @abstractmethod
16
+ def override_config_for_user(self, user: any, config: dict) -> dict:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def login_form(self) -> str:
21
+ pass
22
+
23
+ @abstractmethod
24
+ def login_handler(self, flask_request) -> str:
25
+ pass
26
+
27
+ @abstractmethod
28
+ def callback_handler(self, flask_request) -> str:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def logout_handler(self, flask_request) -> str:
33
+ pass
34
+
35
+ class NoAuth(AuthInterface):
36
+ def get_user(self, flask_request) -> any:
37
+ return {}
38
+
39
+ def is_logged_in(self, user: any) -> bool:
40
+ return True
41
+
42
+ def override_config_for_user(self, user: any, config: dict) -> dict:
43
+ return config
44
+
45
+ def login_form(self) -> str:
46
+ return ''
47
+
48
+ def login_handler(self, flask_request) -> str:
49
+ return 'No login required'
50
+
51
+ def callback_handler(self, flask_request) -> str:
52
+ return 'No login required'
53
+
54
+ def logout_handler(self, flask_request) -> str:
55
+ return 'No login required'
@@ -0,0 +1 @@
1
+ from .gemini_chat import GoogleGeminiChat
@@ -0,0 +1,52 @@
1
+ import os
2
+ from ..base import VannaBase
3
+
4
+
5
+ class GoogleGeminiChat(VannaBase):
6
+ def __init__(self, config=None):
7
+ VannaBase.__init__(self, config=config)
8
+
9
+ # default temperature - can be overrided using config
10
+ self.temperature = 0.7
11
+
12
+ if "temperature" in config:
13
+ self.temperature = config["temperature"]
14
+
15
+ if "model_name" in config:
16
+ model_name = config["model_name"]
17
+ else:
18
+ model_name = "gemini-1.0-pro"
19
+
20
+ self.google_api_key = None
21
+
22
+ if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
23
+ """
24
+ If Google api_key is provided through config
25
+ or set as an environment variable, assign it.
26
+ """
27
+ import google.generativeai as genai
28
+
29
+ genai.configure(api_key=config["api_key"])
30
+ self.chat_model = genai.GenerativeModel(model_name)
31
+ else:
32
+ # Authenticate using VertexAI
33
+ from vertexai.preview.generative_models import GenerativeModel
34
+ self.chat_model = GenerativeModel("gemini-pro")
35
+
36
+ def system_message(self, message: str) -> any:
37
+ return message
38
+
39
+ def user_message(self, message: str) -> any:
40
+ return message
41
+
42
+ def assistant_message(self, message: str) -> any:
43
+ return message
44
+
45
+ def submit_prompt(self, prompt, **kwargs) -> str:
46
+ response = self.chat_model.generate_content(
47
+ prompt,
48
+ generation_config={
49
+ "temperature": self.temperature,
50
+ },
51
+ )
52
+ return response.text
vanna/remote.py CHANGED
@@ -34,11 +34,13 @@ from .types import (
34
34
  UserOTP,
35
35
  Visibility,
36
36
  )
37
+ from .vannadb import VannaDB_VectorStore
37
38
 
38
39
 
39
- class VannaDefault(VannaBase):
40
+ class VannaDefault(VannaDB_VectorStore):
40
41
  def __init__(self, model: str, api_key: str, config=None):
41
42
  VannaBase.__init__(self, config=config)
43
+ VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config)
42
44
 
43
45
  self._model = model
44
46
  self._api_key = api_key
@@ -48,50 +50,6 @@ class VannaDefault(VannaBase):
48
50
  if config is None or "endpoint" not in config
49
51
  else config["endpoint"]
50
52
  )
51
- self._unauthenticated_endpoint = (
52
- "https://ask.vanna.ai/unauthenticated_rpc"
53
- if config is None or "unauthenticated_endpoint" not in config
54
- else config["unauthenticated_endpoint"]
55
- )
56
-
57
- def _unauthenticated_rpc_call(self, method, params):
58
- headers = {
59
- "Content-Type": "application/json",
60
- }
61
- data = {
62
- "method": method,
63
- "params": [self._dataclass_to_dict(obj) for obj in params],
64
- }
65
-
66
- response = requests.post(
67
- self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
68
- )
69
- return response.json()
70
-
71
- def _rpc_call(self, method, params):
72
- if method != "list_orgs":
73
- headers = {
74
- "Content-Type": "application/json",
75
- "Vanna-Key": self._api_key,
76
- "Vanna-Org": self._model,
77
- }
78
- else:
79
- headers = {
80
- "Content-Type": "application/json",
81
- "Vanna-Key": self._api_key,
82
- "Vanna-Org": "demo-tpc-h",
83
- }
84
-
85
- data = {
86
- "method": method,
87
- "params": [self._dataclass_to_dict(obj) for obj in params],
88
- }
89
-
90
- response = requests.post(self._endpoint, headers=headers, data=json.dumps(data))
91
- return response.json()
92
-
93
- def _dataclass_to_dict(self, obj):
94
- return dataclasses.asdict(obj)
95
53
 
96
54
  def system_message(self, message: str) -> any:
97
55
  return {"role": "system", "content": message}
@@ -102,299 +60,6 @@ class VannaDefault(VannaBase):
102
60
  def assistant_message(self, message: str) -> any:
103
61
  return {"role": "assistant", "content": message}
104
62
 
105
- def get_training_data(self, **kwargs) -> pd.DataFrame:
106
- """
107
- Get the training data for the current model
108
-
109
- **Example:**
110
- ```python
111
- training_data = vn.get_training_data()
112
- ```
113
-
114
- Returns:
115
- pd.DataFrame or None: The training data, or None if an error occurred.
116
-
117
- """
118
- params = []
119
-
120
- d = self._rpc_call(method="get_training_data", params=params)
121
-
122
- if "result" not in d:
123
- return None
124
-
125
- # Load the result into a dataclass
126
- training_data = DataFrameJSON(**d["result"])
127
-
128
- df = pd.read_json(StringIO(training_data.data))
129
-
130
- return df
131
-
132
- def remove_training_data(self, id: str, **kwargs) -> bool:
133
- """
134
- Remove training data from the model
135
-
136
- **Example:**
137
- ```python
138
- vn.remove_training_data(id="1-ddl")
139
- ```
140
-
141
- Args:
142
- id (str): The ID of the training data to remove.
143
- """
144
- params = [StringData(data=id)]
145
-
146
- d = self._rpc_call(method="remove_training_data", params=params)
147
-
148
- if "result" not in d:
149
- raise Exception(f"Error removing training data")
150
-
151
- status = Status(**d["result"])
152
-
153
- if not status.success:
154
- raise Exception(f"Error removing training data: {status.message}")
155
-
156
- return status.success
157
-
158
- def generate_questions(self) -> list[str]:
159
- """
160
- **Example:**
161
- ```python
162
- vn.generate_questions()
163
- # ['What is the average salary of employees?', 'What is the total salary of employees?', ...]
164
- ```
165
-
166
- Generate questions using the Vanna.AI API.
167
-
168
- Returns:
169
- List[str] or None: The questions, or None if an error occurred.
170
- """
171
- d = self._rpc_call(method="generate_questions", params=[])
172
-
173
- if "result" not in d:
174
- return None
175
-
176
- # Load the result into a dataclass
177
- question_string_list = QuestionStringList(**d["result"])
178
-
179
- return question_string_list.questions
180
-
181
- def add_ddl(self, ddl: str, **kwargs) -> str:
182
- """
183
- Adds a DDL statement to the model's training data
184
-
185
- **Example:**
186
- ```python
187
- vn.add_ddl(
188
- ddl="CREATE TABLE employees (id INT, name VARCHAR(255), salary INT)"
189
- )
190
- ```
191
-
192
- Args:
193
- ddl (str): The DDL statement to store.
194
-
195
- Returns:
196
- str: The ID of the DDL statement.
197
- """
198
- params = [StringData(data=ddl)]
199
-
200
- d = self._rpc_call(method="add_ddl", params=params)
201
-
202
- if "result" not in d:
203
- raise Exception("Error adding DDL", d)
204
-
205
- status = StatusWithId(**d["result"])
206
-
207
- return status.id
208
-
209
- def add_documentation(self, documentation: str, **kwargs) -> str:
210
- """
211
- Adds documentation to the model's training data
212
-
213
- **Example:**
214
- ```python
215
- vn.add_documentation(
216
- documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold."
217
- )
218
- ```
219
-
220
- Args:
221
- documentation (str): The documentation string to store.
222
-
223
- Returns:
224
- str: The ID of the documentation string.
225
- """
226
- params = [StringData(data=documentation)]
227
-
228
- d = self._rpc_call(method="add_documentation", params=params)
229
-
230
- if "result" not in d:
231
- raise Exception("Error adding documentation", d)
232
-
233
- status = StatusWithId(**d["result"])
234
-
235
- return status.id
236
-
237
- def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
238
- """
239
- Adds a question and its corresponding SQL query to the model's training data. The preferred way to call this is to use [`vn.train(sql=...)`][vanna.train].
240
-
241
- **Example:**
242
- ```python
243
- vn.add_sql(
244
- question="What is the average salary of employees?",
245
- sql="SELECT AVG(salary) FROM employees"
246
- )
247
- ```
248
-
249
- Args:
250
- question (str): The question to store.
251
- sql (str): The SQL query to store.
252
- tag (Union[str, None]): A tag to associate with the question and SQL query.
253
-
254
- Returns:
255
- str: The ID of the question and SQL query.
256
- """
257
- if "tag" in kwargs:
258
- tag = kwargs["tag"]
259
- else:
260
- tag = "Manually Trained"
261
-
262
- params = [QuestionSQLPair(question=question, sql=sql, tag=tag)]
263
-
264
- d = self._rpc_call(method="add_sql", params=params)
265
-
266
- if "result" not in d:
267
- raise Exception("Error adding question and SQL pair", d)
268
-
269
- status = StatusWithId(**d["result"])
270
-
271
- return status.id
272
-
273
- def generate_embedding(self, data: str, **kwargs) -> list[float]:
274
- """
275
- Not necessary for remote models as embeddings are generated on the server side.
276
- """
277
- pass
278
-
279
- def generate_plotly_code(
280
- self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
281
- ) -> str:
282
- """
283
- **Example:**
284
- ```python
285
- vn.generate_plotly_code(
286
- question="What is the average salary of employees?",
287
- sql="SELECT AVG(salary) FROM employees",
288
- df_metadata=df.dtypes
289
- )
290
- # fig = px.bar(df, x="name", y="salary")
291
- ```
292
- Generate Plotly code using the Vanna.AI API.
293
-
294
- Args:
295
- question (str): The question to generate Plotly code for.
296
- sql (str): The SQL query to generate Plotly code for.
297
- df (pd.DataFrame): The dataframe to generate Plotly code for.
298
- chart_instructions (str): Optional instructions for how to plot the chart.
299
-
300
- Returns:
301
- str or None: The Plotly code, or None if an error occurred.
302
- """
303
- if kwargs is not None and "chart_instructions" in kwargs:
304
- if question is not None:
305
- question = (
306
- question
307
- + " -- When plotting, follow these instructions: "
308
- + kwargs["chart_instructions"]
309
- )
310
- else:
311
- question = (
312
- "When plotting, follow these instructions: "
313
- + kwargs["chart_instructions"]
314
- )
315
-
316
- params = [
317
- DataResult(
318
- question=question,
319
- sql=sql,
320
- table_markdown=df_metadata,
321
- error=None,
322
- correction_attempts=0,
323
- )
324
- ]
325
-
326
- d = self._rpc_call(method="generate_plotly_code", params=params)
327
-
328
- if "result" not in d:
329
- return None
330
-
331
- # Load the result into a dataclass
332
- plotly_code = PlotlyResult(**d["result"])
333
-
334
- return plotly_code.plotly_code
335
-
336
- def generate_question(self, sql: str, **kwargs) -> str:
337
- """
338
-
339
- **Example:**
340
- ```python
341
- vn.generate_question(sql="SELECT * FROM students WHERE name = 'John Doe'")
342
- # 'What is the name of the student?'
343
- ```
344
-
345
- Generate a question from an SQL query using the Vanna.AI API.
346
-
347
- Args:
348
- sql (str): The SQL query to generate a question for.
349
-
350
- Returns:
351
- str or None: The question, or None if an error occurred.
352
-
353
- """
354
- params = [
355
- SQLAnswer(
356
- raw_answer="",
357
- prefix="",
358
- postfix="",
359
- sql=sql,
360
- )
361
- ]
362
-
363
- d = self._rpc_call(method="generate_question", params=params)
364
-
365
- if "result" not in d:
366
- return None
367
-
368
- # Load the result into a dataclass
369
- question = Question(**d["result"])
370
-
371
- return question.question
372
-
373
- def get_sql_prompt(
374
- self,
375
- question: str,
376
- question_sql_list: list,
377
- ddl_list: list,
378
- doc_list: list,
379
- **kwargs,
380
- ):
381
- """
382
- Not necessary for remote models as prompts are generated on the server side.
383
- """
384
-
385
- def get_followup_questions_prompt(
386
- self,
387
- question: str,
388
- df: pd.DataFrame,
389
- question_sql_list: list,
390
- ddl_list: list,
391
- doc_list: list,
392
- **kwargs,
393
- ):
394
- """
395
- Not necessary for remote models as prompts are generated on the server side.
396
- """
397
-
398
63
  def submit_prompt(self, prompt, **kwargs) -> str:
399
64
  # JSON-ify the prompt
400
65
  json_prompt = json.dumps(prompt)
@@ -410,46 +75,3 @@ class VannaDefault(VannaBase):
410
75
  results = StringData(**d["result"])
411
76
 
412
77
  return results.data
413
-
414
- def get_similar_question_sql(self, question: str, **kwargs) -> list:
415
- """
416
- Not necessary for remote models as similar questions are generated on the server side.
417
- """
418
-
419
- def get_related_ddl(self, question: str, **kwargs) -> list:
420
- """
421
- Not necessary for remote models as related DDL statements are generated on the server side.
422
- """
423
-
424
- def get_related_documentation(self, question: str, **kwargs) -> list:
425
- """
426
- Not necessary for remote models as related documentation is generated on the server side.
427
- """
428
-
429
- def generate_sql(self, question: str, **kwargs) -> str:
430
- """
431
- **Example:**
432
- ```python
433
- vn.generate_sql_from_question(question="What is the average salary of employees?")
434
- # SELECT AVG(salary) FROM employees
435
- ```
436
-
437
- Generate an SQL query using the Vanna.AI API.
438
-
439
- Args:
440
- question (str): The question to generate an SQL query for.
441
-
442
- Returns:
443
- str or None: The SQL query, or None if an error occurred.
444
- """
445
- params = [Question(question=question)]
446
-
447
- d = self._rpc_call(method="generate_sql_from_question", params=params)
448
-
449
- if "result" not in d:
450
- return None
451
-
452
- # Load the result into a dataclass
453
- sql_answer = SQLAnswer(**d["result"])
454
-
455
- return sql_answer.sql
@@ -7,14 +7,17 @@ import requests
7
7
 
8
8
  from ..base import VannaBase
9
9
  from ..types import (
10
- DataFrameJSON,
11
- Question,
12
- QuestionSQLPair,
13
- Status,
14
- StatusWithId,
15
- StringData,
16
- TrainingData,
10
+ DataFrameJSON,
11
+ NewOrganization,
12
+ OrganizationList,
13
+ Question,
14
+ QuestionSQLPair,
15
+ Status,
16
+ StatusWithId,
17
+ StringData,
18
+ TrainingData,
17
19
  )
20
+ from ..utils import sanitize_model_name
18
21
 
19
22
 
20
23
  class VannaDB_VectorStore(VannaBase):
@@ -29,27 +32,8 @@ class VannaDB_VectorStore(VannaBase):
29
32
  if config is None or "endpoint" not in config
30
33
  else config["endpoint"]
31
34
  )
32
- self._unauthenticated_endpoint = (
33
- "https://ask.vanna.ai/unauthenticated_rpc"
34
- if config is None or "unauthenticated_endpoint" not in config
35
- else config["unauthenticated_endpoint"]
36
- )
37
35
  self.related_training_data = {}
38
36
 
39
- def _unauthenticated_rpc_call(self, method, params):
40
- headers = {
41
- "Content-Type": "application/json",
42
- }
43
- data = {
44
- "method": method,
45
- "params": [self._dataclass_to_dict(obj) for obj in params],
46
- }
47
-
48
- response = requests.post(
49
- self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
50
- )
51
- return response.json()
52
-
53
37
  def _rpc_call(self, method, params):
54
38
  if method != "list_orgs":
55
39
  headers = {
@@ -75,6 +59,53 @@ class VannaDB_VectorStore(VannaBase):
75
59
  def _dataclass_to_dict(self, obj):
76
60
  return dataclasses.asdict(obj)
77
61
 
62
+ def create_model(self, model: str, **kwargs) -> bool:
63
+ """
64
+ **Example:**
65
+ ```python
66
+ success = vn.create_model("my_model")
67
+ ```
68
+ Create a new model.
69
+
70
+ Args:
71
+ model (str): The name of the model to create.
72
+
73
+ Returns:
74
+ bool: True if the model was created, False otherwise.
75
+ """
76
+ model = sanitize_model_name(model)
77
+ params = [NewOrganization(org_name=model, db_type="")]
78
+
79
+ d = self._rpc_call(method="create_org", params=params)
80
+
81
+ if "result" not in d:
82
+ return False
83
+
84
+ status = Status(**d["result"])
85
+
86
+ return status.success
87
+
88
+ def get_models(self) -> list:
89
+ """
90
+ **Example:**
91
+ ```python
92
+ models = vn.get_models()
93
+ ```
94
+
95
+ List the models that belong to the user.
96
+
97
+ Returns:
98
+ List[str]: A list of model names.
99
+ """
100
+ d = self._rpc_call(method="list_my_models", params=[])
101
+
102
+ if "result" not in d:
103
+ return []
104
+
105
+ orgs = OrganizationList(**d["result"])
106
+
107
+ return orgs.organizations
108
+
78
109
  def generate_embedding(self, data: str, **kwargs) -> list[float]:
79
110
  # This is done server-side
80
111
  pass
@@ -141,7 +172,7 @@ class VannaDB_VectorStore(VannaBase):
141
172
  d = self._rpc_call(method="remove_training_data", params=params)
142
173
 
143
174
  if "result" not in d:
144
- raise Exception(f"Error removing training data")
175
+ raise Exception("Error removing training data")
145
176
 
146
177
  status = Status(**d["result"])
147
178
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -28,11 +28,15 @@ Requires-Dist: chromadb ; extra == "all"
28
28
  Requires-Dist: anthropic ; extra == "all"
29
29
  Requires-Dist: zhipuai ; extra == "all"
30
30
  Requires-Dist: marqo ; extra == "all"
31
+ Requires-Dist: google-generativeai ; extra == "all"
32
+ Requires-Dist: google-cloud-aiplatform ; extra == "all"
31
33
  Requires-Dist: anthropic ; extra == "anthropic"
32
34
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
33
35
  Requires-Dist: chromadb ; extra == "chromadb"
34
36
  Requires-Dist: duckdb ; extra == "duckdb"
35
37
  Requires-Dist: google-generativeai ; extra == "gemini"
38
+ Requires-Dist: google-generativeai ; extra == "google"
39
+ Requires-Dist: google-cloud-aiplatform ; extra == "google"
36
40
  Requires-Dist: marqo ; extra == "marqo"
37
41
  Requires-Dist: mistralai ; extra == "mistralai"
38
42
  Requires-Dist: PyMySQL ; extra == "mysql"
@@ -50,6 +54,7 @@ Provides-Extra: bigquery
50
54
  Provides-Extra: chromadb
51
55
  Provides-Extra: duckdb
52
56
  Provides-Extra: gemini
57
+ Provides-Extra: google
53
58
  Provides-Extra: marqo
54
59
  Provides-Extra: mistralai
55
60
  Provides-Extra: mysql