vanna 0.7.3__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,10 @@ import datetime
2
2
  import os
3
3
  import uuid
4
4
  from typing import List, Optional
5
+ from vertexai.language_models import (
6
+ TextEmbeddingInput,
7
+ TextEmbeddingModel
8
+ )
5
9
 
6
10
  import pandas as pd
7
11
  from google.cloud import bigquery
@@ -23,17 +27,15 @@ class BigQuery_VectorStore(VannaBase):
23
27
  or set as an environment variable, assign it.
24
28
  """
25
29
  print("Configuring genai")
30
+ self.type = "GEMINI"
26
31
  import google.generativeai as genai
27
32
 
28
33
  genai.configure(api_key=config["api_key"])
29
34
 
30
35
  self.genai = genai
31
36
  else:
37
+ self.type = "VERTEX_AI"
32
38
  # Authenticate using VertexAI
33
- from vertexai.language_models import (
34
- TextEmbeddingInput,
35
- TextEmbeddingModel,
36
- )
37
39
 
38
40
  if self.config.get("project_id"):
39
41
  self.project_id = self.config.get("project_id")
@@ -139,25 +141,42 @@ class BigQuery_VectorStore(VannaBase):
139
141
  results = self.conn.query(query).result().to_dataframe()
140
142
  return results
141
143
 
142
- def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
143
- result = self.genai.embed_content(
144
+ def get_embeddings(self, data: str, task: str) -> List[float]:
145
+ embeddings = None
146
+
147
+ if self.type == "VERTEX_AI":
148
+ input = [TextEmbeddingInput(data, task)]
149
+ model = TextEmbeddingModel.from_pretrained("text-embedding-004")
150
+
151
+ result = model.get_embeddings(input)
152
+
153
+ if len(result) > 0:
154
+ embeddings = result[0].values
155
+ else:
156
+ # Use Gemini Consumer API
157
+ result = self.genai.embed_content(
144
158
  model="models/text-embedding-004",
145
159
  content=data,
146
- task_type="retrieval_query")
160
+ task_type=task)
147
161
 
148
- if 'embedding' in result:
149
- return result['embedding']
162
+ if 'embedding' in result:
163
+ embeddings = result['embedding']
164
+
165
+ return embeddings
166
+
167
+ def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
168
+ result = self.get_embeddings(data, "RETRIEVAL_QUERY")
169
+
170
+ if result != None:
171
+ return result
150
172
  else:
151
173
  raise ValueError("No embeddings returned")
152
174
 
153
175
  def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
154
- result = self.genai.embed_content(
155
- model="models/text-embedding-004",
156
- content=data,
157
- task_type="retrieval_document")
176
+ result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
158
177
 
159
- if 'embedding' in result:
160
- return result['embedding']
178
+ if result != None:
179
+ return result
161
180
  else:
162
181
  raise ValueError("No embeddings returned")
163
182
 
@@ -15,7 +15,7 @@ class GoogleGeminiChat(VannaBase):
15
15
  if "model_name" in config:
16
16
  model_name = config["model_name"]
17
17
  else:
18
- model_name = "gemini-1.0-pro"
18
+ model_name = "gemini-1.5-pro"
19
19
 
20
20
  self.google_api_key = None
21
21
 
@@ -30,7 +30,7 @@ class GoogleGeminiChat(VannaBase):
30
30
  self.chat_model = genai.GenerativeModel(model_name)
31
31
  else:
32
32
  # Authenticate using VertexAI
33
- from vertexai.preview.generative_models import GenerativeModel
33
+ from vertexai.generative_models import GenerativeModel
34
34
  self.chat_model = GenerativeModel(model_name)
35
35
 
36
36
  def system_message(self, message: str) -> any:
@@ -0,0 +1 @@
1
+ from .pgvector import PG_VectorStore
@@ -0,0 +1,253 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import uuid
5
+
6
+ import pandas as pd
7
+ from langchain_core.documents import Document
8
+ from langchain_postgres.vectorstores import PGVector
9
+ from sqlalchemy import create_engine, text
10
+
11
+ from .. import ValidationError
12
+ from ..base import VannaBase
13
+ from ..types import TrainingPlan, TrainingPlanItem
14
+
15
+
16
+ class PG_VectorStore(VannaBase):
17
+ def __init__(self, config=None):
18
+ if not config or "connection_string" not in config:
19
+ raise ValueError(
20
+ "A valid 'config' dictionary with a 'connection_string' is required.")
21
+
22
+ VannaBase.__init__(self, config=config)
23
+
24
+ if config and "connection_string" in config:
25
+ self.connection_string = config.get("connection_string")
26
+ self.n_results = config.get("n_results", 10)
27
+
28
+ if config and "embedding_function" in config:
29
+ self.embedding_function = config.get("embedding_function")
30
+ else:
31
+ from langchain_huggingface import HuggingFaceEmbeddings
32
+ self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
33
+
34
+ self.sql_collection = PGVector(
35
+ embeddings=self.embedding_function,
36
+ collection_name="sql",
37
+ connection=self.connection_string,
38
+ )
39
+ self.ddl_collection = PGVector(
40
+ embeddings=self.embedding_function,
41
+ collection_name="ddl",
42
+ connection=self.connection_string,
43
+ )
44
+ self.documentation_collection = PGVector(
45
+ embeddings=self.embedding_function,
46
+ collection_name="documentation",
47
+ connection=self.connection_string,
48
+ )
49
+
50
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
51
+ question_sql_json = json.dumps(
52
+ {
53
+ "question": question,
54
+ "sql": sql,
55
+ },
56
+ ensure_ascii=False,
57
+ )
58
+ id = str(uuid.uuid4()) + "-sql"
59
+ createdat = kwargs.get("createdat")
60
+ doc = Document(
61
+ page_content=question_sql_json,
62
+ metadata={"id": id, "createdat": createdat},
63
+ )
64
+ self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
65
+
66
+ return id
67
+
68
+ def add_ddl(self, ddl: str, **kwargs) -> str:
69
+ _id = str(uuid.uuid4()) + "-ddl"
70
+ doc = Document(
71
+ page_content=ddl,
72
+ metadata={"id": _id},
73
+ )
74
+ self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
75
+ return _id
76
+
77
+ def add_documentation(self, documentation: str, **kwargs) -> str:
78
+ _id = str(uuid.uuid4()) + "-doc"
79
+ doc = Document(
80
+ page_content=documentation,
81
+ metadata={"id": _id},
82
+ )
83
+ self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
84
+ return _id
85
+
86
+ def get_collection(self, collection_name):
87
+ match collection_name:
88
+ case "sql":
89
+ return self.sql_collection
90
+ case "ddl":
91
+ return self.ddl_collection
92
+ case "documentation":
93
+ return self.documentation_collection
94
+ case _:
95
+ raise ValueError("Specified collection does not exist.")
96
+
97
+ def get_similar_question_sql(self, question: str) -> list:
98
+ documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
99
+ return [ast.literal_eval(document.page_content) for document in documents]
100
+
101
+ def get_related_ddl(self, question: str, **kwargs) -> list:
102
+ documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
103
+ return [document.page_content for document in documents]
104
+
105
+ def get_related_documentation(self, question: str, **kwargs) -> list:
106
+ documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
107
+ return [document.page_content for document in documents]
108
+
109
+ def train(
110
+ self,
111
+ question: str | None = None,
112
+ sql: str | None = None,
113
+ ddl: str | None = None,
114
+ documentation: str | None = None,
115
+ plan: TrainingPlan | None = None,
116
+ createdat: str | None = None,
117
+ ):
118
+ if question and not sql:
119
+ raise ValidationError("Please provide a SQL query.")
120
+
121
+ if documentation:
122
+ logging.info(f"Adding documentation: {documentation}")
123
+ return self.add_documentation(documentation)
124
+
125
+ if sql and question:
126
+ return self.add_question_sql(question=question, sql=sql, createdat=createdat)
127
+
128
+ if ddl:
129
+ logging.info(f"Adding ddl: {ddl}")
130
+ return self.add_ddl(ddl)
131
+
132
+ if plan:
133
+ for item in plan._plan:
134
+ if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
135
+ self.add_ddl(item.item_value)
136
+ elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
137
+ self.add_documentation(item.item_value)
138
+ elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
139
+ self.add_question_sql(question=item.item_name, sql=item.item_value)
140
+
141
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
142
+ # Establishing the connection
143
+ engine = create_engine(self.connection_string)
144
+
145
+ # Querying the 'langchain_pg_embedding' table
146
+ query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
147
+ df_embedding = pd.read_sql(query_embedding, engine)
148
+
149
+ # List to accumulate the processed rows
150
+ processed_rows = []
151
+
152
+ # Process each row in the DataFrame
153
+ for _, row in df_embedding.iterrows():
154
+ custom_id = row["cmetadata"]["id"]
155
+ document = row["document"]
156
+ training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
157
+
158
+ if training_data_type == "sql":
159
+ # Convert the document string to a dictionary
160
+ try:
161
+ doc_dict = ast.literal_eval(document)
162
+ question = doc_dict.get("question")
163
+ content = doc_dict.get("sql")
164
+ except (ValueError, SyntaxError):
165
+ logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
166
+ continue
167
+ elif training_data_type in ["documentation", "ddl"]:
168
+ question = None # Default value for question
169
+ content = document
170
+ else:
171
+ # If the suffix is not recognized, skip this row
172
+ logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
173
+ continue
174
+
175
+ # Append the processed data to the list
176
+ processed_rows.append(
177
+ {"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
178
+ )
179
+
180
+ # Create a DataFrame from the list of processed rows
181
+ df_processed = pd.DataFrame(processed_rows)
182
+
183
+ return df_processed
184
+
185
+ def remove_training_data(self, id: str, **kwargs) -> bool:
186
+ # Create the database engine
187
+ engine = create_engine(self.connection_string)
188
+
189
+ # SQL DELETE statement
190
+ delete_statement = text(
191
+ """
192
+ DELETE FROM langchain_pg_embedding
193
+ WHERE cmetadata ->> 'id' = :id
194
+ """
195
+ )
196
+
197
+ # Connect to the database and execute the delete statement
198
+ with engine.connect() as connection:
199
+ # Start a transaction
200
+ with connection.begin() as transaction:
201
+ try:
202
+ result = connection.execute(delete_statement, {"id": id})
203
+ # Commit the transaction if the delete was successful
204
+ transaction.commit()
205
+ # Check if any row was deleted and return True or False accordingly
206
+ return result.rowcount > 0
207
+ except Exception as e:
208
+ # Rollback the transaction in case of error
209
+ logging.error(f"An error occurred: {e}")
210
+ transaction.rollback()
211
+ return False
212
+
213
+ def remove_collection(self, collection_name: str) -> bool:
214
+ engine = create_engine(self.connection_string)
215
+
216
+ # Determine the suffix to look for based on the collection name
217
+ suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
218
+ suffix = suffix_map.get(collection_name)
219
+
220
+ if not suffix:
221
+ logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
222
+ return False
223
+
224
+ # SQL query to delete rows based on the condition
225
+ query = text(
226
+ f"""
227
+ DELETE FROM langchain_pg_embedding
228
+ WHERE cmetadata->>'id' LIKE '%{suffix}'
229
+ """
230
+ )
231
+
232
+ # Execute the deletion within a transaction block
233
+ with engine.connect() as connection:
234
+ with connection.begin() as transaction:
235
+ try:
236
+ result = connection.execute(query)
237
+ transaction.commit() # Explicitly commit the transaction
238
+ if result.rowcount > 0:
239
+ logging.info(
240
+ f"Deleted {result.rowcount} rows from "
241
+ f"langchain_pg_embedding where collection is {collection_name}."
242
+ )
243
+ return True
244
+ else:
245
+ logging.info(f"No rows deleted for collection {collection_name}.")
246
+ return False
247
+ except Exception as e:
248
+ logging.error(f"An error occurred: {e}")
249
+ transaction.rollback() # Rollback in case of error
250
+ return False
251
+
252
+ def generate_embedding(self, *args, **kwargs):
253
+ pass
vanna/vllm/vllm.py CHANGED
@@ -22,6 +22,12 @@ class Vllm(VannaBase):
22
22
  else:
23
23
  self.auth_key = None
24
24
 
25
+ if "temperature" in config:
26
+ self.temperature = config["temperature"]
27
+ else:
28
+ # default temperature - can be overrided using config
29
+ self.temperature = 0.7
30
+
25
31
  def system_message(self, message: str) -> any:
26
32
  return {"role": "system", "content": message}
27
33
 
@@ -68,6 +74,7 @@ class Vllm(VannaBase):
68
74
  url = f"{self.host}/v1/chat/completions"
69
75
  data = {
70
76
  "model": self.model,
77
+ "temperature": self.temperature,
71
78
  "stream": False,
72
79
  "messages": prompt,
73
80
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.7.3
3
+ Version: 0.7.4
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -46,6 +46,12 @@ Requires-Dist: weaviate-client ; extra == "all"
46
46
  Requires-Dist: azure-search-documents ; extra == "all"
47
47
  Requires-Dist: azure-identity ; extra == "all"
48
48
  Requires-Dist: azure-common ; extra == "all"
49
+ Requires-Dist: faiss-cpu ; extra == "all"
50
+ Requires-Dist: boto ; extra == "all"
51
+ Requires-Dist: boto3 ; extra == "all"
52
+ Requires-Dist: botocore ; extra == "all"
53
+ Requires-Dist: langchain_core ; extra == "all"
54
+ Requires-Dist: langchain_postgres ; extra == "all"
49
55
  Requires-Dist: anthropic ; extra == "anthropic"
50
56
  Requires-Dist: azure-search-documents ; extra == "azuresearch"
51
57
  Requires-Dist: azure-identity ; extra == "azuresearch"
@@ -57,6 +63,8 @@ Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
57
63
  Requires-Dist: chromadb ; extra == "chromadb"
58
64
  Requires-Dist: clickhouse_connect ; extra == "clickhouse"
59
65
  Requires-Dist: duckdb ; extra == "duckdb"
66
+ Requires-Dist: faiss-cpu ; extra == "faiss-cpu"
67
+ Requires-Dist: faiss-gpu ; extra == "faiss-gpu"
60
68
  Requires-Dist: google-generativeai ; extra == "gemini"
61
69
  Requires-Dist: google-generativeai ; extra == "google"
62
70
  Requires-Dist: google-cloud-aiplatform ; extra == "google"
@@ -70,6 +78,7 @@ Requires-Dist: httpx ; extra == "ollama"
70
78
  Requires-Dist: openai ; extra == "openai"
71
79
  Requires-Dist: opensearch-py ; extra == "opensearch"
72
80
  Requires-Dist: opensearch-dsl ; extra == "opensearch"
81
+ Requires-Dist: langchain-postgres>=0.0.12 ; extra == "pgvector"
73
82
  Requires-Dist: pinecone-client ; extra == "pinecone"
74
83
  Requires-Dist: fastembed ; extra == "pinecone"
75
84
  Requires-Dist: psycopg2-binary ; extra == "postgres"
@@ -92,6 +101,8 @@ Provides-Extra: bigquery
92
101
  Provides-Extra: chromadb
93
102
  Provides-Extra: clickhouse
94
103
  Provides-Extra: duckdb
104
+ Provides-Extra: faiss-cpu
105
+ Provides-Extra: faiss-gpu
95
106
  Provides-Extra: gemini
96
107
  Provides-Extra: google
97
108
  Provides-Extra: hf
@@ -102,6 +113,7 @@ Provides-Extra: mysql
102
113
  Provides-Extra: ollama
103
114
  Provides-Extra: openai
104
115
  Provides-Extra: opensearch
116
+ Provides-Extra: pgvector
105
117
  Provides-Extra: pinecone
106
118
  Provides-Extra: postgres
107
119
  Provides-Extra: qdrant
@@ -23,8 +23,8 @@ vanna/flask/__init__.py,sha256=jcdaau1tQ142nL1ZsDklk0ilMkEyRxgQZdmsl1IN4LQ,43866
23
23
  vanna/flask/assets.py,sha256=af-vact_5HSftltugBpPxzLkAI14Z0lVWcObyVe6eKE,453462
24
24
  vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
25
25
  vanna/google/__init__.py,sha256=6D8rDBjKJJm_jpVn9b4Vc2NR-R779ed_bnHhWmxCJXE,92
26
- vanna/google/bigquery_vector.py,sha256=rkP94Xd1lNYjU1x3MDLvqmGSPUYtDfQwvlqVmX44jyM,8839
27
- vanna/google/gemini_chat.py,sha256=j1szC2PamMLFrs0Z4lYPS69i017FYICe-mNObNYFBPQ,1576
26
+ vanna/google/bigquery_vector.py,sha256=mHggjvCsWMt4HK6Y4dAZUPgHi1uytxp2AEQ696TSsJA,9315
27
+ vanna/google/gemini_chat.py,sha256=9xHvxArxHr7OWXHnDRz7wX7KTGbuy6xXxHMLkhOMkis,1568
28
28
  vanna/hf/__init__.py,sha256=vD0bIhfLkA1UsvVSF4MAz3Da8aQunkQo3wlDztmMuj0,19
29
29
  vanna/hf/hf.py,sha256=N8N5g3xvKDBt3dez2r_U0qATxbl2pN8SVLTZK9CSRA0,3020
30
30
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
@@ -44,6 +44,8 @@ vanna/openai/openai_chat.py,sha256=KU6ynOQ5v7vwrQQ13phXoUXeQUrH6_vmhfiPvWddTrQ,4
44
44
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
45
45
  vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
46
46
  vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8QyVGI0eI,12226
47
+ vanna/pgvector/__init__.py,sha256=7Wvu9qcNdNvZu26Dn53jhO9YXELm0_YsrwBab4BdgVM,37
48
+ vanna/pgvector/pgvector.py,sha256=dJfm8rswYZvbaIbnjmyRjL071iw4siE0INibsZtaLXY,9919
47
49
  vanna/pinecone/__init__.py,sha256=eO5l8aX8vKL6aIUMgAXGPt1jdqKxB_Hic6cmoVAUrD0,90
48
50
  vanna/pinecone/pinecone_vector.py,sha256=mpq1lzo3KRj2QfJEw8pwFclFQK1Oi_Nx-lDkx9Gp0mw,11448
49
51
  vanna/qdrant/__init__.py,sha256=PX_OsDOiPMvwCJ2iGER1drSdQ9AyM8iN5PEBhRb6qqY,73
@@ -58,9 +60,9 @@ vanna/types/__init__.py,sha256=Qhn_YscKtJh7mFPCyCDLa2K8a4ORLMGVnPpTbv9uB2U,4957
58
60
  vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
59
61
  vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
60
62
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
61
- vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
63
+ vanna/vllm/vllm.py,sha256=oCdEjT2KP7gbZk-N7G9bfxB15OTtbwJHvNAceXx_r8g,3077
62
64
  vanna/weaviate/__init__.py,sha256=HL6PAl7ePBAkeG8uln-BmM7IUtWohyTPvDfcPzSGSCg,46
63
65
  vanna/weaviate/weaviate_vector.py,sha256=tUJIZjEy2mda8CB6C8zeN2SKkEO-UJdLsIqy69skuF0,7584
64
- vanna-0.7.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
65
- vanna-0.7.3.dist-info/METADATA,sha256=BOfBtwy1ENcdHApatLWXjqvKj8Zl3bti1hlueVoplR8,12407
66
- vanna-0.7.3.dist-info/RECORD,,
66
+ vanna-0.7.4.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
67
+ vanna-0.7.4.dist-info/METADATA,sha256=rGFflMIIcAqHlmt6g0gvHkgl6nBT-w9Y8jhwy3pRaYo,12900
68
+ vanna-0.7.4.dist-info/RECORD,,
File without changes