vanna 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vanna/base/base.py +150 -41
- vanna/flask/__init__.py +50 -9
- vanna/flask/assets.py +16 -16
- vanna/hf/__init__.py +1 -0
- vanna/hf/hf.py +79 -0
- vanna/ollama/ollama.py +94 -69
- vanna/opensearch/__init__.py +1 -0
- vanna/opensearch/opensearch_vector.py +289 -0
- vanna/vllm/__init__.py +1 -0
- vanna/vllm/vllm.py +76 -0
- {vanna-0.4.2.dist-info → vanna-0.5.0.dist-info}/METADATA +17 -1
- {vanna-0.4.2.dist-info → vanna-0.5.0.dist-info}/RECORD +13 -7
- {vanna-0.4.2.dist-info → vanna-0.5.0.dist-info}/WHEEL +0 -0
vanna/hf/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .hf import Hf
|
vanna/hf/hf.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
3
|
+
|
|
4
|
+
from ..base import VannaBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Hf(VannaBase):
|
|
8
|
+
def __init__(self, config=None):
|
|
9
|
+
model_name = self.config.get(
|
|
10
|
+
"model_name", None
|
|
11
|
+
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
|
|
12
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
13
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
14
|
+
model_name,
|
|
15
|
+
torch_dtype="auto",
|
|
16
|
+
device_map="auto",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def system_message(self, message: str) -> any:
|
|
20
|
+
return {"role": "system", "content": message}
|
|
21
|
+
|
|
22
|
+
def user_message(self, message: str) -> any:
|
|
23
|
+
return {"role": "user", "content": message}
|
|
24
|
+
|
|
25
|
+
def assistant_message(self, message: str) -> any:
|
|
26
|
+
return {"role": "assistant", "content": message}
|
|
27
|
+
|
|
28
|
+
def extract_sql_query(self, text):
|
|
29
|
+
"""
|
|
30
|
+
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
31
|
+
matches until the first semicolon, three backticks, or the end of the string,
|
|
32
|
+
and removes three backticks if they exist in the extracted string.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
- text (str): The string to search within for an SQL statement.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
39
|
+
"""
|
|
40
|
+
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
|
|
41
|
+
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
|
|
42
|
+
|
|
43
|
+
match = pattern.search(text)
|
|
44
|
+
if match:
|
|
45
|
+
# Remove three backticks from the matched string if they exist
|
|
46
|
+
return match.group(0).replace("```", "")
|
|
47
|
+
else:
|
|
48
|
+
return text
|
|
49
|
+
|
|
50
|
+
def generate_sql(self, question: str, **kwargs) -> str:
|
|
51
|
+
# Use the super generate_sql
|
|
52
|
+
sql = super().generate_sql(question, **kwargs)
|
|
53
|
+
|
|
54
|
+
# Replace "\_" with "_"
|
|
55
|
+
sql = sql.replace("\\_", "_")
|
|
56
|
+
|
|
57
|
+
sql = sql.replace("\\", "")
|
|
58
|
+
|
|
59
|
+
return self.extract_sql_query(sql)
|
|
60
|
+
|
|
61
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
62
|
+
|
|
63
|
+
input_ids = self.tokenizer.apply_chat_template(
|
|
64
|
+
prompt, add_generation_prompt=True, return_tensors="pt"
|
|
65
|
+
).to(self.model.device)
|
|
66
|
+
|
|
67
|
+
outputs = self.model.generate(
|
|
68
|
+
input_ids,
|
|
69
|
+
max_new_tokens=512,
|
|
70
|
+
eos_token_id=self.tokenizer.eos_token_id,
|
|
71
|
+
do_sample=True,
|
|
72
|
+
temperature=1,
|
|
73
|
+
top_p=0.9,
|
|
74
|
+
)
|
|
75
|
+
response = outputs[0][input_ids.shape[-1] :]
|
|
76
|
+
response = self.tokenizer.decode(response, skip_special_tokens=True)
|
|
77
|
+
self.log(response)
|
|
78
|
+
|
|
79
|
+
return response
|
vanna/ollama/ollama.py
CHANGED
|
@@ -1,76 +1,101 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import re
|
|
2
3
|
|
|
3
|
-
import
|
|
4
|
+
from httpx import Timeout
|
|
4
5
|
|
|
5
6
|
from ..base import VannaBase
|
|
7
|
+
from ..exceptions import DependencyError
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class Ollama(VannaBase):
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
11
|
+
def __init__(self, config=None):
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
ollama = __import__("ollama")
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise DependencyError(
|
|
17
|
+
"You need to install required dependencies to execute this method, run command:"
|
|
18
|
+
" \npip install ollama"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if not config:
|
|
22
|
+
raise ValueError("config must contain at least Ollama model")
|
|
23
|
+
if 'model' not in config.keys():
|
|
24
|
+
raise ValueError("config must contain at least Ollama model")
|
|
25
|
+
self.host = config.get("ollama_host", "http://localhost:11434")
|
|
26
|
+
self.model = config["model"]
|
|
27
|
+
if ":" in self.model:
|
|
28
|
+
self.model += ":latest"
|
|
29
|
+
|
|
30
|
+
self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))
|
|
31
|
+
self.keep_alive = config.get('keep_alive', None)
|
|
32
|
+
self.ollama_options = config.get('options', {})
|
|
33
|
+
self.num_ctx = self.ollama_options.get('num_ctx', 2048)
|
|
34
|
+
self.__pull_model_if_ne(self.ollama_client, self.model)
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def __pull_model_if_ne(ollama_client, model):
|
|
38
|
+
model_response = ollama_client.list()
|
|
39
|
+
model_lists = [model_element['model'] for model_element in
|
|
40
|
+
model_response.get('models', [])]
|
|
41
|
+
if model not in model_lists:
|
|
42
|
+
ollama_client.pull(model)
|
|
43
|
+
|
|
44
|
+
def system_message(self, message: str) -> any:
|
|
45
|
+
return {"role": "system", "content": message}
|
|
46
|
+
|
|
47
|
+
def user_message(self, message: str) -> any:
|
|
48
|
+
return {"role": "user", "content": message}
|
|
49
|
+
|
|
50
|
+
def assistant_message(self, message: str) -> any:
|
|
51
|
+
return {"role": "assistant", "content": message}
|
|
52
|
+
|
|
53
|
+
def extract_sql(self, llm_response):
|
|
54
|
+
"""
|
|
55
|
+
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
56
|
+
matches until the first semicolon, three backticks, or the end of the string,
|
|
57
|
+
and removes three backticks if they exist in the extracted string.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
- llm_response (str): The string to search within for an SQL statement.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
64
|
+
"""
|
|
65
|
+
# Remove ollama-generated extra characters
|
|
66
|
+
llm_response = llm_response.replace("\\_", "_")
|
|
67
|
+
llm_response = llm_response.replace("\\", "")
|
|
68
|
+
|
|
69
|
+
# Regular expression to find ```sql' and capture until '```'
|
|
70
|
+
sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
|
|
71
|
+
# Regular expression to find 'select, with (ignoring case) and capture until ';', [ (this happens in case of mistral) or end of string
|
|
72
|
+
select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
|
|
73
|
+
llm_response,
|
|
74
|
+
re.IGNORECASE | re.DOTALL)
|
|
75
|
+
if sql:
|
|
76
|
+
self.log(
|
|
77
|
+
f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
|
|
78
|
+
return sql.group(1).replace("```", "")
|
|
79
|
+
elif select_with:
|
|
80
|
+
self.log(
|
|
81
|
+
f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
|
|
82
|
+
return select_with.group(0)
|
|
83
|
+
else:
|
|
84
|
+
return llm_response
|
|
85
|
+
|
|
86
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
87
|
+
self.log(
|
|
88
|
+
f"Ollama parameters:\n"
|
|
89
|
+
f"model={self.model},\n"
|
|
90
|
+
f"options={self.ollama_options},\n"
|
|
91
|
+
f"keep_alive={self.keep_alive}")
|
|
92
|
+
self.log(f"Prompt Content:\n{json.dumps(prompt)}")
|
|
93
|
+
response_dict = self.ollama_client.chat(model=self.model,
|
|
94
|
+
messages=prompt,
|
|
95
|
+
stream=False,
|
|
96
|
+
options=self.ollama_options,
|
|
97
|
+
keep_alive=self.keep_alive)
|
|
98
|
+
|
|
99
|
+
self.log(f"Ollama Response:\n{str(response_dict)}")
|
|
100
|
+
|
|
101
|
+
return response_dict['message']['content']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .opensearch_vector import OpenSearch_VectorStore
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from opensearchpy import OpenSearch
|
|
7
|
+
|
|
8
|
+
from ..base import VannaBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenSearch_VectorStore(VannaBase):
|
|
12
|
+
def __init__(self, config=None):
|
|
13
|
+
VannaBase.__init__(self, config=config)
|
|
14
|
+
document_index = "vanna_document_index"
|
|
15
|
+
ddl_index = "vanna_ddl_index"
|
|
16
|
+
question_sql_index = "vanna_questions_sql_index"
|
|
17
|
+
if config is not None and "es_document_index" in config:
|
|
18
|
+
document_index = config["es_document_index"]
|
|
19
|
+
if config is not None and "es_ddl_index" in config:
|
|
20
|
+
ddl_index = config["es_ddl_index"]
|
|
21
|
+
if config is not None and "es_question_sql_index" in config:
|
|
22
|
+
question_sql_index = config["es_question_sql_index"]
|
|
23
|
+
|
|
24
|
+
self.document_index = document_index
|
|
25
|
+
self.ddl_index = ddl_index
|
|
26
|
+
self.question_sql_index = question_sql_index
|
|
27
|
+
print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)
|
|
28
|
+
|
|
29
|
+
es_urls = None
|
|
30
|
+
if config is not None and "es_urls" in config:
|
|
31
|
+
es_urls = config["es_urls"]
|
|
32
|
+
|
|
33
|
+
# Host and port
|
|
34
|
+
if config is not None and "es_host" in config:
|
|
35
|
+
host = config["es_host"]
|
|
36
|
+
else:
|
|
37
|
+
host = "localhost"
|
|
38
|
+
|
|
39
|
+
if config is not None and "es_port" in config:
|
|
40
|
+
port = config["es_port"]
|
|
41
|
+
else:
|
|
42
|
+
port = 9200
|
|
43
|
+
|
|
44
|
+
if config is not None and "es_ssl" in config:
|
|
45
|
+
ssl = config["es_ssl"]
|
|
46
|
+
else:
|
|
47
|
+
ssl = False
|
|
48
|
+
|
|
49
|
+
if config is not None and "es_verify_certs" in config:
|
|
50
|
+
verify_certs = config["es_verify_certs"]
|
|
51
|
+
else:
|
|
52
|
+
verify_certs = False
|
|
53
|
+
|
|
54
|
+
# Authentication
|
|
55
|
+
if config is not None and "es_user" in config:
|
|
56
|
+
auth = (config["es_user"], config["es_password"])
|
|
57
|
+
else:
|
|
58
|
+
# Default to admin:admin
|
|
59
|
+
auth = None
|
|
60
|
+
|
|
61
|
+
headers = None
|
|
62
|
+
# base64 authentication
|
|
63
|
+
if config is not None and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
|
|
64
|
+
if config["es_encoded_base64"]:
|
|
65
|
+
encoded_credentials = base64.b64encode(
|
|
66
|
+
(config["es_user"] + ":" + config["es_password"]).encode("utf-8")
|
|
67
|
+
).decode("utf-8")
|
|
68
|
+
headers = {
|
|
69
|
+
'Authorization': 'Basic ' + encoded_credentials
|
|
70
|
+
}
|
|
71
|
+
# remove auth from config
|
|
72
|
+
auth = None
|
|
73
|
+
|
|
74
|
+
# custom headers
|
|
75
|
+
if config is not None and "es_headers" in config:
|
|
76
|
+
headers = config["es_headers"]
|
|
77
|
+
|
|
78
|
+
if config is not None and "es_timeout" in config:
|
|
79
|
+
timeout = config["es_timeout"]
|
|
80
|
+
else:
|
|
81
|
+
timeout = 60
|
|
82
|
+
|
|
83
|
+
if config is not None and "es_max_retries" in config:
|
|
84
|
+
max_retries = config["es_max_retries"]
|
|
85
|
+
else:
|
|
86
|
+
max_retries = 10
|
|
87
|
+
|
|
88
|
+
if es_urls is not None:
|
|
89
|
+
# Initialize the OpenSearch client by passing a list of URLs
|
|
90
|
+
self.client = OpenSearch(
|
|
91
|
+
hosts=[es_urls],
|
|
92
|
+
http_compress=True,
|
|
93
|
+
use_ssl=ssl,
|
|
94
|
+
verify_certs=verify_certs,
|
|
95
|
+
timeout=timeout,
|
|
96
|
+
max_retries=max_retries,
|
|
97
|
+
retry_on_timeout=True,
|
|
98
|
+
http_auth=auth,
|
|
99
|
+
headers=headers
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
# Initialize the OpenSearch client by passing a host and port
|
|
103
|
+
self.client = OpenSearch(
|
|
104
|
+
hosts=[{'host': host, 'port': port}],
|
|
105
|
+
http_compress=True,
|
|
106
|
+
use_ssl=ssl,
|
|
107
|
+
verify_certs=verify_certs,
|
|
108
|
+
timeout=timeout,
|
|
109
|
+
max_retries=max_retries,
|
|
110
|
+
retry_on_timeout=True,
|
|
111
|
+
http_auth=auth,
|
|
112
|
+
headers=headers
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# 执行一个简单的查询来检查连接
|
|
116
|
+
try:
|
|
117
|
+
info = self.client.info()
|
|
118
|
+
print('OpenSearch cluster info:', info)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
print('Error connecting to OpenSearch cluster:', e)
|
|
121
|
+
|
|
122
|
+
# Create the indices if they don't exist
|
|
123
|
+
# self.create_index()
|
|
124
|
+
|
|
125
|
+
def create_index(self):
|
|
126
|
+
for index in [self.document_index, self.ddl_index, self.question_sql_index]:
|
|
127
|
+
try:
|
|
128
|
+
self.client.indices.create(index)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print("Error creating index: ", e)
|
|
131
|
+
print(f"opensearch index {index} already exists")
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
|
135
|
+
# Assuming that you have a DDL index in your OpenSearch
|
|
136
|
+
id = str(uuid.uuid4()) + "-ddl"
|
|
137
|
+
ddl_dict = {
|
|
138
|
+
"ddl": ddl
|
|
139
|
+
}
|
|
140
|
+
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
|
|
141
|
+
**kwargs)
|
|
142
|
+
return response['_id']
|
|
143
|
+
|
|
144
|
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
|
145
|
+
# Assuming you have a documentation index in your OpenSearch
|
|
146
|
+
id = str(uuid.uuid4()) + "-doc"
|
|
147
|
+
doc_dict = {
|
|
148
|
+
"doc": doc
|
|
149
|
+
}
|
|
150
|
+
response = self.client.index(index=self.document_index, id=id,
|
|
151
|
+
body=doc_dict, **kwargs)
|
|
152
|
+
return response['_id']
|
|
153
|
+
|
|
154
|
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
155
|
+
# Assuming you have a Questions and SQL index in your OpenSearch
|
|
156
|
+
id = str(uuid.uuid4()) + "-sql"
|
|
157
|
+
question_sql_dict = {
|
|
158
|
+
"question": question,
|
|
159
|
+
"sql": sql
|
|
160
|
+
}
|
|
161
|
+
response = self.client.index(index=self.question_sql_index,
|
|
162
|
+
body=question_sql_dict, id=id,
|
|
163
|
+
**kwargs)
|
|
164
|
+
return response['_id']
|
|
165
|
+
|
|
166
|
+
def get_related_ddl(self, question: str, **kwargs) -> List[str]:
|
|
167
|
+
# Assume you have some vector search mechanism associated with your data
|
|
168
|
+
query = {
|
|
169
|
+
"query": {
|
|
170
|
+
"match": {
|
|
171
|
+
"ddl": question
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
response = self.client.search(index=self.ddl_index, body=query,
|
|
176
|
+
**kwargs)
|
|
177
|
+
return [hit['_source']['ddl'] for hit in response['hits']['hits']]
|
|
178
|
+
|
|
179
|
+
def get_related_documentation(self, question: str, **kwargs) -> List[str]:
|
|
180
|
+
query = {
|
|
181
|
+
"query": {
|
|
182
|
+
"match": {
|
|
183
|
+
"doc": question
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
response = self.client.search(index=self.document_index,
|
|
188
|
+
body=query,
|
|
189
|
+
**kwargs)
|
|
190
|
+
return [hit['_source']['doc'] for hit in response['hits']['hits']]
|
|
191
|
+
|
|
192
|
+
def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
|
|
193
|
+
query = {
|
|
194
|
+
"query": {
|
|
195
|
+
"match": {
|
|
196
|
+
"question": question
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
response = self.client.search(index=self.question_sql_index,
|
|
201
|
+
body=query,
|
|
202
|
+
**kwargs)
|
|
203
|
+
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
|
|
204
|
+
response['hits']['hits']]
|
|
205
|
+
|
|
206
|
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
|
207
|
+
# This will be a simple example pulling all data from an index
|
|
208
|
+
# WARNING: Do not use this approach in production for large indices!
|
|
209
|
+
data = []
|
|
210
|
+
response = self.client.search(
|
|
211
|
+
index=self.document_index,
|
|
212
|
+
body={"query": {"match_all": {}}},
|
|
213
|
+
size=1000
|
|
214
|
+
)
|
|
215
|
+
# records = [hit['_source'] for hit in response['hits']['hits']]
|
|
216
|
+
for hit in response['hits']['hits']:
|
|
217
|
+
data.append(
|
|
218
|
+
{
|
|
219
|
+
"id": hit["_id"],
|
|
220
|
+
"training_data_type": "documentation",
|
|
221
|
+
"question": "",
|
|
222
|
+
"content": hit["_source"]['doc'],
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
response = self.client.search(
|
|
227
|
+
index=self.question_sql_index,
|
|
228
|
+
body={"query": {"match_all": {}}},
|
|
229
|
+
size=1000
|
|
230
|
+
)
|
|
231
|
+
# records = [hit['_source'] for hit in response['hits']['hits']]
|
|
232
|
+
for hit in response['hits']['hits']:
|
|
233
|
+
data.append(
|
|
234
|
+
{
|
|
235
|
+
"id": hit["_id"],
|
|
236
|
+
"training_data_type": "sql",
|
|
237
|
+
"question": hit.get("_source", {}).get("question", ""),
|
|
238
|
+
"content": hit.get("_source", {}).get("sql", ""),
|
|
239
|
+
}
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
response = self.client.search(
|
|
243
|
+
index=self.ddl_index,
|
|
244
|
+
body={"query": {"match_all": {}}},
|
|
245
|
+
size=1000
|
|
246
|
+
)
|
|
247
|
+
# records = [hit['_source'] for hit in response['hits']['hits']]
|
|
248
|
+
for hit in response['hits']['hits']:
|
|
249
|
+
data.append(
|
|
250
|
+
{
|
|
251
|
+
"id": hit["_id"],
|
|
252
|
+
"training_data_type": "ddl",
|
|
253
|
+
"question": "",
|
|
254
|
+
"content": hit["_source"]['ddl'],
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return pd.DataFrame(data)
|
|
259
|
+
|
|
260
|
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
261
|
+
try:
|
|
262
|
+
if id.endswith("-sql"):
|
|
263
|
+
self.client.delete(index=self.question_sql_index, id=id)
|
|
264
|
+
return True
|
|
265
|
+
elif id.endswith("-ddl"):
|
|
266
|
+
self.client.delete(index=self.ddl_index, id=id, **kwargs)
|
|
267
|
+
return True
|
|
268
|
+
elif id.endswith("-doc"):
|
|
269
|
+
self.client.delete(index=self.document_index, id=id, **kwargs)
|
|
270
|
+
return True
|
|
271
|
+
else:
|
|
272
|
+
return False
|
|
273
|
+
except Exception as e:
|
|
274
|
+
print("Error deleting training dataError deleting training data: ", e)
|
|
275
|
+
return False
|
|
276
|
+
|
|
277
|
+
def generate_embedding(self, data: str, **kwargs) -> list[float]:
|
|
278
|
+
# opensearch doesn't need to generate embeddings
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# OpenSearch_VectorStore.__init__(self, config={'es_urls':
|
|
283
|
+
# "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
|
|
284
|
+
# "admin", 'es_password': "admin", 'es_verify_certs': True})
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# OpenSearch_VectorStore.__init__(self, config={'es_host':
|
|
288
|
+
# "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin",
|
|
289
|
+
# 'es_password': "admin", 'es_verify_certs': True})
|
vanna/vllm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .vllm import Vllm
|
vanna/vllm/vllm.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from ..base import VannaBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Vllm(VannaBase):
|
|
9
|
+
def __init__(self, config=None):
|
|
10
|
+
if config is None or "vllm_host" not in config:
|
|
11
|
+
self.host = "http://localhost:8000"
|
|
12
|
+
else:
|
|
13
|
+
self.host = config["vllm_host"]
|
|
14
|
+
|
|
15
|
+
if config is None or "model" not in config:
|
|
16
|
+
raise ValueError("check the config for vllm")
|
|
17
|
+
else:
|
|
18
|
+
self.model = config["model"]
|
|
19
|
+
|
|
20
|
+
def system_message(self, message: str) -> any:
|
|
21
|
+
return {"role": "system", "content": message}
|
|
22
|
+
|
|
23
|
+
def user_message(self, message: str) -> any:
|
|
24
|
+
return {"role": "user", "content": message}
|
|
25
|
+
|
|
26
|
+
def assistant_message(self, message: str) -> any:
|
|
27
|
+
return {"role": "assistant", "content": message}
|
|
28
|
+
|
|
29
|
+
def extract_sql_query(self, text):
|
|
30
|
+
"""
|
|
31
|
+
Extracts the first SQL statement after the word 'select', ignoring case,
|
|
32
|
+
matches until the first semicolon, three backticks, or the end of the string,
|
|
33
|
+
and removes three backticks if they exist in the extracted string.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
- text (str): The string to search within for an SQL statement.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
|
|
40
|
+
"""
|
|
41
|
+
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
|
|
42
|
+
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
|
|
43
|
+
|
|
44
|
+
match = pattern.search(text)
|
|
45
|
+
if match:
|
|
46
|
+
# Remove three backticks from the matched string if they exist
|
|
47
|
+
return match.group(0).replace("```", "")
|
|
48
|
+
else:
|
|
49
|
+
return text
|
|
50
|
+
|
|
51
|
+
def generate_sql(self, question: str, **kwargs) -> str:
|
|
52
|
+
# Use the super generate_sql
|
|
53
|
+
sql = super().generate_sql(question, **kwargs)
|
|
54
|
+
|
|
55
|
+
# Replace "\_" with "_"
|
|
56
|
+
sql = sql.replace("\\_", "_")
|
|
57
|
+
|
|
58
|
+
sql = sql.replace("\\", "")
|
|
59
|
+
|
|
60
|
+
return self.extract_sql_query(sql)
|
|
61
|
+
|
|
62
|
+
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
63
|
+
url = f"{self.host}/v1/chat/completions"
|
|
64
|
+
data = {
|
|
65
|
+
"model": self.model,
|
|
66
|
+
"stream": False,
|
|
67
|
+
"messages": prompt,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
response = requests.post(url, json=data)
|
|
71
|
+
|
|
72
|
+
response_dict = response.json()
|
|
73
|
+
|
|
74
|
+
self.log(response.text)
|
|
75
|
+
|
|
76
|
+
return response_dict['choices'][0]['message']['content']
|