vanna 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ class GoogleGeminiChat(VannaBase):
7
7
  VannaBase.__init__(self, config=config)
8
8
 
9
9
  # default temperature - can be overrided using config
10
- self.temperature = 0.7
10
+ self.temperature = 0.7
11
11
 
12
12
  if "temperature" in config:
13
13
  self.temperature = config["temperature"]
@@ -31,7 +31,7 @@ class GoogleGeminiChat(VannaBase):
31
31
  else:
32
32
  # Authenticate using VertexAI
33
33
  from vertexai.preview.generative_models import GenerativeModel
34
- self.chat_model = GenerativeModel("gemini-pro")
34
+ self.chat_model = GenerativeModel(model_name)
35
35
 
36
36
  def system_message(self, message: str) -> any:
37
37
  return message
vanna/hf/hf.py CHANGED
@@ -6,13 +6,15 @@ from ..base import VannaBase
6
6
 
7
7
  class Hf(VannaBase):
8
8
  def __init__(self, config=None):
9
- model_name = self.config.get(
10
- "model_name", None
11
- ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model_name_or_path = self.config.get(
10
+ "model_name_or_path", None
11
+ ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
12
+ # list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
13
+ quantization_config = self.config.get("quantization_config", None)
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
13
15
  self.model = AutoModelForCausalLM.from_pretrained(
14
- model_name,
15
- torch_dtype="auto",
16
+ model_name_or_path,
17
+ quantization_config=quantization_config,
16
18
  device_map="auto",
17
19
  )
18
20
 
@@ -11,14 +11,10 @@ class OpenAI_Chat(VannaBase):
11
11
 
12
12
  # default parameters - can be overrided using config
13
13
  self.temperature = 0.7
14
- self.max_tokens = 500
15
14
 
16
15
  if "temperature" in config:
17
16
  self.temperature = config["temperature"]
18
17
 
19
- if "max_tokens" in config:
20
- self.max_tokens = config["max_tokens"]
21
-
22
18
  if "api_type" in config:
23
19
  raise Exception(
24
20
  "Passing api_type is now deprecated. Please pass an OpenAI client instead."
@@ -75,7 +71,6 @@ class OpenAI_Chat(VannaBase):
75
71
  response = self.client.chat.completions.create(
76
72
  model=model,
77
73
  messages=prompt,
78
- max_tokens=self.max_tokens,
79
74
  stop=None,
80
75
  temperature=self.temperature,
81
76
  )
@@ -87,7 +82,6 @@ class OpenAI_Chat(VannaBase):
87
82
  response = self.client.chat.completions.create(
88
83
  engine=engine,
89
84
  messages=prompt,
90
- max_tokens=self.max_tokens,
91
85
  stop=None,
92
86
  temperature=self.temperature,
93
87
  )
@@ -98,7 +92,6 @@ class OpenAI_Chat(VannaBase):
98
92
  response = self.client.chat.completions.create(
99
93
  engine=self.config["engine"],
100
94
  messages=prompt,
101
- max_tokens=self.max_tokens,
102
95
  stop=None,
103
96
  temperature=self.temperature,
104
97
  )
@@ -109,7 +102,6 @@ class OpenAI_Chat(VannaBase):
109
102
  response = self.client.chat.completions.create(
110
103
  model=self.config["model"],
111
104
  messages=prompt,
112
- max_tokens=self.max_tokens,
113
105
  stop=None,
114
106
  temperature=self.temperature,
115
107
  )
@@ -123,7 +115,6 @@ class OpenAI_Chat(VannaBase):
123
115
  response = self.client.chat.completions.create(
124
116
  model=model,
125
117
  messages=prompt,
126
- max_tokens=self.max_tokens,
127
118
  stop=None,
128
119
  temperature=self.temperature,
129
120
  )
@@ -0,0 +1 @@
1
+ from .weaviate_vector import WeaviateDatabase
@@ -0,0 +1,174 @@
1
+ import weaviate
2
+ import weaviate.classes as wvc
3
+ from fastembed import TextEmbedding
4
+
5
+ from vanna.base import VannaBase
6
+
7
+
8
+ class WeaviateDatabase(VannaBase):
9
+
10
+ def __init__(self, config=None):
11
+ """
12
+ Initialize the VannaEnhanced class with the provided configuration.
13
+
14
+ :param config: Dictionary containing configuration parameters.
15
+
16
+ params:
17
+ weaviate_url (str): Weaviate cluster URL while using weaviate cloud,
18
+ weaviate_api_key (str): Weaviate API key while using weaviate cloud,
19
+ weaviate_port (num): Weaviate port while using local weaviate,
20
+ weaviate_grpc (num): Weaviate gRPC port while using local weaviate,
21
+ fastembed_model (str): Fastembed model name for text embeddings. BAAI/bge-small-en-v1.5 by default.
22
+
23
+ """
24
+ super().__init__(config=config)
25
+
26
+ if config is None:
27
+ raise ValueError("config is required")
28
+
29
+ self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
30
+ self.weaviate_api_key = config.get("weaviate_api_key")
31
+ self.weaviate_url = config.get("weaviate_url")
32
+ self.weaviate_port = config.get("weaviate_port")
33
+ self.weaviate_grpc_port = config.get("weaviate_grpc", 50051)
34
+
35
+ if not self.weaviate_api_key and not self.weaviate_port:
36
+ raise ValueError("Add proper credentials to connect to weaviate")
37
+
38
+ self.weaviate_client = self._initialize_weaviate_client()
39
+ self.embeddings = TextEmbedding(model_name=self.fastembed_model)
40
+
41
+ self.training_data_cluster = {
42
+ "sql": "SQLTrainingDataEntry",
43
+ "ddl": "DDLEntry",
44
+ "doc": "DocumentationEntry"
45
+ }
46
+
47
+ self._create_collections_if_not_exist()
48
+
49
+ def _create_collections_if_not_exist(self):
50
+ properties_dict = {
51
+ self.training_data_cluster['ddl']: [
52
+ wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
53
+ ],
54
+ self.training_data_cluster['doc']: [
55
+ wvc.config.Property(name="description", data_type=wvc.config.DataType.TEXT),
56
+ ],
57
+ self.training_data_cluster['sql']: [
58
+ wvc.config.Property(name="sql", data_type=wvc.config.DataType.TEXT),
59
+ wvc.config.Property(name="natural_language_question", data_type=wvc.config.DataType.TEXT),
60
+ ]
61
+ }
62
+
63
+ for cluster, properties in properties_dict.items():
64
+ if not self.weaviate_client.collections.exists(cluster):
65
+ self.weaviate_client.collections.create(
66
+ name=cluster,
67
+ properties=properties
68
+ )
69
+
70
+ def _initialize_weaviate_client(self):
71
+ if self.weaviate_api_key:
72
+ return weaviate.connect_to_wcs(
73
+ cluster_url=self.weaviate_url,
74
+ auth_credentials=weaviate.auth.AuthApiKey(self.weaviate_api_key),
75
+ additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
76
+ skip_init_checks=True
77
+ )
78
+ else:
79
+ return weaviate.connect_to_local(
80
+ port=self.weaviate_port,
81
+ grpc_port=self.weaviate_grpc_port,
82
+ additional_config=weaviate.config.AdditionalConfig(timeout=(10, 300)),
83
+ skip_init_checks=True
84
+ )
85
+
86
+ def generate_embedding(self, data: str, **kwargs):
87
+ embedding_model = TextEmbedding(model_name=self.fastembed_model)
88
+ embedding = next(embedding_model.embed(data))
89
+ return embedding.tolist()
90
+
91
+
92
+ def _insert_data(self, cluster_key: str, data_object: dict, vector: list) -> str:
93
+ self.weaviate_client.connect()
94
+ response = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]).data.insert(
95
+ properties=data_object,
96
+ vector=vector
97
+ )
98
+ self.weaviate_client.close()
99
+ return response
100
+
101
+ def add_ddl(self, ddl: str, **kwargs) -> str:
102
+ data_object = {
103
+ "description": ddl,
104
+ }
105
+ response = self._insert_data('ddl', data_object, self.generate_embedding(ddl))
106
+ return f'{response}-ddl'
107
+
108
+ def add_documentation(self, doc: str, **kwargs) -> str:
109
+ data_object = {
110
+ "description": doc,
111
+ }
112
+ response = self._insert_data('doc', data_object, self.generate_embedding(doc))
113
+ return f'{response}-doc'
114
+
115
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
116
+ data_object = {
117
+ "sql": sql,
118
+ "natural_language_question": question,
119
+ }
120
+ response = self._insert_data('sql', data_object, self.generate_embedding(question))
121
+ return f'{response}-sql'
122
+
123
+ def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list, limit: int = 3) -> list:
124
+ self.weaviate_client.connect()
125
+ collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key])
126
+ response = collection.query.near_vector(
127
+ near_vector=vector_input,
128
+ limit=limit,
129
+ return_properties=return_properties
130
+ )
131
+ response_list = [item.properties for item in response.objects]
132
+ self.weaviate_client.close()
133
+ return response_list
134
+
135
+ def get_related_ddl(self, question: str, **kwargs) -> list:
136
+ vector_input = self.generate_embedding(question)
137
+ response_list = self._query_collection('ddl', vector_input, ["description"])
138
+ return [item["description"] for item in response_list]
139
+
140
+ def get_related_documentation(self, question: str, **kwargs) -> list:
141
+ vector_input = self.generate_embedding(question)
142
+ response_list = self._query_collection('doc', vector_input, ["description"])
143
+ return [item["description"] for item in response_list]
144
+
145
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
146
+ vector_input = self.generate_embedding(question)
147
+ response_list = self._query_collection('sql', vector_input, ["sql", "natural_language_question"])
148
+ return [{"question": item["natural_language_question"], "sql": item["sql"]} for item in response_list]
149
+
150
+ def get_training_data(self, **kwargs) -> list:
151
+ self.weaviate_client.connect()
152
+ combined_response_list = []
153
+ for collection_name in self.training_data_cluster.values():
154
+ if self.weaviate_client.collections.exists(collection_name):
155
+ collection = self.weaviate_client.collections.get(collection_name)
156
+ response_list = [item.properties for item in collection.iterator()]
157
+ combined_response_list.extend(response_list)
158
+ self.weaviate_client.close()
159
+ return combined_response_list
160
+
161
+ def remove_training_data(self, id: str, **kwargs) -> bool:
162
+ self.weaviate_client.connect()
163
+ success = False
164
+ if id.endswith("-sql"):
165
+ id = id.replace('-sql', '')
166
+ success = self.weaviate_client.collections.get(self.training_data_cluster['sql']).data.delete_by_id(id)
167
+ elif id.endswith("-ddl"):
168
+ id = id.replace('-ddl', '')
169
+ success = self.weaviate_client.collections.get(self.training_data_cluster['ddl']).data.delete_by_id(id)
170
+ elif id.endswith("-doc"):
171
+ id = id.replace('-doc', '')
172
+ success = self.weaviate_client.collections.get(self.training_data_cluster['doc']).data.delete_by_id(id)
173
+ self.weaviate_client.close()
174
+ return success
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vanna
3
- Version: 0.6.2
3
+ Version: 0.6.4
4
4
  Summary: Generate SQL queries from natural language
5
5
  Author-email: Zain Hoda <zain@vanna.ai>
6
6
  Requires-Python: >=3.9
@@ -16,6 +16,7 @@ Requires-Dist: sqlparse
16
16
  Requires-Dist: kaleido
17
17
  Requires-Dist: flask
18
18
  Requires-Dist: flask-sock
19
+ Requires-Dist: flasgger
19
20
  Requires-Dist: sqlalchemy
20
21
  Requires-Dist: psycopg2-binary ; extra == "all"
21
22
  Requires-Dist: db-dtypes ; extra == "all"
@@ -40,7 +41,10 @@ Requires-Dist: opensearch-dsl ; extra == "all"
40
41
  Requires-Dist: transformers ; extra == "all"
41
42
  Requires-Dist: pinecone-client ; extra == "all"
42
43
  Requires-Dist: pymilvus[model] ; extra == "all"
44
+ Requires-Dist: weaviate-client ; extra == "all"
43
45
  Requires-Dist: anthropic ; extra == "anthropic"
46
+ Requires-Dist: boto3 ; extra == "bedrock"
47
+ Requires-Dist: botocore ; extra == "bedrock"
44
48
  Requires-Dist: google-cloud-bigquery ; extra == "bigquery"
45
49
  Requires-Dist: chromadb ; extra == "chromadb"
46
50
  Requires-Dist: clickhouse_connect ; extra == "clickhouse"
@@ -67,11 +71,13 @@ Requires-Dist: fastembed ; extra == "qdrant"
67
71
  Requires-Dist: snowflake-connector-python ; extra == "snowflake"
68
72
  Requires-Dist: tox ; extra == "test"
69
73
  Requires-Dist: vllm ; extra == "vllm"
74
+ Requires-Dist: weaviate-client ; extra == "weaviate"
70
75
  Requires-Dist: zhipuai ; extra == "zhipuai"
71
76
  Project-URL: Bug Tracker, https://github.com/vanna-ai/vanna/issues
72
77
  Project-URL: Homepage, https://github.com/vanna-ai/vanna
73
78
  Provides-Extra: all
74
79
  Provides-Extra: anthropic
80
+ Provides-Extra: bedrock
75
81
  Provides-Extra: bigquery
76
82
  Provides-Extra: chromadb
77
83
  Provides-Extra: clickhouse
@@ -92,6 +98,7 @@ Provides-Extra: qdrant
92
98
  Provides-Extra: snowflake
93
99
  Provides-Extra: test
94
100
  Provides-Extra: vllm
101
+ Provides-Extra: weaviate
95
102
  Provides-Extra: zhipuai
96
103
 
97
104
 
@@ -7,19 +7,21 @@ vanna/ZhipuAI/ZhipuAI_embeddings.py,sha256=lUqzJg9fOx7rVFhjdkFjXcDeVGV4aAB5Ss0oE
7
7
  vanna/ZhipuAI/__init__.py,sha256=NlsijtcZp5Tj9jkOe9fNcOQND_QsGgu7otODsCLBPr0,116
8
8
  vanna/advanced/__init__.py,sha256=oDj9g1JbrbCfp4WWdlr_bhgdMqNleyHgr6VXX6DcEbo,658
9
9
  vanna/anthropic/__init__.py,sha256=85s_2mAyyPxc0T_0JEvYeAkEKWJwkwqoyUwSC5dw9Gk,43
10
- vanna/anthropic/anthropic_chat.py,sha256=Wk0o-NMW1uvR2fhSWxrR_2FqNh-dLprNG4uuVqpqAkY,2615
10
+ vanna/anthropic/anthropic_chat.py,sha256=7X3x8SYwDY28aGyBnt0YNRMG8YY1p_t-foMfKGj8_Oo,2627
11
11
  vanna/base/__init__.py,sha256=Sl-HM1RRYzAZoSqmL1CZQmF3ZF-byYTCFQP3JZ2A5MU,28
12
- vanna/base/base.py,sha256=l1H0TKsK9DN3n5XgDkUckdLois4dTCAUwrVsRa_6SlQ,70988
12
+ vanna/base/base.py,sha256=3Du70NrXQMn_LOif2YFPRRVKo4wH5-f6eZcLlXEX0X8,71705
13
+ vanna/bedrock/__init__.py,sha256=hRT2bgJbHEqViLdL-t9hfjSfFdIOkPU2ADBt-B1En-8,46
14
+ vanna/bedrock/bedrock_converse.py,sha256=Nx5kYm-diAfYmsWAnTP5xnv7V84Og69-AP9b3seIe0E,2869
13
15
  vanna/chromadb/__init__.py,sha256=-iL0nW_g4uM8nWKMuWnNePfN4nb9uk8P3WzGvezOqRg,50
14
16
  vanna/chromadb/chromadb_vector.py,sha256=eKyPck99Y6Jt-BNWojvxLG-zvAERzLSm-3zY-bKXvaA,8792
15
17
  vanna/exceptions/__init__.py,sha256=dJ65xxxZh1lqBeg6nz6Tq_r34jLVmjvBvPO9Q6hFaQ8,685
16
- vanna/flask/__init__.py,sha256=urPrHUqM1mpx96VHiQWVXCy3NQwDh6OsSkm4V4wqccY,30211
18
+ vanna/flask/__init__.py,sha256=r1ucQupb6wuTcjVVKpkdrg6R38FZe6KQoKw9AtcghDQ,42889
17
19
  vanna/flask/assets.py,sha256=_UoUr57sS0QL2BuTxAOe9k4yy8T7-fp2NpbRSVtW3IM,451769
18
20
  vanna/flask/auth.py,sha256=UpKxh7W5cd43W0LGch0VqhncKwB78L6dtOQkl1JY5T0,1246
19
21
  vanna/google/__init__.py,sha256=M-dCxCZcKL4bTQyMLj6r6VRs65YNX9Tl2aoPCuqGm-8,41
20
- vanna/google/gemini_chat.py,sha256=ps3A-afFbCo3HeFTLL_nMoQO1PsGvRUUPRUppbMcDew,1584
22
+ vanna/google/gemini_chat.py,sha256=j1szC2PamMLFrs0Z4lYPS69i017FYICe-mNObNYFBPQ,1576
21
23
  vanna/hf/__init__.py,sha256=vD0bIhfLkA1UsvVSF4MAz3Da8aQunkQo3wlDztmMuj0,19
22
- vanna/hf/hf.py,sha256=v1v6sZnbj5xcrjgmvLP_ytS9NM7E5d0GyMfXXtr6BMU,2703
24
+ vanna/hf/hf.py,sha256=N8N5g3xvKDBt3dez2r_U0qATxbl2pN8SVLTZK9CSRA0,3020
23
25
  vanna/marqo/__init__.py,sha256=GaAWtJ0B-H5rTY607iLCCrLD7T0zMYM5qWIomEB9gLk,37
24
26
  vanna/marqo/marqo.py,sha256=W7WTtzWp4RJjZVy6OaXHqncUBIPdI4Q7qH7BRCxZ1_A,5242
25
27
  vanna/milvus/__init__.py,sha256=VBasJG2eTKbJI6CEand7kPLNBrqYrn0QCAhSYVz814s,46
@@ -33,7 +35,7 @@ vanna/mock/vectordb.py,sha256=h45znfYMUnttE2BBC8v6TKeMaA58pFJL-5B3OGeRNFI,2681
33
35
  vanna/ollama/__init__.py,sha256=4xyu8aHPdnEHg5a-QAMwr5o0ns5wevsp_zkI-ndMO2k,27
34
36
  vanna/ollama/ollama.py,sha256=rXa7cfvdlO1E5SLysXIl3IZpIaA2r0RBvV5jX2-upiE,3794
35
37
  vanna/openai/__init__.py,sha256=tGkeQ7wTIPsando7QhoSHehtoQVdYLwFbKNlSmCmNeQ,86
36
- vanna/openai/openai_chat.py,sha256=lm-hUsQxu6Q1t06A2csC037zI4VkMk0wFbQ-_Lj74Wg,4764
38
+ vanna/openai/openai_chat.py,sha256=KU6ynOQ5v7vwrQQ13phXoUXeQUrH6_vmhfiPvWddTrQ,4427
37
39
  vanna/openai/openai_embeddings.py,sha256=g4pNh9LVcYP9wOoO8ecaccDFWmCUYMInebfHucAa2Gc,1260
38
40
  vanna/opensearch/__init__.py,sha256=0unDevWOTs7o8S79TOHUKF1mSiuQbBUVm-7k9jV5WW4,54
39
41
  vanna/opensearch/opensearch_vector.py,sha256=VhIcrSyNzWR9ZrqrJnyGFOyuQZs3swfbhr8QyVGI0eI,12226
@@ -46,6 +48,8 @@ vanna/vannadb/__init__.py,sha256=C6UkYocmO6dmzfPKZaWojN0mI5YlZZ9VIbdcquBE58A,48
46
48
  vanna/vannadb/vannadb_vector.py,sha256=N8poMYvAojoaOF5gI4STD5pZWK9lBKPvyIjbh9dPBa0,14189
47
49
  vanna/vllm/__init__.py,sha256=aNlUkF9tbURdeXAJ8ytuaaF1gYwcG3ny1MfNl_cwQYg,23
48
50
  vanna/vllm/vllm.py,sha256=oM_aA-1Chyl7T_Qc_yRKlL6oSX1etsijY9zQdjeMGMQ,2827
49
- vanna-0.6.2.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
50
- vanna-0.6.2.dist-info/METADATA,sha256=RVle66HeuhBS8iaO0vD8_iDQqk9NbeO1pZOgCgKwh54,11628
51
- vanna-0.6.2.dist-info/RECORD,,
51
+ vanna/weaviate/__init__.py,sha256=HL6PAl7ePBAkeG8uln-BmM7IUtWohyTPvDfcPzSGSCg,46
52
+ vanna/weaviate/weaviate_vector.py,sha256=GEiu4Vd9w-7j10aB-zTxJ8gefqe_F-LUUGvttFs1vlg,7539
53
+ vanna-0.6.4.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
54
+ vanna-0.6.4.dist-info/METADATA,sha256=LqIi4Hg1y_aTEH79PX48nnY1TM-u6ese9K8Os9Cqkg0,11889
55
+ vanna-0.6.4.dist-info/RECORD,,
File without changes