vanna 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/__init__.py +92 -74
- vanna/types.py +1 -1
- {vanna-0.0.9.dist-info → vanna-0.0.11.dist-info}/METADATA +1 -1
- vanna-0.0.11.dist-info/RECORD +7 -0
- vanna-0.0.9.dist-info/RECORD +0 -7
- {vanna-0.0.9.dist-info → vanna-0.0.11.dist-info}/LICENSE +0 -0
- {vanna-0.0.9.dist-info → vanna-0.0.11.dist-info}/WHEEL +0 -0
- {vanna-0.0.9.dist-info → vanna-0.0.11.dist-info}/top_level.txt +0 -0
vanna/__init__.py
CHANGED
|
@@ -17,11 +17,31 @@ from typing import List, Dict, Any, Union, Optional, Callable, Tuple
|
|
|
17
17
|
import warnings
|
|
18
18
|
import traceback
|
|
19
19
|
|
|
20
|
-
"""Set the API key for Vanna.AI."""
|
|
21
20
|
api_key: Union[str, None] = None # API key for Vanna.AI
|
|
21
|
+
"""
|
|
22
|
+
## Example
|
|
23
|
+
```python
|
|
24
|
+
# Login to Vanna.AI
|
|
25
|
+
vn.login('user@example.com')
|
|
26
|
+
print(vn.api_key)
|
|
27
|
+
|
|
28
|
+
vn.api_key='my_api_key'
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
This is the API key for Vanna.AI. You can set it manually if you have it or use [`vn.login(...)`][vanna.login] to login and set it automatically.
|
|
32
|
+
|
|
33
|
+
"""
|
|
22
34
|
|
|
23
|
-
"""Set the SQL to DataFrame function for Vanna.AI."""
|
|
24
35
|
sql_to_df: Union[Callable[[str], pd.DataFrame], None] = None # Function to convert SQL to a Pandas DataFrame
|
|
36
|
+
"""
|
|
37
|
+
## Example
|
|
38
|
+
```python
|
|
39
|
+
vn.sql_to_df = lambda sql: pd.read_sql(sql, engine)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function.
|
|
43
|
+
|
|
44
|
+
"""
|
|
25
45
|
|
|
26
46
|
__org: Union[str, None] = None # Organization name for Vanna.AI
|
|
27
47
|
|
|
@@ -45,10 +65,10 @@ def __rpc_call(method, params):
|
|
|
45
65
|
global __org
|
|
46
66
|
|
|
47
67
|
if api_key is None:
|
|
48
|
-
raise Exception("API key not set")
|
|
68
|
+
raise Exception("API key not set. Use vn.login(...) to login.")
|
|
49
69
|
|
|
50
70
|
if __org is None and method != "list_orgs":
|
|
51
|
-
raise Exception("
|
|
71
|
+
raise Exception("Datasets not set. Use vn.use_datasets([...]) to set the datasets to use.")
|
|
52
72
|
|
|
53
73
|
if method != "list_orgs":
|
|
54
74
|
headers = {
|
|
@@ -124,17 +144,17 @@ def login(email: str, otp_code: Union[str, None] = None) -> bool:
|
|
|
124
144
|
|
|
125
145
|
return True
|
|
126
146
|
|
|
127
|
-
def
|
|
147
|
+
def list_datasets() -> List[str]:
|
|
128
148
|
"""
|
|
129
149
|
## Example
|
|
130
150
|
```python
|
|
131
|
-
|
|
151
|
+
datasets = vn.list_datasets()
|
|
132
152
|
```
|
|
133
153
|
|
|
134
|
-
List the
|
|
154
|
+
List the datasets that the user is a member of.
|
|
135
155
|
|
|
136
156
|
Returns:
|
|
137
|
-
List[str]: A list of
|
|
157
|
+
List[str]: A list of dataset names.
|
|
138
158
|
"""
|
|
139
159
|
d = __rpc_call(method="list_orgs", params=[])
|
|
140
160
|
|
|
@@ -145,23 +165,23 @@ def list_orgs() -> List[str]:
|
|
|
145
165
|
|
|
146
166
|
return orgs.organizations
|
|
147
167
|
|
|
148
|
-
def
|
|
168
|
+
def create_dataset(dataset: str, db_type: str) -> bool:
|
|
149
169
|
"""
|
|
150
170
|
## Example
|
|
151
171
|
```python
|
|
152
|
-
vn.
|
|
172
|
+
vn.create_dataset(dataset="my-dataset", db_type="postgres")
|
|
153
173
|
```
|
|
154
174
|
|
|
155
|
-
Create a new
|
|
175
|
+
Create a new dataset.
|
|
156
176
|
|
|
157
177
|
Args:
|
|
158
|
-
|
|
159
|
-
db_type (str): The type of database to use for the
|
|
178
|
+
dataset (str): The name of the dataset to create.
|
|
179
|
+
db_type (str): The type of database to use for the dataset. This can be "Snowflake", "BigQuery", "Postgres", or anything else.
|
|
160
180
|
|
|
161
181
|
Returns:
|
|
162
|
-
bool: True if the
|
|
182
|
+
bool: True if the dataset was created successfully, False otherwise.
|
|
163
183
|
"""
|
|
164
|
-
params = [NewOrganization(org_name=
|
|
184
|
+
params = [NewOrganization(org_name=dataset, db_type=db_type)]
|
|
165
185
|
|
|
166
186
|
d = __rpc_call(method="create_org", params=params)
|
|
167
187
|
|
|
@@ -172,24 +192,25 @@ def create_org(org: str, db_type: str) -> bool:
|
|
|
172
192
|
|
|
173
193
|
return status.success
|
|
174
194
|
|
|
175
|
-
def
|
|
195
|
+
def add_user_to_dataset(dataset: str, email: str, is_admin: bool) -> bool:
|
|
176
196
|
"""
|
|
177
197
|
## Example
|
|
178
198
|
```python
|
|
179
|
-
vn.
|
|
199
|
+
vn.add_user_to_dataset(dataset="my-dataset", email="user@example.com")
|
|
180
200
|
```
|
|
181
201
|
|
|
182
|
-
Add a user to an
|
|
202
|
+
Add a user to an dataset.
|
|
183
203
|
|
|
184
204
|
Args:
|
|
185
|
-
|
|
205
|
+
dataset (str): The name of the dataset to add the user to.
|
|
186
206
|
email (str): The email address of the user to add.
|
|
207
|
+
is_admin (bool): Whether or not the user should be an admin.
|
|
187
208
|
|
|
188
209
|
Returns:
|
|
189
210
|
bool: True if the user was added successfully, False otherwise.
|
|
190
211
|
"""
|
|
191
212
|
|
|
192
|
-
params = [NewOrganizationMember(org_name=
|
|
213
|
+
params = [NewOrganizationMember(org_name=dataset, email=email, is_admin=is_admin)]
|
|
193
214
|
|
|
194
215
|
d = __rpc_call(method="add_user_to_org", params=params)
|
|
195
216
|
|
|
@@ -203,21 +224,20 @@ def add_user_to_org(org: str, email: str, is_admin: bool) -> bool:
|
|
|
203
224
|
|
|
204
225
|
return status.success
|
|
205
226
|
|
|
206
|
-
def
|
|
227
|
+
def set_dataset_visibility(visibility: bool) -> bool:
|
|
207
228
|
"""
|
|
208
229
|
## Example
|
|
209
230
|
```python
|
|
210
|
-
vn.
|
|
231
|
+
vn.set_dataset_visibility(visibility=True)
|
|
211
232
|
```
|
|
212
233
|
|
|
213
|
-
Set the visibility of
|
|
234
|
+
Set the visibility of the current dataset. If a dataset is visible, anyone can see it. If it is not visible, only members of the dataset can see it.
|
|
214
235
|
|
|
215
236
|
Args:
|
|
216
|
-
|
|
217
|
-
visibility (bool): Whether or not the organization should be visible.
|
|
237
|
+
visibility (bool): Whether or not the dataset should be publicly visible.
|
|
218
238
|
|
|
219
239
|
Returns:
|
|
220
|
-
bool: True if the
|
|
240
|
+
bool: True if the dataset visibility was set successfully, False otherwise.
|
|
221
241
|
"""
|
|
222
242
|
params = [Visibility(visibility=visibility)]
|
|
223
243
|
|
|
@@ -230,69 +250,32 @@ def set_org_visibility(visibility: bool) -> bool:
|
|
|
230
250
|
|
|
231
251
|
return status.success
|
|
232
252
|
|
|
233
|
-
def set_org(org: str) -> None:
|
|
234
|
-
"""
|
|
235
|
-
DEPRECATED. Use [`use_datasets`][vanna.use_datasets] instead.
|
|
236
|
-
|
|
237
|
-
Args:
|
|
238
|
-
org (str): The organization name.
|
|
239
|
-
"""
|
|
240
|
-
global __org
|
|
241
|
-
print("vn.set_org is deprecated. Please use vn.use_datasets instead.")
|
|
242
|
-
warnings.warn("vn.set_org is deprecated. Please use vn.use_datasets instead.", DeprecationWarning)
|
|
243
|
-
|
|
244
|
-
my_orgs = list_orgs()
|
|
245
|
-
if org not in my_orgs:
|
|
246
|
-
# Check if org exists
|
|
247
|
-
d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)])
|
|
248
|
-
|
|
249
|
-
if 'result' not in d:
|
|
250
|
-
raise Exception("Failed to check if organization exists")
|
|
251
|
-
|
|
252
|
-
status = Status(**d['result'])
|
|
253
|
-
|
|
254
|
-
if status.success:
|
|
255
|
-
raise Exception(f"An organization with the name {org} already exists")
|
|
256
|
-
|
|
257
|
-
create = input(f"Would you like to create organization '{org}'? (y/n): ")
|
|
258
|
-
|
|
259
|
-
if create.lower() == 'y':
|
|
260
|
-
db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ")
|
|
261
|
-
__org = 'demo-tpc-h'
|
|
262
|
-
if create_org(org=org, db_type=db_type):
|
|
263
|
-
__org = org
|
|
264
|
-
else:
|
|
265
|
-
__org = None
|
|
266
|
-
raise Exception("Failed to create organization")
|
|
267
|
-
else:
|
|
268
|
-
__org = org
|
|
269
|
-
|
|
270
253
|
def _set_org(org: str) -> None:
|
|
271
254
|
global __org
|
|
272
255
|
|
|
273
|
-
my_orgs =
|
|
256
|
+
my_orgs = list_datasets()
|
|
274
257
|
if org not in my_orgs:
|
|
275
258
|
# Check if org exists
|
|
276
259
|
d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)])
|
|
277
260
|
|
|
278
261
|
if 'result' not in d:
|
|
279
|
-
raise Exception("Failed to check if
|
|
262
|
+
raise Exception("Failed to check if dataset exists")
|
|
280
263
|
|
|
281
264
|
status = Status(**d['result'])
|
|
282
265
|
|
|
283
266
|
if status.success:
|
|
284
267
|
raise Exception(f"An organization with the name {org} already exists")
|
|
285
268
|
|
|
286
|
-
create = input(f"Would you like to create
|
|
269
|
+
create = input(f"Would you like to create dataset '{org}'? (y/n): ")
|
|
287
270
|
|
|
288
271
|
if create.lower() == 'y':
|
|
289
272
|
db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ")
|
|
290
273
|
__org = 'demo-tpc-h'
|
|
291
|
-
if
|
|
274
|
+
if create_dataset(dataset=org, db_type=db_type):
|
|
292
275
|
__org = org
|
|
293
276
|
else:
|
|
294
277
|
__org = None
|
|
295
|
-
raise Exception("Failed to create
|
|
278
|
+
raise Exception("Failed to create dataset")
|
|
296
279
|
else:
|
|
297
280
|
__org = org
|
|
298
281
|
|
|
@@ -379,7 +362,7 @@ def store_documentation(documentation: str) -> bool:
|
|
|
379
362
|
## Example
|
|
380
363
|
```python
|
|
381
364
|
vn.store_documentation(
|
|
382
|
-
documentation="
|
|
365
|
+
documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold."
|
|
383
366
|
)
|
|
384
367
|
```
|
|
385
368
|
|
|
@@ -423,7 +406,7 @@ def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg:
|
|
|
423
406
|
```python
|
|
424
407
|
vn.flag_sql_for_review(question="What is the average salary of employees?")
|
|
425
408
|
```
|
|
426
|
-
Flag a question and its corresponding SQL query for review. You can
|
|
409
|
+
Flag a question and its corresponding SQL query for review. You can see the tag show up in [`vn.get_all_questions()`][vanna.get_all_questions]
|
|
427
410
|
|
|
428
411
|
Args:
|
|
429
412
|
question (str): The question to flag.
|
|
@@ -527,6 +510,41 @@ def generate_sql(question: str) -> str:
|
|
|
527
510
|
|
|
528
511
|
return sql_answer.sql
|
|
529
512
|
|
|
513
|
+
def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]:
|
|
514
|
+
"""
|
|
515
|
+
## Example
|
|
516
|
+
```python
|
|
517
|
+
vn.generate_followup_questions(question="What is the average salary of employees?", df=df)
|
|
518
|
+
# ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...]
|
|
519
|
+
```
|
|
520
|
+
|
|
521
|
+
Generate follow-up questions using the Vanna.AI API.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
question (str): The question to generate follow-up questions for.
|
|
525
|
+
df (pd.DataFrame): The DataFrame to generate follow-up questions for.
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
List[str] or None: The follow-up questions, or None if an error occurred.
|
|
529
|
+
"""
|
|
530
|
+
params = [DataResult(
|
|
531
|
+
question=question,
|
|
532
|
+
sql=None,
|
|
533
|
+
table_markdown=df.head().to_markdown(),
|
|
534
|
+
error=None,
|
|
535
|
+
correction_attempts=0,
|
|
536
|
+
)]
|
|
537
|
+
|
|
538
|
+
d = __rpc_call(method="generate_followup_questions", params=params)
|
|
539
|
+
|
|
540
|
+
if 'result' not in d:
|
|
541
|
+
return None
|
|
542
|
+
|
|
543
|
+
# Load the result into a dataclass
|
|
544
|
+
question_string_list = QuestionStringList(**d['result'])
|
|
545
|
+
|
|
546
|
+
return question_string_list.questions
|
|
547
|
+
|
|
530
548
|
def generate_questions() -> List[str]:
|
|
531
549
|
"""
|
|
532
550
|
## Example
|
|
@@ -591,15 +609,15 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
|
|
|
591
609
|
if print_results:
|
|
592
610
|
print(df.head().to_markdown())
|
|
593
611
|
|
|
612
|
+
if len(df) > 0 and auto_train:
|
|
613
|
+
store_sql(question=question, sql=sql, tag="SQL Ran")
|
|
614
|
+
|
|
594
615
|
try:
|
|
595
616
|
plotly_code = generate_plotly_code(question=question, sql=sql, df=df)
|
|
596
617
|
fig = get_plotly_figure(plotly_code=plotly_code, df=df)
|
|
597
618
|
if print_results:
|
|
598
619
|
fig.show()
|
|
599
620
|
|
|
600
|
-
if len(df) > 0 and auto_train:
|
|
601
|
-
store_sql(question=question, sql=sql, tag="Assumed Correct")
|
|
602
|
-
|
|
603
621
|
return sql, df, fig
|
|
604
622
|
|
|
605
623
|
except Exception as e:
|
|
@@ -687,7 +705,7 @@ def get_plotly_figure(plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
|
|
|
687
705
|
|
|
688
706
|
def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
|
|
689
707
|
"""
|
|
690
|
-
DEPRECATED. Use
|
|
708
|
+
DEPRECATED. Use `vn.sql_to_df` instead.
|
|
691
709
|
Run the SQL query and return the results as a pandas dataframe. This is just a helper function that does not use the Vanna.AI API.
|
|
692
710
|
|
|
693
711
|
Args:
|
vanna/types.py
CHANGED
|
@@ -83,7 +83,7 @@ class QuestionCategory:
|
|
|
83
83
|
NO_SQL_GENERATED = "No SQL Generated"
|
|
84
84
|
SQL_UNABLE_TO_RUN = "SQL Unable to Run"
|
|
85
85
|
BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query"
|
|
86
|
-
|
|
86
|
+
SQL_RAN = "SQL Ran Successfully"
|
|
87
87
|
FLAGGED_FOR_REVIEW = "Flagged for Review"
|
|
88
88
|
REVIEWED_AND_APPROVED = "Reviewed and Approved"
|
|
89
89
|
REVIEWED_AND_REJECTED = "Reviewed and Rejected"
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
vanna/__init__.py,sha256=cSDTlVo1vBIz0Svs-b_LMZJSQBvIEDzPW5mNVtUdOzA,22724
|
|
2
|
+
vanna/types.py,sha256=-vkO6_sc3qVhxt4KFY4uQ23rx_kgiKhUH5Ty2VEOAA0,2792
|
|
3
|
+
vanna-0.0.11.dist-info/LICENSE,sha256=4oFm5g_8bkN2Q-Xoo2z3Q-80BmoytRzgeYMSLuXzjoA,1065
|
|
4
|
+
vanna-0.0.11.dist-info/METADATA,sha256=Q5GlYiPK3mRcIXFYMX6Hs3CDcjY68y_nVa8D8InWPMA,4879
|
|
5
|
+
vanna-0.0.11.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
|
6
|
+
vanna-0.0.11.dist-info/top_level.txt,sha256=LA0zKJsqirV2m8gWffKnavodMj1NrHcMDRZcz3CElRs,6
|
|
7
|
+
vanna-0.0.11.dist-info/RECORD,,
|
vanna-0.0.9.dist-info/RECORD
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
vanna/__init__.py,sha256=mmhQchFaSGXl66Y3ZSJCVejajDTaSU-o7FnpMKOhY2s,22459
|
|
2
|
-
vanna/types.py,sha256=5PWlsV6wXbNJng06d17o2MHhKdZddhDoyzuzFTV2AK0,2795
|
|
3
|
-
vanna-0.0.9.dist-info/LICENSE,sha256=4oFm5g_8bkN2Q-Xoo2z3Q-80BmoytRzgeYMSLuXzjoA,1065
|
|
4
|
-
vanna-0.0.9.dist-info/METADATA,sha256=rHJJCakOmtsHp9fzsU-ZJtvByjsaTJYrNbC-TiW_HSo,4878
|
|
5
|
-
vanna-0.0.9.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
|
6
|
-
vanna-0.0.9.dist-info/top_level.txt,sha256=LA0zKJsqirV2m8gWffKnavodMj1NrHcMDRZcz3CElRs,6
|
|
7
|
-
vanna-0.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|