vanna 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/base/base.py CHANGED
@@ -182,7 +182,7 @@ class VannaBase(ABC):
182
182
  """
183
183
 
184
184
  # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
185
- sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
185
+ sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
186
186
  if sqls:
187
187
  sql = sqls[-1]
188
188
  self.log(title="Extracted SQL", message=f"{sql}")
@@ -0,0 +1 @@
1
+ from .bedrock_converse import Bedrock_Converse
@@ -0,0 +1,85 @@
1
+ from ..base import VannaBase
2
+
3
+ try:
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ except ImportError:
7
+ raise ImportError("Please install boto3 and botocore to use Amazon Bedrock models")
8
+
9
+ class Bedrock_Converse(VannaBase):
10
+ def __init__(self, client=None, config=None):
11
+ VannaBase.__init__(self, config=config)
12
+
13
+ # default parameters
14
+ self.temperature = 0.0
15
+ self.max_tokens = 500
16
+
17
+ if client is None:
18
+ raise ValueError(
19
+ "A valid Bedrock runtime client must be provided to invoke Bedrock models"
20
+ )
21
+ else:
22
+ self.client = client
23
+
24
+ if config is None:
25
+ raise ValueError(
26
+ "Config is required with model_id and inference parameters"
27
+ )
28
+
29
+ if "modelId" not in config:
30
+ raise ValueError(
31
+ "config must contain a modelId to invoke"
32
+ )
33
+ else:
34
+ self.model = config["modelId"]
35
+
36
+ if "temperature" in config:
37
+ self.temperature = config["temperature"]
38
+
39
+ if "max_tokens" in config:
40
+ self.max_tokens = config["max_tokens"]
41
+
42
+ def system_message(self, message: str) -> dict:
43
+ return {"role": "system", "content": message}
44
+
45
+ def user_message(self, message: str) -> dict:
46
+ return {"role": "user", "content": message}
47
+
48
+ def assistant_message(self, message: str) -> dict:
49
+ return {"role": "assistant", "content": message}
50
+
51
+ def submit_prompt(self, prompt, **kwargs) -> str:
52
+ inference_config = {
53
+ "temperature": self.temperature,
54
+ "maxTokens": self.max_tokens
55
+ }
56
+ additional_model_fields = {
57
+ "top_p": 1, # setting top_p value for nucleus sampling
58
+ }
59
+
60
+ system_message = None
61
+ no_system_prompt = []
62
+ for prompt_message in prompt:
63
+ role = prompt_message["role"]
64
+ if role == "system":
65
+ system_message = prompt_message["content"]
66
+ else:
67
+ no_system_prompt.append({"role": role, "content":[{"text": prompt_message["content"]}]})
68
+
69
+ converse_api_params = {
70
+ "modelId": self.model,
71
+ "messages": no_system_prompt,
72
+ "inferenceConfig": inference_config,
73
+ "additionalModelRequestFields": additional_model_fields
74
+ }
75
+
76
+ if system_message:
77
+ converse_api_params["system"] = [{"text": system_message}]
78
+
79
+ try:
80
+ response = self.client.converse(**converse_api_params)
81
+ text_content = response["output"]["message"]["content"][0]["text"]
82
+ return text_content
83
+ except ClientError as err:
84
+ message = err.response["Error"]["Message"]
85
+ raise Exception(f"A Bedrock client error occurred: {message}")
@@ -7,7 +7,7 @@ class GoogleGeminiChat(VannaBase):
7
7
  VannaBase.__init__(self, config=config)
8
8
 
9
9
  # default temperature - can be overrided using config
10
- self.temperature = 0.7
10
+ self.temperature = 0.7
11
11
 
12
12
  if "temperature" in config:
13
13
  self.temperature = config["temperature"]
@@ -31,7 +31,7 @@ class GoogleGeminiChat(VannaBase):
31
31
  else:
32
32
  # Authenticate using VertexAI
33
33
  from vertexai.preview.generative_models import GenerativeModel
34
- self.chat_model = GenerativeModel("gemini-pro")
34
+ self.chat_model = GenerativeModel(model_name)
35
35
 
36
36
  def system_message(self, message: str) -> any:
37
37
  return message
vanna/hf/hf.py CHANGED
@@ -6,13 +6,15 @@ from ..base import VannaBase
6
6
 
7
7
  class Hf(VannaBase):
8
8
  def __init__(self, config=None):
9
- model_name = self.config.get(
10
- "model_name", None
11
- ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model_name_or_path = self.config.get(
10
+ "model_name_or_path", None
11
+ ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
12
+ # list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
13
+ quantization_config = self.config.get("quantization_config", None)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
13
15
  self.model = AutoModelForCausalLM.from_pretrained(
14
- model_name,
15
- torch_dtype="auto",
16
+ model_name_or_path,
17
+ quantization_config=quantization_config,
16
18
  device_map="auto",
17
19
  )
18
20
 
@@ -0,0 +1 @@
1
+ from .weaviate_vector import WeaviateDatabase
@@ -0,0 +1,174 @@
1
+ import weaviate
2
+ import weaviate.classes as wvc
3
+ from fastembed import TextEmbedding
4
+
5
+ from vanna.base import VannaBase
6
+
7
+
8
+ class WeaviateDatabase(VannaBase):
9
+
10
+ def __init__(self, config=None):
11
+ """
12
+ Initialize the VannaEnhanced class with the provided configuration.
13
+
14
+ :param config: Dictionary containing configuration parameters.
15
+
16
+ params:
17
+ weaviate_url (str): Weaviate cluster URL while using weaviate cloud,
18
+ weaviate_api_key (str): Weaviate API key while using weaviate cloud,
19
+ weaviate_port (num): Weaviate port while using local weaviate,
20
+ weaviate_grpc (num): Weaviate gRPC port while using local weaviate,
21
+ fastembed_model (str): Fastembed model name for text embeddings. BAAI/bge-small-en-v1.5 by default.
22
+
23
+ """
24
+ super().__init__(config=config)
25
+
26
+ if config is None:
27
+ raise ValueError("config is required")
28
+
29
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
30
+ self.weaviate_api_key = config.get("weaviate_api_key")
31
+ self.weaviate_url = config.get("weaviate_url")
32
+ self.weaviate_port = config.get("weaviate_port")
33
+ self.weaviate_grpc_port = config.get("weaviate_grpc", 50051)
34
+
35
+ if not self.weaviate_api_key and not self.weaviate_port:
36
+ raise ValueError("Add proper credentials to connect to weaviate")
37
+
38
+ self.weaviate_client = self._initialize_weaviate_client()
39
+ self.embeddings = TextEmbedding(model_name=self.fastembed_model)
40
+
41
+ self.training_data_cluster = {
42
+ "sql": "SQLTrainingDataEntry",
43
+ "ddl": "DDLEntry",
44
+ "doc": "DocumentationEntry"
45
+ }
46
+
47
+ self._create_collections_if_not_exist()
48
+
49
+ def _create_collections_if_not_exist(self):
50
+ properties_dict = {
51
+ self.training_data_cluster['ddl']: [
52
+ wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
53
+ ],
54
+ self.training_data_cluster['doc']: [
55
+ wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
56
+ ],
57
+ self.training_data_cluster['sql']: [
58
+ wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
59
+ wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT),
60
+ ]
61
+ }
62
+
63
+ for cluster, properties in properties_dict.items():
64
+ if not self.weaviate_client.collections.exists(cluster):
65
+ self.weaviate_client.collections.create(
66
+ name=cluster,
67
+ properties=properties
68
+ )
69
+
70
+ def _initialize_weaviate_client(self):
71
+ if self.weaviate_api_key:
72
+ return weaviate.connect_to_wcs(
73
+ cluster_url=self.weaviate_url,
74
+ auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
75
+ additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
76
+ skip_init_checks=True
77
+ )
78
+ else:
79
+ return weaviate.connect_to_local(
80
+ port=self.weaviate_port,
81
+ grpc_port=self.weaviate_grpc_port,
82
+ additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
83
+ skip_init_checks=True
84
+ )
85
+
86
+ def generate_embedding(self, data: str, **kwargs):
87
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
88
+ embedding = next(embedding_model.embed(data))
89
+ return embedding.tolist()
90
+
91
+
92
+ def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
93
+ self.weaviate_client.connect()
94
+ response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert(
95
+ properties=data_object,
96
+ vector=vector
97
+ )
98
+ self.weaviate_client.close()
99
+ return response
100
+
101
+ def add_ddl(self, ddl: str, **kwargs) -> str:
102
+ data_object = {
103
+ "description": ddl,
104
+ }
105
+ response = self._insert_data('ddl', data_object, self.generate_embedding(ddl))
106
+ return f'{response}-ddl'
107
+
108
+ def add_documentation(self, doc: str, **kwargs) -> str:
109
+ data_object = {
110
+ "description": doc,
111
+ }
112
+ response = self._insert_data('doc', data_object, self.generate_embedding(doc))
113
+ return f'{response}-doc'
114
+
115
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
116
+ data_object = {
117
+ "sql": sql,
118
+ "natural_language_question": question,
119
+ }
120
+ response = self._insert_data('sql', data_object, self.generate_embedding(question))
121
+ return f'{response}-sql'
122
+
123
+ def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list, limit: int = 3) -> list:
124
+ self.weaviate_client.connect()
125
+ collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key])
126
+ response = collection.query.near_vector(
127
+ near_vector=vector_input,
128
+ limit=limit,
129
+ return_properties=return_properties
130
+ )
131
+ response_list = [item.properties for item in response.objects]
132
+ self.weaviate_client.close()
133
+ return response_list
134
+
135
+ def get_related_ddl(self, question: str, **kwargs) -> list:
136
+ vector_input = self.generate_embedding(question)
137
+ response_list = self._query_collection('ddl', vector_input, ["description"])
138
+ return [item["description"] for item in response_list]
139
+
140
+ def get_related_documentation(self, question: str, **kwargs) -> list:
141
+ vector_input = self.generate_embedding(question)
142
+ response_list = self._query_collection('doc', vector_input, ["description"])
143
+ return [item["description"] for item in response_list]
144
+
145
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
146
+ vector_input = self.generate_embedding(question)
147
+ response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"])
148
+ return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list]
149
+
150
+ def get_training_data(self, **kwargs) -> list:
151
+ self.weaviate_client.connect()
152
+ combined_response_list = []
153
+ for collection_name in self.training_data_cluster.values():
154
+ if self.weaviate_client.collections.exists(collection_name):
155
+ collection = self.weaviate_client.collections.get(collection_name)
156
+ response_list = [item.properties for item in collection.iterator()]
157
+ combined_response_list.extend(response_list)
158
+ self.weaviate_client.close()
159
+ return combined_response_list
160
+
161
+ def remove_training_data(self, id: str, **kwargs) -> bool:
162
+ self.weaviate_client.connect()
163
+ success = False
164
+ if id.endswith("-sql"):
165
+ id = id.replace('-sql', '')
166
+ success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id)
167
+ elif id.endswith("-ddl"):
168
+ id = id.replace('-ddl', '')
169
+ success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id)
170
+ elif id.endswith("-doc"):
171
+ id = id.replace('-doc', '')
172
+ success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id)
173
+ self.weaviate_client.close()
174
+ return success
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.6.2
3
+ Version: 0.6.3
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -40,7 +40,10 @@ Requires-Dist: opensearch-dsl ; extra == "all"
40
40
  Requires-Dist: transformers ; extra == "all"
41
41
  Requires-Dist: pinecone-client ; extra == "all"
42
42
  Requires-Dist: pymilvus[model] ; extra == "all"
43
+ Requires-Dist: weaviate-client ; extra == "all"
43
44
  Requires-Dist: anthropic ; extra == "anthropic"
45
+ Requires-Dist: boto3 ; extra == "bedrock"
46
+ Requires-Dist: botocore ; extra == "bedrock"
44
47
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
45
48
  Requires-Dist: chromadb ; extra == "chromadb"
46
49
  Requires-Dist: clickhouse_connect ; extra == "clickhouse"
@@ -67,11 +70,13 @@ Requires-Dist: fastembed ; extra == "qdrant"
67
70
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
68
71
  Requires-Dist: tox ; extra == "test"
69
72
  Requires-Dist: vllm ; extra == "vllm"
73
+ Requires-Dist: weaviate-client ; extra == "weaviate"
70
74
  Requires-Dist: zhipuai ; extra == "zhipuai"
71
75
  Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
72
76
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
73
77
  Provides-Extra: all
74
78
  Provides-Extra: anthropic
79
+ Provides-Extra: bedrock
75
80
  Provides-Extra: bigquery
76
81
  Provides-Extra: chromadb
77
82
  Provides-Extra: clickhouse
@@ -92,6 +97,7 @@ Provides-Extra: qdrant
92
97
  Provides-Extra: snowflake
93
98
  Provides-Extra: test
94
99
  Provides-Extra: vllm
100
+ Provides-Extra: weaviate
95
101
  Provides-Extra: zhipuai
96
102
 
97
103
 
@@ -9,7 +9,9 @@ vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,65
9
9
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
10
10
  vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
11
11
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
12
- vanna/base/base.py,sha256=l1H0TKsK9DN3n5XgDkUckdLois4dTCAUwrVsRa_6SlQ,70988
12
+ vanna/base/base.py,sha256=Cz5fW0Odg5GaRg2svvl6YPeBULKXKMNODAb3Zc1Q8bU,70993
13
+ vanna/bedrock/__init__.py,sha256=hRT2bgJbHEqViLdL-t9hfjSfFdIOkPU2ADBt-B1En-8,46
14
+ vanna/bedrock/bedrock_converse.py,sha256=Nx5kYm-diAfYmsWAnTP5xnv7V84Og69-AP9b3seIe0E,2869
13
15
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
14
16
  vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
15
17
  vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
@@ -17,9 +19,9 @@ vanna/flask/__init__.py,sha256=urPrHUqM1mpx96VHiQWVXCy3NQwDh6OsSkm4V4wqccY,30211
17
19
  vanna/flask/assets.py,sha256=_UoUr57sS0QL2BuTxAOe9k4yy8T7-fp2NpbRSVtW3IM,451769
18
20
  vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
19
21
  vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
20
- vanna/google/gemini_chat.py,sha256=ps3A-afFbCo3HeFTLL_nMoQO1PsGvRUUPRUppbMcDew,1584
22
+ vanna/google/gemini_chat.py,sha256=j1szC2PamMLFrs0Z4lYPS69i017FYICe-mNObNYFBPQ,1576
21
23
  vanna/hf/__init__.py,sha256=vD0bIhfLkA1UsvVSF4MAz3Da8aQunkQo3wlDztmMuj0,19
22
- vanna/hf/hf.py,sha256=v1v6sZnbj5xcrjgmvLP_ytS9NM7E5d0GyMfXXtr6BMU,2703
24
+ vanna/hf/hf.py,sha256=N8N5g3xvKDBt3dez2r_U0qATxbl2pN8SVLTZK9CSRA0,3020
23
25
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
24
26
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
25
27
  vanna/milvus/__init__.py,sha256=VBasJG2eTKbJI6CEand7kPLNBrqYrn0QCAhSYVz814s,46
@@ -46,6 +48,8 @@ vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
46
48
  vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
47
49
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
48
50
  vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
49
- vanna-0.6.2.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
50
- vanna-0.6.2.dist-info/METADATA,sha256=RVle66HeuhBS8iaO0vD8_iDQqk9NbeO1pZOgCgKwh54,11628
51
- vanna-0.6.2.dist-info/RECORD,,
51
+ vanna/weaviate/__init__.py,sha256=HL6PAl7ePBAkeG8uln-BmM7IUtWohyTPvDfcPzSGSCg,46
52
+ vanna/weaviate/weaviate_vector.py,sha256=GEiu4Vd9w-7j10aB-zTxJ8gefqe_F-LUUGvttFs1vlg,7539
53
+ vanna-0.6.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
54
+ vanna-0.6.3.dist-info/METADATA,sha256=-CFXJd2aKy2tq2Gxipp-52YCPHJBlI0wIOPzcOV5s80,11865
55
+ vanna-0.6.3.dist-info/RECORD,,
File without changes