vanna 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -220,19 +220,14 @@ class ZhipuAI_Chat(VannaBase):
220
220
  if len(prompt) == 0:
221
221
  raise Exception("Prompt is empty")
222
222
 
223
- client = ZhipuAI(api_key=self.api_key) # 填写您自己的APIKey
223
+ client = ZhipuAI(api_key=self.api_key)
224
224
  response = client.chat.completions.create(
225
- model="glm-4", # 填写需要调用的模型名称
225
+ model="glm-4",
226
226
  max_tokens=max_tokens,
227
227
  temperature=temperature,
228
228
  top_p=top_p,
229
229
  stop=stop,
230
230
  messages=prompt,
231
231
  )
232
- # print(prompt)
233
-
234
- # print(response)
235
-
236
- # print(f"Cost {response.usage.total_tokens} token")
237
232
 
238
233
  return response.choices[0].message.content
vanna/base/base.py CHANGED
@@ -147,19 +147,21 @@ class VannaBase(ABC):
147
147
  return False
148
148
 
149
149
  def generate_followup_questions(
150
- self, question: str, sql: str, df: pd.DataFrame, **kwargs
150
+ self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs
151
151
  ) -> list:
152
152
  """
153
153
  **Example:**
154
154
  ```python
155
- vn.generate_followup_questions("What are the top 10 customers by sales?", df)
155
+ vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df)
156
156
  ```
157
157
 
158
158
  Generate a list of followup questions that you can ask Vanna.AI.
159
159
 
160
160
  Args:
161
161
  question (str): The question that was asked.
162
+ sql (str): The LLM-generated SQL query.
162
163
  df (pd.DataFrame): The results of the SQL query.
164
+ n_questions (int): Number of follow-up questions to generate.
163
165
 
164
166
  Returns:
165
167
  list: A list of followup questions that you can ask Vanna.AI.
@@ -170,7 +172,7 @@ class VannaBase(ABC):
170
172
  f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
171
173
  ),
172
174
  self.user_message(
173
- "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
175
+ f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
174
176
  ),
175
177
  ]
176
178
 
@@ -1304,12 +1306,14 @@ class VannaBase(ABC):
1304
1306
  table_column = df.columns[
1305
1307
  df.columns.str.lower().str.contains("table_name")
1306
1308
  ].to_list()[0]
1307
- column_column = df.columns[
1308
- df.columns.str.lower().str.contains("column_name")
1309
- ].to_list()[0]
1310
- data_type_column = df.columns[
1311
- df.columns.str.lower().str.contains("data_type")
1312
- ].to_list()[0]
1309
+ columns = [database_column,
1310
+ schema_column,
1311
+ table_column]
1312
+ candidates = ["column_name",
1313
+ "data_type",
1314
+ "comment"]
1315
+ matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True)
1316
+ columns += df.columns[matches].to_list()
1313
1317
 
1314
1318
  plan = TrainingPlan([])
1315
1319
 
@@ -1330,15 +1334,7 @@ class VannaBase(ABC):
1330
1334
  f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"'
1331
1335
  )
1332
1336
  doc = f"The following columns are in the {table} table in the {database} database:\n\n"
1333
- doc += df_columns_filtered_to_table[
1334
- [
1335
- database_column,
1336
- schema_column,
1337
- table_column,
1338
- column_column,
1339
- data_type_column,
1340
- ]
1341
- ].to_markdown()
1337
+ doc += df_columns_filtered_to_table[columns].to_markdown()
1342
1338
 
1343
1339
  plan._plan.append(
1344
1340
  TrainingPlanItem(
@@ -8,6 +8,7 @@ from chromadb.config import Settings
8
8
  from chromadb.utils import embedding_functions
9
9
 
10
10
  from ..base import VannaBase
11
+ from ..utils import deterministic_uuid
11
12
 
12
13
  default_ef = embedding_functions.DefaultEmbeddingFunction()
13
14
 
@@ -65,7 +66,7 @@ class ChromaDB_VectorStore(VannaBase):
65
66
  },
66
67
  ensure_ascii=False,
67
68
  )
68
- id = str(uuid.uuid4()) + "-sql"
69
+ id = deterministic_uuid(question_sql_json) + "-sql"
69
70
  self.sql_collection.add(
70
71
  documents=question_sql_json,
71
72
  embeddings=self.generate_embedding(question_sql_json),
@@ -75,7 +76,7 @@ class ChromaDB_VectorStore(VannaBase):
75
76
  return id
76
77
 
77
78
  def add_ddl(self, ddl: str, **kwargs) -> str:
78
- id = str(uuid.uuid4()) + "-ddl"
79
+ id = deterministic_uuid(ddl) + "-ddl"
79
80
  self.ddl_collection.add(
80
81
  documents=ddl,
81
82
  embeddings=self.generate_embedding(ddl),
@@ -84,7 +85,7 @@ class ChromaDB_VectorStore(VannaBase):
84
85
  return id
85
86
 
86
87
  def add_documentation(self, documentation: str, **kwargs) -> str:
87
- id = str(uuid.uuid4()) + "-doc"
88
+ id = deterministic_uuid(documentation) + "-doc"
88
89
  self.documentation_collection.add(
89
90
  documents=documentation,
90
91
  embeddings=self.generate_embedding(documentation),