vanna 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class ZhipuAI_Chat(VannaBase):
40
40
  initial_prompt: str, ddl_list: List[str], max_tokens: int = 14000
41
41
  ) -> str:
42
42
  if len(ddl_list) > 0:
43
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
43
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
44
44
 
45
45
  for ddl in ddl_list:
46
46
  if (
@@ -57,7 +57,7 @@ class ZhipuAI_Chat(VannaBase):
57
57
  initial_prompt: str, documentation_List: List[str], max_tokens: int = 14000
58
58
  ) -> str:
59
59
  if len(documentation_List) > 0:
60
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
60
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
61
61
 
62
62
  for documentation in documentation_List:
63
63
  if (
@@ -74,7 +74,7 @@ class ZhipuAI_Chat(VannaBase):
74
74
  initial_prompt: str, sql_List: List[str], max_tokens: int = 14000
75
75
  ) -> str:
76
76
  if len(sql_List) > 0:
77
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
77
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
78
78
 
79
79
  for question in sql_List:
80
80
  if (
@@ -220,19 +220,14 @@ class ZhipuAI_Chat(VannaBase):
220
220
  if len(prompt) == 0:
221
221
  raise Exception("Prompt is empty")
222
222
 
223
- client = ZhipuAI(api_key=self.api_key) # 填写您自己的APIKey
223
+ client = ZhipuAI(api_key=self.api_key)
224
224
  response = client.chat.completions.create(
225
- model="glm-4", # 填写需要调用的模型名称
225
+ model="glm-4",
226
226
  max_tokens=max_tokens,
227
227
  temperature=temperature,
228
228
  top_p=top_p,
229
229
  stop=stop,
230
230
  messages=prompt,
231
231
  )
232
- # print(prompt)
233
-
234
- # print(response)
235
-
236
- # print(f"Cost {response.usage.total_tokens} token")
237
232
 
238
233
  return response.choices[0].message.content
vanna/base/base.py CHANGED
@@ -124,6 +124,18 @@ class VannaBase(ABC):
124
124
  return self.extract_sql(llm_response)
125
125
 
126
126
  def extract_sql(self, llm_response: str) -> str:
127
+ # If the llm_response is not markdown formatted, extract sql by finding select and ; in the response
128
+ sql = re.search(r"SELECT.*?;", llm_response, re.DOTALL)
129
+ if sql:
130
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}"
131
+ )
132
+ return sql.group(0)
133
+
134
+ # If the llm_response contains a CTE (with clause), extract the sql bewteen WITH and ;
135
+ sql = re.search(r"WITH.*?;", llm_response, re.DOTALL)
136
+ if sql:
137
+ self.log(f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(0)}")
138
+ return sql.group(0)
127
139
  # If the llm_response contains a markdown code block, with or without the sql tag, extract the sql from it
128
140
  sql = re.search(r"```sql\n(.*)```", llm_response, re.DOTALL)
129
141
  if sql:
@@ -147,19 +159,21 @@ class VannaBase(ABC):
147
159
  return False
148
160
 
149
161
  def generate_followup_questions(
150
- self, question: str, sql: str, df: pd.DataFrame, **kwargs
162
+ self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
151
163
  ) -> list:
152
164
  """
153
165
  **Example:**
154
166
  ```python
155
- vn.generate_followup_questions("What are the top 10 customers by sales?", df)
167
+ vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df)
156
168
  ```
157
169
 
158
170
  Generate a list of followup questions that you can ask Vanna.AI.
159
171
 
160
172
  Args:
161
173
  question (str): The question that was asked.
174
+ sql (str): The LLM-generated SQL query.
162
175
  df (pd.DataFrame): The results of the SQL query.
176
+ n_questions (int): Number of follow-up questions to generate.
163
177
 
164
178
  Returns:
165
179
  list: A list of followup questions that you can ask Vanna.AI.
@@ -170,7 +184,7 @@ class VannaBase(ABC):
170
184
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
171
185
  ),
172
186
  self.user_message(
173
- "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
187
+ f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
174
188
  ),
175
189
  ]
176
190
 
@@ -361,7 +375,7 @@ class VannaBase(ABC):
361
375
  self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
362
376
  ) -> str:
363
377
  if len(ddl_list) > 0:
364
- initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
378
+ initial_prompt += "\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
365
379
 
366
380
  for ddl in ddl_list:
367
381
  if (
@@ -380,7 +394,7 @@ class VannaBase(ABC):
380
394
  max_tokens: int = 14000,
381
395
  ) -> str:
382
396
  if len(documentation_list) > 0:
383
- initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
397
+ initial_prompt += "\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
384
398
 
385
399
  for documentation in documentation_list:
386
400
  if (
@@ -396,7 +410,7 @@ class VannaBase(ABC):
396
410
  self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
397
411
  ) -> str:
398
412
  if len(sql_list) > 0:
399
- initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
413
+ initial_prompt += "\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
400
414
 
401
415
  for question in sql_list:
402
416
  if (
@@ -888,6 +902,94 @@ class VannaBase(ABC):
888
902
  self.run_sql_is_set = True
889
903
  self.run_sql = run_sql_mysql
890
904
 
905
+ def connect_to_oracle(
906
+ self,
907
+ user: str = None,
908
+ password: str = None,
909
+ dsn: str = None,
910
+ ):
911
+
912
+ """
913
+ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
914
+ **Example:**
915
+ ```python
916
+ vn.connect_to_oracle(
917
+ user="username",
918
+ password="password",
919
+ dns="host:port/sid",
920
+ )
921
+ ```
922
+ Args:
923
+ USER (str): Oracle db user name.
924
+ PASSWORD (str): Oracle db user password.
925
+ DSN (str): Oracle db host ip - host:port/sid.
926
+ """
927
+
928
+ try:
929
+ import oracledb
930
+ except ImportError:
931
+
932
+ raise DependencyError(
933
+ "You need to install required dependencies to execute this method,"
934
+ " run command: \npip install oracledb"
935
+ )
936
+
937
+ if not dsn:
938
+ dsn = os.getenv("DSN")
939
+
940
+ if not dsn:
941
+ raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid")
942
+
943
+ if not user:
944
+ user = os.getenv("USER")
945
+
946
+ if not user:
947
+ raise ImproperlyConfigured("Please set your Oracle db user")
948
+
949
+ if not password:
950
+ password = os.getenv("PASSWORD")
951
+
952
+ if not password:
953
+ raise ImproperlyConfigured("Please set your Oracle db password")
954
+
955
+ conn = None
956
+
957
+ try:
958
+ conn = oracledb.connect(
959
+ user=user,
960
+ password=password,
961
+ dsn=dsn,
962
+ )
963
+ except oracledb.Error as e:
964
+ raise ValidationError(e)
965
+
966
+ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]:
967
+ if conn:
968
+ try:
969
+ sql = sql.rstrip()
970
+ if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error.
971
+ sql = sql[:-1]
972
+
973
+ cs = conn.cursor()
974
+ cs.execute(sql)
975
+ results = cs.fetchall()
976
+
977
+ # Create a pandas dataframe from the results
978
+ df = pd.DataFrame(
979
+ results, columns=[desc[0] for desc in cs.description]
980
+ )
981
+ return df
982
+
983
+ except oracledb.Error as e:
984
+ conn.rollback()
985
+ raise ValidationError(e)
986
+
987
+ except Exception as e:
988
+ conn.rollback()
989
+ raise e
990
+
991
+ self.run_sql_is_set = True
992
+ self.run_sql = run_sql_oracle
891
993
 
892
994
  def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
893
995
  """
@@ -1236,7 +1338,7 @@ class VannaBase(ABC):
1236
1338
  """
1237
1339
 
1238
1340
  if question and not sql:
1239
- raise ValidationError(f"Please also provide a SQL query")
1341
+ raise ValidationError("Please also provide a SQL query")
1240
1342
 
1241
1343
  if documentation:
1242
1344
  print("Adding documentation....")
@@ -1304,12 +1406,14 @@ class VannaBase(ABC):
1304
1406
  table_column = df.columns[
1305
1407
  df.columns.str.lower().str.contains("table_name")
1306
1408
  ].to_list()[0]
1307
- column_column = df.columns[
1308
- df.columns.str.lower().str.contains("column_name")
1309
- ].to_list()[0]
1310
- data_type_column = df.columns[
1311
- df.columns.str.lower().str.contains("data_type")
1312
- ].to_list()[0]
1409
+ columns = [database_column,
1410
+ schema_column,
1411
+ table_column]
1412
+ candidates = ["column_name",
1413
+ "data_type",
1414
+ "comment"]
1415
+ matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
1416
+ columns += df.columns[matches].to_list()
1313
1417
 
1314
1418
  plan = TrainingPlan([])
1315
1419
 
@@ -1330,15 +1434,7 @@ class VannaBase(ABC):
1330
1434
  f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"'
1331
1435
  )
1332
1436
  doc = f"The following columns are in the {table} table in the {database} database:\n\n"
1333
- doc += df_columns_filtered_to_table[
1334
- [
1335
- database_column,
1336
- schema_column,
1337
- table_column,
1338
- column_column,
1339
- data_type_column,
1340
- ]
1341
- ].to_markdown()
1437
+ doc += df_columns_filtered_to_table[columns].to_markdown()
1342
1438
 
1343
1439
  plan._plan.append(
1344
1440
  TrainingPlanItem(
@@ -1,5 +1,4 @@
1
1
  import json
2
- import uuid
3
2
  from typing import List
4
3
 
5
4
  import chromadb
@@ -8,6 +7,7 @@ from chromadb.config import Settings
8
7
  from chromadb.utils import embedding_functions
9
8
 
10
9
  from ..base import VannaBase
10
+ from ..utils import deterministic_uuid
11
11
 
12
12
  default_ef = embedding_functions.DefaultEmbeddingFunction()
13
13
 
@@ -15,17 +15,14 @@ default_ef = embedding_functions.DefaultEmbeddingFunction()
15
15
  class ChromaDB_VectorStore(VannaBase):
16
16
  def __init__(self, config=None):
17
17
  VannaBase.__init__(self, config=config)
18
+ if config is None:
19
+ config = {}
18
20
 
19
- if config is not None:
20
- path = config.get("path", ".")
21
- self.embedding_function = config.get("embedding_function", default_ef)
22
- curr_client = config.get("client", "persistent")
23
- self.n_results = config.get("n_results", 10)
24
- else:
25
- path = "."
26
- self.embedding_function = default_ef
27
- curr_client = "persistent" # defaults to persistent storage
28
- self.n_results = 10 # defaults to 10 documents
21
+ path = config.get("path", ".")
22
+ self.embedding_function = config.get("embedding_function", default_ef)
23
+ curr_client = config.get("client", "persistent")
24
+ collection_metadata = config.get("collection_metadata", None)
25
+ self.n_results = config.get("n_results", 10)
29
26
 
30
27
  if curr_client == "persistent":
31
28
  self.chroma_client = chromadb.PersistentClient(
@@ -42,13 +39,19 @@ class ChromaDB_VectorStore(VannaBase):
42
39
  raise ValueError(f"Unsupported client was set in config: {curr_client}")
43
40
 
44
41
  self.documentation_collection = self.chroma_client.get_or_create_collection(
45
- name="documentation", embedding_function=self.embedding_function
42
+ name="documentation",
43
+ embedding_function=self.embedding_function,
44
+ metadata=collection_metadata,
46
45
  )
47
46
  self.ddl_collection = self.chroma_client.get_or_create_collection(
48
- name="ddl", embedding_function=self.embedding_function
47
+ name="ddl",
48
+ embedding_function=self.embedding_function,
49
+ metadata=collection_metadata,
49
50
  )
50
51
  self.sql_collection = self.chroma_client.get_or_create_collection(
51
- name="sql", embedding_function=self.embedding_function
52
+ name="sql",
53
+ embedding_function=self.embedding_function,
54
+ metadata=collection_metadata,
52
55
  )
53
56
 
54
57
  def generate_embedding(self, data: str, **kwargs) -> List[float]:
@@ -65,7 +68,7 @@ class ChromaDB_VectorStore(VannaBase):
65
68
  },
66
69
  ensure_ascii=False,
67
70
  )
68
- id = str(uuid.uuid4()) + "-sql"
71
+ id = deterministic_uuid(question_sql_json) + "-sql"
69
72
  self.sql_collection.add(
70
73
  documents=question_sql_json,
71
74
  embeddings=self.generate_embedding(question_sql_json),
@@ -75,7 +78,7 @@ class ChromaDB_VectorStore(VannaBase):
75
78
  return id
76
79
 
77
80
  def add_ddl(self, ddl: str, **kwargs) -> str:
78
- id = str(uuid.uuid4()) + "-ddl"
81
+ id = deterministic_uuid(ddl) + "-ddl"
79
82
  self.ddl_collection.add(
80
83
  documents=ddl,
81
84
  embeddings=self.generate_embedding(ddl),
@@ -84,7 +87,7 @@ class ChromaDB_VectorStore(VannaBase):
84
87
  return id
85
88
 
86
89
  def add_documentation(self, documentation: str, **kwargs) -> str:
87
- id = str(uuid.uuid4()) + "-doc"
90
+ id = deterministic_uuid(documentation) + "-doc"
88
91
  self.documentation_collection.add(
89
92
  documents=documentation,
90
93
  embeddings=self.generate_embedding(documentation),
vanna/flask/__init__.py CHANGED
@@ -525,7 +525,7 @@ class VannaFlaskApp:
525
525
  # Proxy the /vanna.svg file to the remote server
526
526
  @self.flask_app.route("/vanna.svg")
527
527
  def proxy_vanna_svg():
528
- remote_url = f"https://vanna.ai/img/vanna.svg"
528
+ remote_url = "https://vanna.ai/img/vanna.svg"
529
529
  response = requests.get(remote_url, stream=True)
530
530
 
531
531
  # Check if the request to the remote URL was successful
@@ -41,7 +41,7 @@ class OpenAI_Chat(VannaBase):
41
41
  if config is None and client is None:
42
42
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
43
  return
44
-
44
+
45
45
  if "api_key" in config:
46
46
  self.client = OpenAI(api_key=config["api_key"])
47
47
 
@@ -67,7 +67,31 @@ class OpenAI_Chat(VannaBase):
67
67
  for message in prompt:
68
68
  num_tokens += len(message["content"]) / 4
69
69
 
70
- if self.config is not None and "engine" in self.config:
70
+ if kwargs.get("model", None) is not None:
71
+ model = kwargs.get("model", None)
72
+ print(
73
+ f"Using model {model} for {num_tokens} tokens (approx)"
74
+ )
75
+ response = self.client.chat.completions.create(
76
+ model=model,
77
+ messages=prompt,
78
+ max_tokens=self.max_tokens,
79
+ stop=None,
80
+ temperature=self.temperature,
81
+ )
82
+ elif kwargs.get("engine", None) is not None:
83
+ engine = kwargs.get("engine", None)
84
+ print(
85
+ f"Using model {engine} for {num_tokens} tokens (approx)"
86
+ )
87
+ response = self.client.chat.completions.create(
88
+ engine=engine,
89
+ messages=prompt,
90
+ max_tokens=self.max_tokens,
91
+ stop=None,
92
+ temperature=self.temperature,
93
+ )
94
+ elif self.config is not None and "engine" in self.config:
71
95
  print(
72
96
  f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
73
97
  )
vanna/remote.py CHANGED
@@ -34,11 +34,13 @@ from .types import (
34
34
  UserOTP,
35
35
  Visibility,
36
36
  )
37
+ from .vannadb import VannaDB_VectorStore
37
38
 
38
39
 
39
- class VannaDefault(VannaBase):
40
+ class VannaDefault(VannaDB_VectorStore):
40
41
  def __init__(self, model: str, api_key: str, config=None):
41
42
  VannaBase.__init__(self, config=config)
43
+ VannaDB_VectorStore.__init__(self, vanna_model=model, vanna_api_key=api_key, config=config)
42
44
 
43
45
  self._model = model
44
46
  self._api_key = api_key
@@ -48,50 +50,6 @@ class VannaDefault(VannaBase):
48
50
  if config is None or "endpoint" not in config
49
51
  else config["endpoint"]
50
52
  )
51
- self._unauthenticated_endpoint = (
52
- "https://ask.vanna.ai/unauthenticated_rpc"
53
- if config is None or "unauthenticated_endpoint" not in config
54
- else config["unauthenticated_endpoint"]
55
- )
56
-
57
- def _unauthenticated_rpc_call(self, method, params):
58
- headers = {
59
- "Content-Type": "application/json",
60
- }
61
- data = {
62
- "method": method,
63
- "params": [self._dataclass_to_dict(obj) for obj in params],
64
- }
65
-
66
- response = requests.post(
67
- self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
68
- )
69
- return response.json()
70
-
71
- def _rpc_call(self, method, params):
72
- if method != "list_orgs":
73
- headers = {
74
- "Content-Type": "application/json",
75
- "Vanna-Key": self._api_key,
76
- "Vanna-Org": self._model,
77
- }
78
- else:
79
- headers = {
80
- "Content-Type": "application/json",
81
- "Vanna-Key": self._api_key,
82
- "Vanna-Org": "demo-tpc-h",
83
- }
84
-
85
- data = {
86
- "method": method,
87
- "params": [self._dataclass_to_dict(obj) for obj in params],
88
- }
89
-
90
- response = requests.post(self._endpoint, headers=headers, data=json.dumps(data))
91
- return response.json()
92
-
93
- def _dataclass_to_dict(self, obj):
94
- return dataclasses.asdict(obj)
95
53
 
96
54
  def system_message(self, message: str) -> any:
97
55
  return {"role": "system", "content": message}
@@ -102,299 +60,6 @@ class VannaDefault(VannaBase):
102
60
  def assistant_message(self, message: str) -> any:
103
61
  return {"role": "assistant", "content": message}
104
62
 
105
- def get_training_data(self, **kwargs) -> pd.DataFrame:
106
- """
107
- Get the training data for the current model
108
-
109
- **Example:**
110
- ```python
111
- training_data = vn.get_training_data()
112
- ```
113
-
114
- Returns:
115
- pd.DataFrame or None: The training data, or None if an error occurred.
116
-
117
- """
118
- params = []
119
-
120
- d = self._rpc_call(method="get_training_data", params=params)
121
-
122
- if "result" not in d:
123
- return None
124
-
125
- # Load the result into a dataclass
126
- training_data = DataFrameJSON(**d["result"])
127
-
128
- df = pd.read_json(StringIO(training_data.data))
129
-
130
- return df
131
-
132
- def remove_training_data(self, id: str, **kwargs) -> bool:
133
- """
134
- Remove training data from the model
135
-
136
- **Example:**
137
- ```python
138
- vn.remove_training_data(id="1-ddl")
139
- ```
140
-
141
- Args:
142
- id (str): The ID of the training data to remove.
143
- """
144
- params = [StringData(data=id)]
145
-
146
- d = self._rpc_call(method="remove_training_data", params=params)
147
-
148
- if "result" not in d:
149
- raise Exception(f"Error removing training data")
150
-
151
- status = Status(**d["result"])
152
-
153
- if not status.success:
154
- raise Exception(f"Error removing training data: {status.message}")
155
-
156
- return status.success
157
-
158
- def generate_questions(self) -> list[str]:
159
- """
160
- **Example:**
161
- ```python
162
- vn.generate_questions()
163
- # ['What is the average salary of employees?', 'What is the total salary of employees?', ...]
164
- ```
165
-
166
- Generate questions using the Vanna.AI API.
167
-
168
- Returns:
169
- List[str] or None: The questions, or None if an error occurred.
170
- """
171
- d = self._rpc_call(method="generate_questions", params=[])
172
-
173
- if "result" not in d:
174
- return None
175
-
176
- # Load the result into a dataclass
177
- question_string_list = QuestionStringList(**d["result"])
178
-
179
- return question_string_list.questions
180
-
181
- def add_ddl(self, ddl: str, **kwargs) -> str:
182
- """
183
- Adds a DDL statement to the model's training data
184
-
185
- **Example:**
186
- ```python
187
- vn.add_ddl(
188
- ddl="CREATE TABLE employees (id INT, name VARCHAR(255), salary INT)"
189
- )
190
- ```
191
-
192
- Args:
193
- ddl (str): The DDL statement to store.
194
-
195
- Returns:
196
- str: The ID of the DDL statement.
197
- """
198
- params = [StringData(data=ddl)]
199
-
200
- d = self._rpc_call(method="add_ddl", params=params)
201
-
202
- if "result" not in d:
203
- raise Exception("Error adding DDL", d)
204
-
205
- status = StatusWithId(**d["result"])
206
-
207
- return status.id
208
-
209
- def add_documentation(self, documentation: str, **kwargs) -> str:
210
- """
211
- Adds documentation to the model's training data
212
-
213
- **Example:**
214
- ```python
215
- vn.add_documentation(
216
- documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold."
217
- )
218
- ```
219
-
220
- Args:
221
- documentation (str): The documentation string to store.
222
-
223
- Returns:
224
- str: The ID of the documentation string.
225
- """
226
- params = [StringData(data=documentation)]
227
-
228
- d = self._rpc_call(method="add_documentation", params=params)
229
-
230
- if "result" not in d:
231
- raise Exception("Error adding documentation", d)
232
-
233
- status = StatusWithId(**d["result"])
234
-
235
- return status.id
236
-
237
- def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
238
- """
239
- Adds a question and its corresponding SQL query to the model's training data. The preferred way to call this is to use [`vn.train(sql=...)`][vanna.train].
240
-
241
- **Example:**
242
- ```python
243
- vn.add_sql(
244
- question="What is the average salary of employees?",
245
- sql="SELECT AVG(salary) FROM employees"
246
- )
247
- ```
248
-
249
- Args:
250
- question (str): The question to store.
251
- sql (str): The SQL query to store.
252
- tag (Union[str, None]): A tag to associate with the question and SQL query.
253
-
254
- Returns:
255
- str: The ID of the question and SQL query.
256
- """
257
- if "tag" in kwargs:
258
- tag = kwargs["tag"]
259
- else:
260
- tag = "Manually Trained"
261
-
262
- params = [QuestionSQLPair(question=question, sql=sql, tag=tag)]
263
-
264
- d = self._rpc_call(method="add_sql", params=params)
265
-
266
- if "result" not in d:
267
- raise Exception("Error adding question and SQL pair", d)
268
-
269
- status = StatusWithId(**d["result"])
270
-
271
- return status.id
272
-
273
- def generate_embedding(self, data: str, **kwargs) -> list[float]:
274
- """
275
- Not necessary for remote models as embeddings are generated on the server side.
276
- """
277
- pass
278
-
279
- def generate_plotly_code(
280
- self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
281
- ) -> str:
282
- """
283
- **Example:**
284
- ```python
285
- vn.generate_plotly_code(
286
- question="What is the average salary of employees?",
287
- sql="SELECT AVG(salary) FROM employees",
288
- df_metadata=df.dtypes
289
- )
290
- # fig = px.bar(df, x="name", y="salary")
291
- ```
292
- Generate Plotly code using the Vanna.AI API.
293
-
294
- Args:
295
- question (str): The question to generate Plotly code for.
296
- sql (str): The SQL query to generate Plotly code for.
297
- df (pd.DataFrame): The dataframe to generate Plotly code for.
298
- chart_instructions (str): Optional instructions for how to plot the chart.
299
-
300
- Returns:
301
- str or None: The Plotly code, or None if an error occurred.
302
- """
303
- if kwargs is not None and "chart_instructions" in kwargs:
304
- if question is not None:
305
- question = (
306
- question
307
- + " -- When plotting, follow these instructions: "
308
- + kwargs["chart_instructions"]
309
- )
310
- else:
311
- question = (
312
- "When plotting, follow these instructions: "
313
- + kwargs["chart_instructions"]
314
- )
315
-
316
- params = [
317
- DataResult(
318
- question=question,
319
- sql=sql,
320
- table_markdown=df_metadata,
321
- error=None,
322
- correction_attempts=0,
323
- )
324
- ]
325
-
326
- d = self._rpc_call(method="generate_plotly_code", params=params)
327
-
328
- if "result" not in d:
329
- return None
330
-
331
- # Load the result into a dataclass
332
- plotly_code = PlotlyResult(**d["result"])
333
-
334
- return plotly_code.plotly_code
335
-
336
- def generate_question(self, sql: str, **kwargs) -> str:
337
- """
338
-
339
- **Example:**
340
- ```python
341
- vn.generate_question(sql="SELECT * FROM students WHERE name = 'John Doe'")
342
- # 'What is the name of the student?'
343
- ```
344
-
345
- Generate a question from an SQL query using the Vanna.AI API.
346
-
347
- Args:
348
- sql (str): The SQL query to generate a question for.
349
-
350
- Returns:
351
- str or None: The question, or None if an error occurred.
352
-
353
- """
354
- params = [
355
- SQLAnswer(
356
- raw_answer="",
357
- prefix="",
358
- postfix="",
359
- sql=sql,
360
- )
361
- ]
362
-
363
- d = self._rpc_call(method="generate_question", params=params)
364
-
365
- if "result" not in d:
366
- return None
367
-
368
- # Load the result into a dataclass
369
- question = Question(**d["result"])
370
-
371
- return question.question
372
-
373
- def get_sql_prompt(
374
- self,
375
- question: str,
376
- question_sql_list: list,
377
- ddl_list: list,
378
- doc_list: list,
379
- **kwargs,
380
- ):
381
- """
382
- Not necessary for remote models as prompts are generated on the server side.
383
- """
384
-
385
- def get_followup_questions_prompt(
386
- self,
387
- question: str,
388
- df: pd.DataFrame,
389
- question_sql_list: list,
390
- ddl_list: list,
391
- doc_list: list,
392
- **kwargs,
393
- ):
394
- """
395
- Not necessary for remote models as prompts are generated on the server side.
396
- """
397
-
398
63
  def submit_prompt(self, prompt, **kwargs) -> str:
399
64
  # JSON-ify the prompt
400
65
  json_prompt = json.dumps(prompt)
@@ -410,46 +75,3 @@ class VannaDefault(VannaBase):
410
75
  results = StringData(**d["result"])
411
76
 
412
77
  return results.data
413
-
414
- def get_similar_question_sql(self, question: str, **kwargs) -> list:
415
- """
416
- Not necessary for remote models as similar questions are generated on the server side.
417
- """
418
-
419
- def get_related_ddl(self, question: str, **kwargs) -> list:
420
- """
421
- Not necessary for remote models as related DDL statements are generated on the server side.
422
- """
423
-
424
- def get_related_documentation(self, question: str, **kwargs) -> list:
425
- """
426
- Not necessary for remote models as related documentation is generated on the server side.
427
- """
428
-
429
- def generate_sql(self, question: str, **kwargs) -> str:
430
- """
431
- **Example:**
432
- ```python
433
- vn.generate_sql_from_question(question="What is the average salary of employees?")
434
- # SELECT AVG(salary) FROM employees
435
- ```
436
-
437
- Generate an SQL query using the Vanna.AI API.
438
-
439
- Args:
440
- question (str): The question to generate an SQL query for.
441
-
442
- Returns:
443
- str or None: The SQL query, or None if an error occurred.
444
- """
445
- params = [Question(question=question)]
446
-
447
- d = self._rpc_call(method="generate_sql_from_question", params=params)
448
-
449
- if "result" not in d:
450
- return None
451
-
452
- # Load the result into a dataclass
453
- sql_answer = SQLAnswer(**d["result"])
454
-
455
- return sql_answer.sql
vanna/utils.py CHANGED
@@ -1,5 +1,8 @@
1
+ import hashlib
1
2
  import os
2
3
  import re
4
+ import uuid
5
+ from typing import Union
3
6
 
4
7
  from .exceptions import ImproperlyConfigured, ValidationError
5
8
 
@@ -48,3 +51,27 @@ def sanitize_model_name(model_name):
48
51
  return model_name
49
52
  except Exception as e:
50
53
  raise ValidationError(e)
54
+
55
+
56
+ def deterministic_uuid(content: Union[str, bytes]) -> str:
57
+ """Creates deterministic UUID on hash value of string or byte content.
58
+
59
+ Args:
60
+ content: String or byte representation of data.
61
+
62
+ Returns:
63
+ UUID of the content.
64
+ """
65
+ if isinstance(content, str):
66
+ content_bytes = content.encode("utf-8")
67
+ elif isinstance(content, bytes):
68
+ content_bytes = content
69
+ else:
70
+ raise ValueError(f"Content type {type(content)} not supported !")
71
+
72
+ hash_object = hashlib.sha256(content_bytes)
73
+ hash_hex = hash_object.hexdigest()
74
+ namespace = uuid.UUID("00000000-0000-0000-0000-000000000000")
75
+ content_uuid = str(uuid.uuid5(namespace, hash_hex))
76
+
77
+ return content_uuid
@@ -7,14 +7,17 @@ import requests
7
7
 
8
8
  from ..base import VannaBase
9
9
  from ..types import (
10
- DataFrameJSON,
11
- Question,
12
- QuestionSQLPair,
13
- Status,
14
- StatusWithId,
15
- StringData,
16
- TrainingData,
10
+ DataFrameJSON,
11
+ NewOrganization,
12
+ OrganizationList,
13
+ Question,
14
+ QuestionSQLPair,
15
+ Status,
16
+ StatusWithId,
17
+ StringData,
18
+ TrainingData,
17
19
  )
20
+ from ..utils import sanitize_model_name
18
21
 
19
22
 
20
23
  class VannaDB_VectorStore(VannaBase):
@@ -29,27 +32,8 @@ class VannaDB_VectorStore(VannaBase):
29
32
  if config is None or "endpoint" not in config
30
33
  else config["endpoint"]
31
34
  )
32
- self._unauthenticated_endpoint = (
33
- "https://ask.vanna.ai/unauthenticated_rpc"
34
- if config is None or "unauthenticated_endpoint" not in config
35
- else config["unauthenticated_endpoint"]
36
- )
37
35
  self.related_training_data = {}
38
36
 
39
- def _unauthenticated_rpc_call(self, method, params):
40
- headers = {
41
- "Content-Type": "application/json",
42
- }
43
- data = {
44
- "method": method,
45
- "params": [self._dataclass_to_dict(obj) for obj in params],
46
- }
47
-
48
- response = requests.post(
49
- self._unauthenticated_endpoint, headers=headers, data=json.dumps(data)
50
- )
51
- return response.json()
52
-
53
37
  def _rpc_call(self, method, params):
54
38
  if method != "list_orgs":
55
39
  headers = {
@@ -75,6 +59,53 @@ class VannaDB_VectorStore(VannaBase):
75
59
  def _dataclass_to_dict(self, obj):
76
60
  return dataclasses.asdict(obj)
77
61
 
62
+ def create_model(self, model: str, **kwargs) -> bool:
63
+ """
64
+ **Example:**
65
+ ```python
66
+ success = vn.create_model("my_model")
67
+ ```
68
+ Create a new model.
69
+
70
+ Args:
71
+ model (str): The name of the model to create.
72
+
73
+ Returns:
74
+ bool: True if the model was created, False otherwise.
75
+ """
76
+ model = sanitize_model_name(model)
77
+ params = [NewOrganization(org_name=model, db_type="")]
78
+
79
+ d = self._rpc_call(method="create_org", params=params)
80
+
81
+ if "result" not in d:
82
+ return False
83
+
84
+ status = Status(**d["result"])
85
+
86
+ return status.success
87
+
88
+ def get_models(self) -> list:
89
+ """
90
+ **Example:**
91
+ ```python
92
+ models = vn.get_models()
93
+ ```
94
+
95
+ List the models that belong to the user.
96
+
97
+ Returns:
98
+ List[str]: A list of model names.
99
+ """
100
+ d = self._rpc_call(method="list_my_models", params=[])
101
+
102
+ if "result" not in d:
103
+ return []
104
+
105
+ orgs = OrganizationList(**d["result"])
106
+
107
+ return orgs.organizations
108
+
78
109
  def generate_embedding(self, data: str, **kwargs) -> list[float]:
79
110
  # This is done server-side
80
111
  pass
@@ -141,7 +172,7 @@ class VannaDB_VectorStore(VannaBase):
141
172
  d = self._rpc_call(method="remove_training_data", params=params)
142
173
 
143
174
  if "result" not in d:
144
- raise Exception(f"Error removing training data")
175
+ raise Exception("Error removing training data")
145
176
 
146
177
  status = Status(**d["result"])
147
178
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -26,6 +26,8 @@ Requires-Dist: openai ; extra == "all"
26
26
  Requires-Dist: mistralai ; extra == "all"
27
27
  Requires-Dist: chromadb ; extra == "all"
28
28
  Requires-Dist: anthropic ; extra == "all"
29
+ Requires-Dist: zhipuai ; extra == "all"
30
+ Requires-Dist: marqo ; extra == "all"
29
31
  Requires-Dist: anthropic ; extra == "anthropic"
30
32
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
31
33
  Requires-Dist: chromadb ; extra == "chromadb"
@@ -39,6 +41,7 @@ Requires-Dist: psycopg2-binary ; extra == "postgres"
39
41
  Requires-Dist: db-dtypes ; extra == "postgres"
40
42
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
41
43
  Requires-Dist: tox ; extra == "test"
44
+ Requires-Dist: zhipuai ; extra == "zhipuai"
42
45
  Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
43
46
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
44
47
  Provides-Extra: all
@@ -54,6 +57,7 @@ Provides-Extra: openai
54
57
  Provides-Extra: postgres
55
58
  Provides-Extra: snowflake
56
59
  Provides-Extra: test
60
+ Provides-Extra: zhipuai
57
61
 
58
62
 
59
63
 
@@ -1,18 +1,18 @@
1
1
  vanna/__init__.py,sha256=4zz2kSkVZenjwJQg-ETWsIVYdz3gio275i9DMa_aHxM,9248
2
2
  vanna/local.py,sha256=U5s8ybCRQhBUizi8I69o3jqOpTeu_6KGYY6DMwZxjG4,313
3
- vanna/remote.py,sha256=qqaTA4l-ikVH_1aQInae4DyTAfCCmFokQme6Jq_i1us,12803
4
- vanna/utils.py,sha256=Q0H4eugPYg9SVpEoTWgvmuoJZZxOVRhNzrP97E5lyak,1472
5
- vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=C2zkK9R_HZnhOTN_lAvjiFnHSH2s_lFNfu-FEjU2ycs,8906
3
+ vanna/remote.py,sha256=CcScbeEkqYlzBZJMeMLDlLgmaFcG0QSUQUZybFI3Y28,1856
4
+ vanna/utils.py,sha256=cs0B_0MwhmPI2nWjVHifDYCmCR0kkddylQ2vloaPDSw,2247
5
+ vanna/ZhipuAI/ZhipuAI_Chat.py,sha256=WtZKUBIwlNH0BGbb4lZbVR7pTWIrn7b4RLIk-7u0SuQ,8725
6
6
  vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oERsa8pE,2849
7
7
  vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
8
8
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
9
9
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
10
10
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
11
- vanna/base/base.py,sha256=25tHTMsCnoek6X9C0hqpPyNM3H0mihcL4I8Mlq-74c0,54890
11
+ vanna/base/base.py,sha256=89XPWy97YVx6090mNmu1zvn4k8X1pusCuAIypHHexNc,58100
12
12
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
13
- vanna/chromadb/chromadb_vector.py,sha256=fa7uj_knzSfzsVLvpSunwwu1ZJNC3GbiNZ4Yy09v4l4,8372
13
+ vanna/chromadb/chromadb_vector.py,sha256=1n4U4XpXThCFqyJf0zAYVA7mQu9rUkjOFtYn9e04JAo,8461
14
14
  vanna/exceptions/__init__.py,sha256=N76unE7sjbGGBz6LmCrPQAugFWr9cUFv8ErJxBrCTts,717
15
- vanna/flask/__init__.py,sha256=UgM0Ce5pGDdadWV6ZEAXj7RXDE1E420DW1wtR-juBMw,21212
15
+ vanna/flask/__init__.py,sha256=tpwpA8596Uyn60FAy7I5oJ81c7kgCB2JG9X044P0_SA,21211
16
16
  vanna/flask/assets.py,sha256=pOOtPV8aWtFsTuxJneFHcfrXhXh6cOSvS-Y8JO2HYrY,180924
17
17
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
18
18
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
@@ -21,11 +21,11 @@ vanna/mistral/mistral.py,sha256=DAEqAT9SzC91rfMM_S3SuzBZ34MrKHw9qAj6EP2MGVk,1508
21
21
  vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
22
22
  vanna/ollama/ollama.py,sha256=jfW9VQHAcmzDeo4jF3HJjOMYwAWmptknKqEJaQ0MTno,2418
23
23
  vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
24
- vanna/openai/openai_chat.py,sha256=Y3-Fhz9c6D-5vrMR1zGibavAPNPAu-hMTwqGfKhAg3Q,3852
24
+ vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
25
25
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
26
26
  vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
27
27
  vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
28
- vanna/vannadb/vannadb_vector.py,sha256=f4kddaJgTpZync7wnQi09QdODUuMtiHsK7WfKBUAmSo,5644
29
- vanna-0.3.2.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
30
- vanna-0.3.2.dist-info/METADATA,sha256=Suj6A5Wqxfyrs7d91R6g7y2lCusWtDUSctv8pUaueqg,9961
31
- vanna-0.3.2.dist-info/RECORD,,
28
+ vanna/vannadb/vannadb_vector.py,sha256=9YwTO3Lh5owWQE7KPMBqLp2EkiGV0RC1sEYhslzJzgI,6168
29
+ vanna-0.3.4.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
30
+ vanna-0.3.4.dist-info/METADATA,sha256=FEg4vs5ZiSAvd5YkF5oEfFqod9n3UoNfi51Q_2WKotA,10107
31
+ vanna-0.3.4.dist-info/RECORD,,
File without changes