vanna 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vanna/hf/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .hf import Hf
vanna/hf/hf.py ADDED
@@ -0,0 +1,79 @@
1
+ import re
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ from ..base import VannaBase
5
+
6
+
7
+ class Hf(VannaBase):
8
+ def __init__(self, config=None):
9
+ model_name = self.config.get(
10
+ "model_name", None
11
+ ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
12
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ self.model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype="auto",
16
+ device_map="auto",
17
+ )
18
+
19
+ def system_message(self, message: str) -> any:
20
+ return {"role": "system", "content": message}
21
+
22
+ def user_message(self, message: str) -> any:
23
+ return {"role": "user", "content": message}
24
+
25
+ def assistant_message(self, message: str) -> any:
26
+ return {"role": "assistant", "content": message}
27
+
28
+ def extract_sql_query(self, text):
29
+ """
30
+ Extracts the first SQL statement after the word 'select', ignoring case,
31
+ matches until the first semicolon, three backticks, or the end of the string,
32
+ and removes three backticks if they exist in the extracted string.
33
+
34
+ Args:
35
+ - text (str): The string to search within for an SQL statement.
36
+
37
+ Returns:
38
+ - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
39
+ """
40
+ # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
41
+ pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
42
+
43
+ match = pattern.search(text)
44
+ if match:
45
+ # Remove three backticks from the matched string if they exist
46
+ return match.group(0).replace("```", "")
47
+ else:
48
+ return text
49
+
50
+ def generate_sql(self, question: str, **kwargs) -> str:
51
+ # Use the super generate_sql
52
+ sql = super().generate_sql(question, **kwargs)
53
+
54
+ # Replace "\_" with "_"
55
+ sql = sql.replace("\\_", "_")
56
+
57
+ sql = sql.replace("\\", "")
58
+
59
+ return self.extract_sql_query(sql)
60
+
61
+ def submit_prompt(self, prompt, **kwargs) -> str:
62
+
63
+ input_ids = self.tokenizer.apply_chat_template(
64
+ prompt, add_generation_prompt=True, return_tensors="pt"
65
+ ).to(self.model.device)
66
+
67
+ outputs = self.model.generate(
68
+ input_ids,
69
+ max_new_tokens=512,
70
+ eos_token_id=self.tokenizer.eos_token_id,
71
+ do_sample=True,
72
+ temperature=1,
73
+ top_p=0.9,
74
+ )
75
+ response = outputs[0][input_ids.shape[-1] :]
76
+ response = self.tokenizer.decode(response, skip_special_tokens=True)
77
+ self.log(response)
78
+
79
+ return response
vanna/ollama/ollama.py CHANGED
@@ -1,76 +1,101 @@
1
+ import json
1
2
  import re
2
3
 
3
- import requests
4
+ from httpx import Timeout
4
5
 
5
6
  from ..base import VannaBase
7
+ from ..exceptions import DependencyError
6
8
 
7
9
 
8
10
  class Ollama(VannaBase):
9
- def __init__(self, config=None):
10
- if config is None or "ollama_host" not in config:
11
- self.host = "http://localhost:11434"
12
- else:
13
- self.host = config["ollama_host"]
14
-
15
- if config is None or "model" not in config:
16
- raise ValueError("config must contain a Ollama model")
17
- else:
18
- self.model = config["model"]
19
-
20
- def system_message(self, message: str) -> any:
21
- return {"role": "system", "content": message}
22
-
23
- def user_message(self, message: str) -> any:
24
- return {"role": "user", "content": message}
25
-
26
- def assistant_message(self, message: str) -> any:
27
- return {"role": "assistant", "content": message}
28
-
29
- def extract_sql_query(self, text):
30
- """
31
- Extracts the first SQL statement after the word 'select', ignoring case,
32
- matches until the first semicolon, three backticks, or the end of the string,
33
- and removes three backticks if they exist in the extracted string.
34
-
35
- Args:
36
- - text (str): The string to search within for an SQL statement.
37
-
38
- Returns:
39
- - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
40
- """
41
- # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
42
- pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
43
-
44
- match = pattern.search(text)
45
- if match:
46
- # Remove three backticks from the matched string if they exist
47
- return match.group(0).replace("```", "")
48
- else:
49
- return text
50
-
51
- def generate_sql(self, question: str, **kwargs) -> str:
52
- # Use the super generate_sql
53
- sql = super().generate_sql(question, **kwargs)
54
-
55
- # Replace "\_" with "_"
56
- sql = sql.replace("\\_", "_")
57
-
58
- sql = sql.replace("\\", "")
59
-
60
- return self.extract_sql_query(sql)
61
-
62
- def submit_prompt(self, prompt, **kwargs) -> str:
63
- url = f"{self.host}/api/chat"
64
- data = {
65
- "model": self.model,
66
- "stream": False,
67
- "messages": prompt,
68
- }
69
-
70
- response = requests.post(url, json=data)
71
-
72
- response_dict = response.json()
73
-
74
- self.log(response.text)
75
-
76
- return response_dict["message"]["content"]
11
+ def __init__(self, config=None):
12
+
13
+ try:
14
+ ollama = __import__("ollama")
15
+ except ImportError:
16
+ raise DependencyError(
17
+ "You need to install required dependencies to execute this method, run command:"
18
+ " \npip install ollama"
19
+ )
20
+
21
+ if not config:
22
+ raise ValueError("config must contain at least Ollama model")
23
+ if 'model' not in config.keys():
24
+ raise ValueError("config must contain at least Ollama model")
25
+ self.host = config.get("ollama_host", "http://localhost:11434")
26
+ self.model = config["model"]
27
+ if ":" in self.model:
28
+ self.model += ":latest"
29
+
30
+ self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))
31
+ self.keep_alive = config.get('keep_alive', None)
32
+ self.ollama_options = config.get('options', {})
33
+ self.num_ctx = self.ollama_options.get('num_ctx', 2048)
34
+ self.__pull_model_if_ne(self.ollama_client, self.model)
35
+
36
+ @staticmethod
37
+ def __pull_model_if_ne(ollama_client, model):
38
+ model_response = ollama_client.list()
39
+ model_lists = [model_element['model'] for model_element in
40
+ model_response.get('models', [])]
41
+ if model not in model_lists:
42
+ ollama_client.pull(model)
43
+
44
+ def system_message(self, message: str) -> any:
45
+ return {"role": "system", "content": message}
46
+
47
+ def user_message(self, message: str) -> any:
48
+ return {"role": "user", "content": message}
49
+
50
+ def assistant_message(self, message: str) -> any:
51
+ return {"role": "assistant", "content": message}
52
+
53
+ def extract_sql(self, llm_response):
54
+ """
55
+ Extracts the first SQL statement after the word 'select', ignoring case,
56
+ matches until the first semicolon, three backticks, or the end of the string,
57
+ and removes three backticks if they exist in the extracted string.
58
+
59
+ Args:
60
+ - llm_response (str): The string to search within for an SQL statement.
61
+
62
+ Returns:
63
+ - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
64
+ """
65
+ # Remove ollama-generated extra characters
66
+ llm_response = llm_response.replace("\\_", "_")
67
+ llm_response = llm_response.replace("\\", "")
68
+
69
+ # Regular expression to find ```sql' and capture until '```'
70
+ sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
71
+ # Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
72
+ select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
73
+ llm_response,
74
+ re.IGNORECASE | re.DOTALL)
75
+ if sql:
76
+ self.log(
77
+ f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
78
+ return sql.group(1).replace("```", "")
79
+ elif select_with:
80
+ self.log(
81
+ f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
82
+ return select_with.group(0)
83
+ else:
84
+ return llm_response
85
+
86
+ def submit_prompt(self, prompt, **kwargs) -> str:
87
+ self.log(
88
+ f"Ollama parameters:\n"
89
+ f"model={self.model},\n"
90
+ f"options={self.ollama_options},\n"
91
+ f"keep_alive={self.keep_alive}")
92
+ self.log(f"Prompt Content:\n{json.dumps(prompt)}")
93
+ response_dict = self.ollama_client.chat(model=self.model,
94
+ messages=prompt,
95
+ stream=False,
96
+ options=self.ollama_options,
97
+ keep_alive=self.keep_alive)
98
+
99
+ self.log(f"Ollama Response:\n{str(response_dict)}")
100
+
101
+ return response_dict['message']['content']
@@ -0,0 +1 @@
1
+ from .opensearch_vector import OpenSearch_VectorStore
@@ -0,0 +1,289 @@
1
+ import base64
2
+ import uuid
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ from opensearchpy import OpenSearch
7
+
8
+ from ..base import VannaBase
9
+
10
+
11
+ class OpenSearch_VectorStore(VannaBase):
12
+ def __init__(self, config=None):
13
+ VannaBase.__init__(self, config=config)
14
+ document_index = "vanna_document_index"
15
+ ddl_index = "vanna_ddl_index"
16
+ question_sql_index = "vanna_questions_sql_index"
17
+ if config is not None and "es_document_index" in config:
18
+ document_index = config["es_document_index"]
19
+ if config is not None and "es_ddl_index" in config:
20
+ ddl_index = config["es_ddl_index"]
21
+ if config is not None and "es_question_sql_index" in config:
22
+ question_sql_index = config["es_question_sql_index"]
23
+
24
+ self.document_index = document_index
25
+ self.ddl_index = ddl_index
26
+ self.question_sql_index = question_sql_index
27
+ print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)
28
+
29
+ es_urls = None
30
+ if config is not None and "es_urls" in config:
31
+ es_urls = config["es_urls"]
32
+
33
+ # Host and port
34
+ if config is not None and "es_host" in config:
35
+ host = config["es_host"]
36
+ else:
37
+ host = "localhost"
38
+
39
+ if config is not None and "es_port" in config:
40
+ port = config["es_port"]
41
+ else:
42
+ port = 9200
43
+
44
+ if config is not None and "es_ssl" in config:
45
+ ssl = config["es_ssl"]
46
+ else:
47
+ ssl = False
48
+
49
+ if config is not None and "es_verify_certs" in config:
50
+ verify_certs = config["es_verify_certs"]
51
+ else:
52
+ verify_certs = False
53
+
54
+ # Authentication
55
+ if config is not None and "es_user" in config:
56
+ auth = (config["es_user"], config["es_password"])
57
+ else:
58
+ # Default to admin:admin
59
+ auth = None
60
+
61
+ headers = None
62
+ # base64 authentication
63
+ if config is not None and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
64
+ if config["es_encoded_base64"]:
65
+ encoded_credentials = base64.b64encode(
66
+ (config["es_user"] + ":" + config["es_password"]).encode("utf-8")
67
+ ).decode("utf-8")
68
+ headers = {
69
+ 'Authorization': 'Basic ' + encoded_credentials
70
+ }
71
+ # remove auth from config
72
+ auth = None
73
+
74
+ # custom headers
75
+ if config is not None and "es_headers" in config:
76
+ headers = config["es_headers"]
77
+
78
+ if config is not None and "es_timeout" in config:
79
+ timeout = config["es_timeout"]
80
+ else:
81
+ timeout = 60
82
+
83
+ if config is not None and "es_max_retries" in config:
84
+ max_retries = config["es_max_retries"]
85
+ else:
86
+ max_retries = 10
87
+
88
+ if es_urls is not None:
89
+ # Initialize the OpenSearch client by passing a list of URLs
90
+ self.client = OpenSearch(
91
+ hosts=[es_urls],
92
+ http_compress=True,
93
+ use_ssl=ssl,
94
+ verify_certs=verify_certs,
95
+ timeout=timeout,
96
+ max_retries=max_retries,
97
+ retry_on_timeout=True,
98
+ http_auth=auth,
99
+ headers=headers
100
+ )
101
+ else:
102
+ # Initialize the OpenSearch client by passing a host and port
103
+ self.client = OpenSearch(
104
+ hosts=[{'host': host, 'port': port}],
105
+ http_compress=True,
106
+ use_ssl=ssl,
107
+ verify_certs=verify_certs,
108
+ timeout=timeout,
109
+ max_retries=max_retries,
110
+ retry_on_timeout=True,
111
+ http_auth=auth,
112
+ headers=headers
113
+ )
114
+
115
+ # 执行一个简单的查询来检查连接
116
+ try:
117
+ info = self.client.info()
118
+ print('OpenSearch cluster info:', info)
119
+ except Exception as e:
120
+ print('Error connecting to OpenSearch cluster:', e)
121
+
122
+ # Create the indices if they don't exist
123
+ # self.create_index()
124
+
125
+ def create_index(self):
126
+ for index in [self.document_index, self.ddl_index, self.question_sql_index]:
127
+ try:
128
+ self.client.indices.create(index)
129
+ except Exception as e:
130
+ print("Error creating index: ", e)
131
+ print(f"opensearch index {index} already exists")
132
+ pass
133
+
134
+ def add_ddl(self, ddl: str, **kwargs) -> str:
135
+ # Assuming that you have a DDL index in your OpenSearch
136
+ id = str(uuid.uuid4()) + "-ddl"
137
+ ddl_dict = {
138
+ "ddl": ddl
139
+ }
140
+ response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
141
+ **kwargs)
142
+ return response['_id']
143
+
144
+ def add_documentation(self, doc: str, **kwargs) -> str:
145
+ # Assuming you have a documentation index in your OpenSearch
146
+ id = str(uuid.uuid4()) + "-doc"
147
+ doc_dict = {
148
+ "doc": doc
149
+ }
150
+ response = self.client.index(index=self.document_index, id=id,
151
+ body=doc_dict, **kwargs)
152
+ return response['_id']
153
+
154
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
155
+ # Assuming you have a Questions and SQL index in your OpenSearch
156
+ id = str(uuid.uuid4()) + "-sql"
157
+ question_sql_dict = {
158
+ "question": question,
159
+ "sql": sql
160
+ }
161
+ response = self.client.index(index=self.question_sql_index,
162
+ body=question_sql_dict, id=id,
163
+ **kwargs)
164
+ return response['_id']
165
+
166
+ def get_related_ddl(self, question: str, **kwargs) -> List[str]:
167
+ # Assume you have some vector search mechanism associated with your data
168
+ query = {
169
+ "query": {
170
+ "match": {
171
+ "ddl": question
172
+ }
173
+ }
174
+ }
175
+ response = self.client.search(index=self.ddl_index, body=query,
176
+ **kwargs)
177
+ return [hit['_source']['ddl'] for hit in response['hits']['hits']]
178
+
179
+ def get_related_documentation(self, question: str, **kwargs) -> List[str]:
180
+ query = {
181
+ "query": {
182
+ "match": {
183
+ "doc": question
184
+ }
185
+ }
186
+ }
187
+ response = self.client.search(index=self.document_index,
188
+ body=query,
189
+ **kwargs)
190
+ return [hit['_source']['doc'] for hit in response['hits']['hits']]
191
+
192
+ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
193
+ query = {
194
+ "query": {
195
+ "match": {
196
+ "question": question
197
+ }
198
+ }
199
+ }
200
+ response = self.client.search(index=self.question_sql_index,
201
+ body=query,
202
+ **kwargs)
203
+ return [(hit['_source']['question'], hit['_source']['sql']) for hit in
204
+ response['hits']['hits']]
205
+
206
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
207
+ # This will be a simple example pulling all data from an index
208
+ # WARNING: Do not use this approach in production for large indices!
209
+ data = []
210
+ response = self.client.search(
211
+ index=self.document_index,
212
+ body={"query": {"match_all": {}}},
213
+ size=1000
214
+ )
215
+ # records = [hit['_source'] for hit in response['hits']['hits']]
216
+ for hit in response['hits']['hits']:
217
+ data.append(
218
+ {
219
+ "id": hit["_id"],
220
+ "training_data_type": "documentation",
221
+ "question": "",
222
+ "content": hit["_source"]['doc'],
223
+ }
224
+ )
225
+
226
+ response = self.client.search(
227
+ index=self.question_sql_index,
228
+ body={"query": {"match_all": {}}},
229
+ size=1000
230
+ )
231
+ # records = [hit['_source'] for hit in response['hits']['hits']]
232
+ for hit in response['hits']['hits']:
233
+ data.append(
234
+ {
235
+ "id": hit["_id"],
236
+ "training_data_type": "sql",
237
+ "question": hit.get("_source", {}).get("question", ""),
238
+ "content": hit.get("_source", {}).get("sql", ""),
239
+ }
240
+ )
241
+
242
+ response = self.client.search(
243
+ index=self.ddl_index,
244
+ body={"query": {"match_all": {}}},
245
+ size=1000
246
+ )
247
+ # records = [hit['_source'] for hit in response['hits']['hits']]
248
+ for hit in response['hits']['hits']:
249
+ data.append(
250
+ {
251
+ "id": hit["_id"],
252
+ "training_data_type": "ddl",
253
+ "question": "",
254
+ "content": hit["_source"]['ddl'],
255
+ }
256
+ )
257
+
258
+ return pd.DataFrame(data)
259
+
260
+ def remove_training_data(self, id: str, **kwargs) -> bool:
261
+ try:
262
+ if id.endswith("-sql"):
263
+ self.client.delete(index=self.question_sql_index, id=id)
264
+ return True
265
+ elif id.endswith("-ddl"):
266
+ self.client.delete(index=self.ddl_index, id=id, **kwargs)
267
+ return True
268
+ elif id.endswith("-doc"):
269
+ self.client.delete(index=self.document_index, id=id, **kwargs)
270
+ return True
271
+ else:
272
+ return False
273
+ except Exception as e:
274
+ print("Error deleting training dataError deleting training data: ", e)
275
+ return False
276
+
277
+ def generate_embedding(self, data: str, **kwargs) -> list[float]:
278
+ # opensearch doesn't need to generate embeddings
279
+ pass
280
+
281
+
282
+ # OpenSearch_VectorStore.__init__(self, config={'es_urls':
283
+ # "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
284
+ # "admin", 'es_password': "admin", 'es_verify_certs': True})
285
+
286
+
287
+ # OpenSearch_VectorStore.__init__(self, config={'es_host':
288
+ # "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin",
289
+ # 'es_password': "admin", 'es_verify_certs': True})
vanna/vllm/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .vllm import Vllm
vanna/vllm/vllm.py ADDED
@@ -0,0 +1,76 @@
1
+ import re
2
+
3
+ import requests
4
+
5
+ from ..base import VannaBase
6
+
7
+
8
+ class Vllm(VannaBase):
9
+ def __init__(self, config=None):
10
+ if config is None or "vllm_host" not in config:
11
+ self.host = "http://localhost:8000"
12
+ else:
13
+ self.host = config["vllm_host"]
14
+
15
+ if config is None or "model" not in config:
16
+ raise ValueError("check the config for vllm")
17
+ else:
18
+ self.model = config["model"]
19
+
20
+ def system_message(self, message: str) -> any:
21
+ return {"role": "system", "content": message}
22
+
23
+ def user_message(self, message: str) -> any:
24
+ return {"role": "user", "content": message}
25
+
26
+ def assistant_message(self, message: str) -> any:
27
+ return {"role": "assistant", "content": message}
28
+
29
+ def extract_sql_query(self, text):
30
+ """
31
+ Extracts the first SQL statement after the word 'select', ignoring case,
32
+ matches until the first semicolon, three backticks, or the end of the string,
33
+ and removes three backticks if they exist in the extracted string.
34
+
35
+ Args:
36
+ - text (str): The string to search within for an SQL statement.
37
+
38
+ Returns:
39
+ - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
40
+ """
41
+ # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
42
+ pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
43
+
44
+ match = pattern.search(text)
45
+ if match:
46
+ # Remove three backticks from the matched string if they exist
47
+ return match.group(0).replace("```", "")
48
+ else:
49
+ return text
50
+
51
+ def generate_sql(self, question: str, **kwargs) -> str:
52
+ # Use the super generate_sql
53
+ sql = super().generate_sql(question, **kwargs)
54
+
55
+ # Replace "\_" with "_"
56
+ sql = sql.replace("\\_", "_")
57
+
58
+ sql = sql.replace("\\", "")
59
+
60
+ return self.extract_sql_query(sql)
61
+
62
+ def submit_prompt(self, prompt, **kwargs) -> str:
63
+ url = f"{self.host}/v1/chat/completions"
64
+ data = {
65
+ "model": self.model,
66
+ "stream": False,
67
+ "messages": prompt,
68
+ }
69
+
70
+ response = requests.post(url, json=data)
71
+
72
+ response_dict = response.json()
73
+
74
+ self.log(response.text)
75
+
76
+ return response_dict['choices'][0]['message']['content']