vanna 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/__init__.py CHANGED
@@ -17,11 +17,31 @@ from typing import List, Dict, Any, Union, Optional, Callable, Tuple
17
17
  import warnings
18
18
  import traceback
19
19
 
20
- """Set the API key for Vanna.AI."""
21
20
  api_key: Union[str, None] = None # API key for Vanna.AI
21
+ """
22
+ ## Example
23
+ ```python
24
+ # Login to Vanna.AI
25
+ vn.login('user@example.com')
26
+ print(vn.api_key)
27
+
28
+ vn.api_key='my_api_key'
29
+ ```
30
+
31
+ This is the API key for Vanna.AI. You can set it manually if you have it or use [`vn.login(...)`][vanna.login] to login and set it automatically.
32
+
33
+ """
22
34
 
23
- """Set the SQL to DataFrame function for Vanna.AI."""
24
35
  sql_to_df: Union[Callable[[str], pd.DataFrame], None] = None # Function to convert SQL to a Pandas DataFrame
36
+ """
37
+ ## Example
38
+ ```python
39
+ vn.sql_to_df = lambda sql: pd.read_sql(sql, engine)
40
+ ```
41
+
42
+ Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function.
43
+
44
+ """
25
45
 
26
46
  __org: Union[str, None] = None # Organization name for Vanna.AI
27
47
 
@@ -45,10 +65,10 @@ def __rpc_call(method, params):
45
65
  global __org
46
66
 
47
67
  if api_key is None:
48
- raise Exception("API key not set")
68
+ raise Exception("API key not set. Use vn.login(...) to login.")
49
69
 
50
70
  if __org is None and method != "list_orgs":
51
- raise Exception("Organization name not set")
71
+ raise Exception("Datasets not set. Use vn.use_datasets([...]) to set the datasets to use.")
52
72
 
53
73
  if method != "list_orgs":
54
74
  headers = {
@@ -124,17 +144,17 @@ def login(email: str, otp_code: Union[str, None] = None) -> bool:
124
144
 
125
145
  return True
126
146
 
127
- def list_orgs() -> List[str]:
147
+ def list_datasets() -> List[str]:
128
148
  """
129
149
  ## Example
130
150
  ```python
131
- orgs = vn.list_orgs()
151
+ datasets = vn.list_datasets()
132
152
  ```
133
153
 
134
- List the organizations that the user is a member of.
154
+ List the datasets that the user is a member of.
135
155
 
136
156
  Returns:
137
- List[str]: A list of organization names.
157
+ List[str]: A list of dataset names.
138
158
  """
139
159
  d = __rpc_call(method="list_orgs", params=[])
140
160
 
@@ -145,23 +165,23 @@ def list_orgs() -> List[str]:
145
165
 
146
166
  return orgs.organizations
147
167
 
148
- def create_org(org: str, db_type: str) -> bool:
168
+ def create_dataset(dataset: str, db_type: str) -> bool:
149
169
  """
150
170
  ## Example
151
171
  ```python
152
- vn.create_org(org="my-org", db_type="postgres")
172
+ vn.create_dataset(dataset="my-dataset", db_type="postgres")
153
173
  ```
154
174
 
155
- Create a new organization.
175
+ Create a new dataset.
156
176
 
157
177
  Args:
158
- org (str): The name of the organization to create.
159
- db_type (str): The type of database to use for the organization. This can be "Snowflake", "BigQuery", "Postgres", or anything else.
178
+ dataset (str): The name of the dataset to create.
179
+ db_type (str): The type of database to use for the dataset. This can be "Snowflake", "BigQuery", "Postgres", or anything else.
160
180
 
161
181
  Returns:
162
- bool: True if the organization was created successfully, False otherwise.
182
+ bool: True if the dataset was created successfully, False otherwise.
163
183
  """
164
- params = [NewOrganization(org_name=org, db_type=db_type)]
184
+ params = [NewOrganization(org_name=dataset, db_type=db_type)]
165
185
 
166
186
  d = __rpc_call(method="create_org", params=params)
167
187
 
@@ -172,24 +192,25 @@ def create_org(org: str, db_type: str) -> bool:
172
192
 
173
193
  return status.success
174
194
 
175
- def add_user_to_org(org: str, email: str, is_admin: bool) -> bool:
195
+ def add_user_to_dataset(dataset: str, email: str, is_admin: bool) -> bool:
176
196
  """
177
197
  ## Example
178
198
  ```python
179
- vn.add_user_to_org(org="my-org", email="user@example.com")
199
+ vn.add_user_to_dataset(dataset="my-dataset", email="user@example.com")
180
200
  ```
181
201
 
182
- Add a user to an organization.
202
+ Add a user to an dataset.
183
203
 
184
204
  Args:
185
- org (str): The name of the organization to add the user to.
205
+ dataset (str): The name of the dataset to add the user to.
186
206
  email (str): The email address of the user to add.
207
+ is_admin (bool): Whether or not the user should be an admin.
187
208
 
188
209
  Returns:
189
210
  bool: True if the user was added successfully, False otherwise.
190
211
  """
191
212
 
192
- params = [NewOrganizationMember(org_name=org, email=email, is_admin=is_admin)]
213
+ params = [NewOrganizationMember(org_name=dataset, email=email, is_admin=is_admin)]
193
214
 
194
215
  d = __rpc_call(method="add_user_to_org", params=params)
195
216
 
@@ -203,21 +224,20 @@ def add_user_to_org(org: str, email: str, is_admin: bool) -> bool:
203
224
 
204
225
  return status.success
205
226
 
206
- def set_org_visibility(visibility: bool) -> bool:
227
+ def set_dataset_visibility(visibility: bool) -> bool:
207
228
  """
208
229
  ## Example
209
230
  ```python
210
- vn.set_org_visibility(org="my-org", visibility=True)
231
+ vn.set_dataset_visibility(visibility=True)
211
232
  ```
212
233
 
213
- Set the visibility of an organization. If an organization is visible, anyone can see it. If it is not visible, only members of the organization can see it.
234
+ Set the visibility of the current dataset. If a dataset is visible, anyone can see it. If it is not visible, only members of the dataset can see it.
214
235
 
215
236
  Args:
216
- org (str): The name of the organization to set the visibility of.
217
- visibility (bool): Whether or not the organization should be visible.
237
+ visibility (bool): Whether or not the dataset should be publicly visible.
218
238
 
219
239
  Returns:
220
- bool: True if the organization visibility was set successfully, False otherwise.
240
+ bool: True if the dataset visibility was set successfully, False otherwise.
221
241
  """
222
242
  params = [Visibility(visibility=visibility)]
223
243
 
@@ -230,69 +250,32 @@ def set_org_visibility(visibility: bool) -> bool:
230
250
 
231
251
  return status.success
232
252
 
233
- def set_org(org: str) -> None:
234
- """
235
- DEPRECATED. Use [`use_datasets`][vanna.use_datasets] instead.
236
-
237
- Args:
238
- org (str): The organization name.
239
- """
240
- global __org
241
- print("vn.set_org is deprecated. Please use vn.use_datasets instead.")
242
- warnings.warn("vn.set_org is deprecated. Please use vn.use_datasets instead.", DeprecationWarning)
243
-
244
- my_orgs = list_orgs()
245
- if org not in my_orgs:
246
- # Check if org exists
247
- d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)])
248
-
249
- if 'result' not in d:
250
- raise Exception("Failed to check if organization exists")
251
-
252
- status = Status(**d['result'])
253
-
254
- if status.success:
255
- raise Exception(f"An organization with the name {org} already exists")
256
-
257
- create = input(f"Would you like to create organization '{org}'? (y/n): ")
258
-
259
- if create.lower() == 'y':
260
- db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ")
261
- __org = 'demo-tpc-h'
262
- if create_org(org=org, db_type=db_type):
263
- __org = org
264
- else:
265
- __org = None
266
- raise Exception("Failed to create organization")
267
- else:
268
- __org = org
269
-
270
253
  def _set_org(org: str) -> None:
271
254
  global __org
272
255
 
273
- my_orgs = list_orgs()
256
+ my_orgs = list_datasets()
274
257
  if org not in my_orgs:
275
258
  # Check if org exists
276
259
  d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)])
277
260
 
278
261
  if 'result' not in d:
279
- raise Exception("Failed to check if organization exists")
262
+ raise Exception("Failed to check if dataset exists")
280
263
 
281
264
  status = Status(**d['result'])
282
265
 
283
266
  if status.success:
284
267
  raise Exception(f"An organization with the name {org} already exists")
285
268
 
286
- create = input(f"Would you like to create organization '{org}'? (y/n): ")
269
+ create = input(f"Would you like to create dataset '{org}'? (y/n): ")
287
270
 
288
271
  if create.lower() == 'y':
289
272
  db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ")
290
273
  __org = 'demo-tpc-h'
291
- if create_org(org=org, db_type=db_type):
274
+ if create_dataset(dataset=org, db_type=db_type):
292
275
  __org = org
293
276
  else:
294
277
  __org = None
295
- raise Exception("Failed to create organization")
278
+ raise Exception("Failed to create dataset")
296
279
  else:
297
280
  __org = org
298
281
 
@@ -379,7 +362,7 @@ def store_documentation(documentation: str) -> bool:
379
362
  ## Example
380
363
  ```python
381
364
  vn.store_documentation(
382
- documentation="This is a documentation string for the employees table."
365
+ documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold."
383
366
  )
384
367
  ```
385
368
 
@@ -423,7 +406,7 @@ def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg:
423
406
  ```python
424
407
  vn.flag_sql_for_review(question="What is the average salary of employees?")
425
408
  ```
426
- Flag a question and its corresponding SQL query for review. You can later retrieve the flagged questions using [`get_flagged_questions()`][vanna.get_flagged_questions].
409
+ Flag a question and its corresponding SQL query for review. You can see the tag show up in [`vn.get_all_questions()`][vanna.get_all_questions]
427
410
 
428
411
  Args:
429
412
  question (str): The question to flag.
@@ -527,6 +510,41 @@ def generate_sql(question: str) -> str:
527
510
 
528
511
  return sql_answer.sql
529
512
 
513
+ def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]:
514
+ """
515
+ ## Example
516
+ ```python
517
+ vn.generate_followup_questions(question="What is the average salary of employees?", df=df)
518
+ # ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...]
519
+ ```
520
+
521
+ Generate follow-up questions using the Vanna.AI API.
522
+
523
+ Args:
524
+ question (str): The question to generate follow-up questions for.
525
+ df (pd.DataFrame): The DataFrame to generate follow-up questions for.
526
+
527
+ Returns:
528
+ List[str] or None: The follow-up questions, or None if an error occurred.
529
+ """
530
+ params = [DataResult(
531
+ question=question,
532
+ sql=None,
533
+ table_markdown=df.head().to_markdown(),
534
+ error=None,
535
+ correction_attempts=0,
536
+ )]
537
+
538
+ d = __rpc_call(method="generate_followup_questions", params=params)
539
+
540
+ if 'result' not in d:
541
+ return None
542
+
543
+ # Load the result into a dataclass
544
+ question_string_list = QuestionStringList(**d['result'])
545
+
546
+ return question_string_list.questions
547
+
530
548
  def generate_questions() -> List[str]:
531
549
  """
532
550
  ## Example
@@ -591,15 +609,15 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai
591
609
  if print_results:
592
610
  print(df.head().to_markdown())
593
611
 
612
+ if len(df) > 0 and auto_train:
613
+ store_sql(question=question, sql=sql, tag="SQL Ran")
614
+
594
615
  try:
595
616
  plotly_code = generate_plotly_code(question=question, sql=sql, df=df)
596
617
  fig = get_plotly_figure(plotly_code=plotly_code, df=df)
597
618
  if print_results:
598
619
  fig.show()
599
620
 
600
- if len(df) > 0 and auto_train:
601
- store_sql(question=question, sql=sql, tag="Assumed Correct")
602
-
603
621
  return sql, df, fig
604
622
 
605
623
  except Exception as e:
@@ -687,7 +705,7 @@ def get_plotly_figure(plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
687
705
 
688
706
  def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
689
707
  """
690
- DEPRECATED. Use [`vanna.sql_to_df()`][vanna.sql_to_df] instead.
708
+ DEPRECATED. Use `vn.sql_to_df` instead.
691
709
  Run the SQL query and return the results as a pandas dataframe. This is just a helper function that does not use the Vanna.AI API.
692
710
 
693
711
  Args:
vanna/types.py CHANGED
@@ -83,7 +83,7 @@ class QuestionCategory:
83
83
  NO_SQL_GENERATED = "No SQL Generated"
84
84
  SQL_UNABLE_TO_RUN = "SQL Unable to Run"
85
85
  BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query"
86
- ASSUMED_CORRECT = "Assumed Correct"
86
+ SQL_RAN = "SQL Ran Successfully"
87
87
  FLAGGED_FOR_REVIEW = "Flagged for Review"
88
88
  REVIEWED_AND_APPROVED = "Reviewed and Approved"
89
89
  REVIEWED_AND_REJECTED = "Reviewed and Rejected"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Project-URL: Homepage, https://github.com/vanna-ai/vanna-py
@@ -0,0 +1,7 @@
1
+ vanna/__init__.py,sha256=cSDTlVo1vBIz0Svs-b_LMZJSQBvIEDzPW5mNVtUdOzA,22724
2
+ vanna/types.py,sha256=-vkO6_sc3qVhxt4KFY4uQ23rx_kgiKhUH5Ty2VEOAA0,2792
3
+ vanna-0.0.11.dist-info/LICENSE,sha256=4oFm5g_8bkN2Q-Xoo2z3Q-80BmoytRzgeYMSLuXzjoA,1065
4
+ vanna-0.0.11.dist-info/METADATA,sha256=Q5GlYiPK3mRcIXFYMX6Hs3CDcjY68y_nVa8D8InWPMA,4879
5
+ vanna-0.0.11.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
6
+ vanna-0.0.11.dist-info/top_level.txt,sha256=LA0zKJsqirV2m8gWffKnavodMj1NrHcMDRZcz3CElRs,6
7
+ vanna-0.0.11.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- vanna/__init__.py,sha256=mmhQchFaSGXl66Y3ZSJCVejajDTaSU-o7FnpMKOhY2s,22459
2
- vanna/types.py,sha256=5PWlsV6wXbNJng06d17o2MHhKdZddhDoyzuzFTV2AK0,2795
3
- vanna-0.0.9.dist-info/LICENSE,sha256=4oFm5g_8bkN2Q-Xoo2z3Q-80BmoytRzgeYMSLuXzjoA,1065
4
- vanna-0.0.9.dist-info/METADATA,sha256=rHJJCakOmtsHp9fzsU-ZJtvByjsaTJYrNbC-TiW_HSo,4878
5
- vanna-0.0.9.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
6
- vanna-0.0.9.dist-info/top_level.txt,sha256=LA0zKJsqirV2m8gWffKnavodMj1NrHcMDRZcz3CElRs,6
7
- vanna-0.0.9.dist-info/RECORD,,
File without changes