MindsDB 25.1.4.0__py3-none-any.whl → 25.1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/METADATA +235 -246
- {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/RECORD +44 -42
- mindsdb/__about__.py +1 -1
- mindsdb/api/executor/datahub/datanodes/datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/integration_datanode.py +1 -1
- mindsdb/api/executor/datahub/datanodes/project_datanode.py +2 -26
- mindsdb/api/http/namespaces/agents.py +3 -1
- mindsdb/api/http/namespaces/knowledge_bases.py +4 -1
- mindsdb/integrations/handlers/databricks_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/file_handler/requirements.txt +0 -4
- mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_handler.py +1 -1
- mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +8 -0
- mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +4 -2
- mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +5 -3
- mindsdb/integrations/handlers/snowflake_handler/requirements.txt +1 -1
- mindsdb/integrations/handlers/web_handler/requirements.txt +0 -1
- mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -1
- mindsdb/integrations/libs/vectordatabase_handler.py +4 -3
- mindsdb/integrations/utilities/files/__init__.py +0 -0
- mindsdb/integrations/utilities/files/file_reader.py +258 -0
- mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +2 -1
- mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py +8 -3
- mindsdb/integrations/utilities/rag/chains/map_reduce_summarizer_chain.py +5 -9
- mindsdb/integrations/utilities/rag/pipelines/rag.py +1 -3
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +97 -89
- mindsdb/integrations/utilities/rag/settings.py +29 -14
- mindsdb/interfaces/agents/agents_controller.py +15 -3
- mindsdb/interfaces/agents/constants.py +1 -0
- mindsdb/interfaces/agents/langchain_agent.py +15 -10
- mindsdb/interfaces/agents/langfuse_callback_handler.py +4 -0
- mindsdb/interfaces/agents/mindsdb_database_agent.py +14 -0
- mindsdb/interfaces/database/integrations.py +5 -1
- mindsdb/interfaces/database/projects.py +38 -1
- mindsdb/interfaces/knowledge_base/controller.py +26 -11
- mindsdb/interfaces/knowledge_base/preprocessing/document_loader.py +7 -26
- mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +18 -10
- mindsdb/interfaces/skills/skill_tool.py +12 -6
- mindsdb/interfaces/skills/skills_controller.py +7 -3
- mindsdb/interfaces/skills/sql_agent.py +81 -18
- mindsdb/utilities/langfuse.py +15 -0
- {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/LICENSE +0 -0
- {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/WHEEL +0 -0
- {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
import json
|
|
3
|
+
import csv
|
|
4
|
+
from io import BytesIO, StringIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import codecs
|
|
7
|
+
|
|
8
|
+
import filetype
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from charset_normalizer import from_bytes
|
|
11
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
12
|
+
|
|
13
|
+
from mindsdb.utilities import log
|
|
14
|
+
|
|
15
|
+
logger = log.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
DEFAULT_CHUNK_SIZE = 500
|
|
18
|
+
DEFAULT_CHUNK_OVERLAP = 250
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FileDetectError(Exception):
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def decode(file_obj: BytesIO) -> StringIO:
|
|
26
|
+
byte_str = file_obj.read()
|
|
27
|
+
# Move it to StringIO
|
|
28
|
+
try:
|
|
29
|
+
# Handle Microsoft's BOM "special" UTF-8 encoding
|
|
30
|
+
if byte_str.startswith(codecs.BOM_UTF8):
|
|
31
|
+
data_str = StringIO(byte_str.decode("utf-8-sig"))
|
|
32
|
+
else:
|
|
33
|
+
file_encoding_meta = from_bytes(
|
|
34
|
+
byte_str[: 32 * 1024],
|
|
35
|
+
steps=32, # Number of steps/block to extract from my_byte_str
|
|
36
|
+
chunk_size=1024, # Set block size of each extraction)
|
|
37
|
+
explain=False,
|
|
38
|
+
)
|
|
39
|
+
best_meta = file_encoding_meta.best()
|
|
40
|
+
errors = "strict"
|
|
41
|
+
if best_meta is not None:
|
|
42
|
+
encoding = file_encoding_meta.best().encoding
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
data_str = StringIO(byte_str.decode(encoding, errors))
|
|
46
|
+
except UnicodeDecodeError:
|
|
47
|
+
encoding = "utf-8"
|
|
48
|
+
errors = "replace"
|
|
49
|
+
|
|
50
|
+
data_str = StringIO(byte_str.decode(encoding, errors))
|
|
51
|
+
else:
|
|
52
|
+
encoding = "utf-8"
|
|
53
|
+
errors = "replace"
|
|
54
|
+
|
|
55
|
+
data_str = StringIO(byte_str.decode(encoding, errors))
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.error(traceback.format_exc())
|
|
58
|
+
raise FileDetectError("Could not load into string") from e
|
|
59
|
+
|
|
60
|
+
return data_str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class FormatDetector:
|
|
64
|
+
|
|
65
|
+
def get(self, name, file_obj: BytesIO = None):
|
|
66
|
+
format = self.get_format_by_name(name)
|
|
67
|
+
if format is None and file_obj is not None:
|
|
68
|
+
format = self.get_format_by_content(file_obj)
|
|
69
|
+
|
|
70
|
+
if format is not None:
|
|
71
|
+
return format
|
|
72
|
+
raise FileDetectError(f'Unable to detect format: {name}')
|
|
73
|
+
|
|
74
|
+
def get_format_by_name(self, filename):
|
|
75
|
+
extension = Path(filename).suffix.strip(".").lower()
|
|
76
|
+
if extension == "tsv":
|
|
77
|
+
extension = "csv"
|
|
78
|
+
return extension or None
|
|
79
|
+
|
|
80
|
+
def get_format_by_content(self, file_obj):
|
|
81
|
+
if self.is_parquet(file_obj):
|
|
82
|
+
return "parquet"
|
|
83
|
+
|
|
84
|
+
file_type = filetype.guess(file_obj)
|
|
85
|
+
if file_type is None:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
if file_type.mime in {
|
|
89
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
90
|
+
"application/vnd.ms-excel",
|
|
91
|
+
}:
|
|
92
|
+
return 'xlsx'
|
|
93
|
+
|
|
94
|
+
if file_type.mime == 'application/pdf':
|
|
95
|
+
return "pdf"
|
|
96
|
+
|
|
97
|
+
file_obj = decode(file_obj)
|
|
98
|
+
|
|
99
|
+
if self.is_json(file_obj):
|
|
100
|
+
return "json"
|
|
101
|
+
|
|
102
|
+
if self.is_csv(file_obj):
|
|
103
|
+
return "csv"
|
|
104
|
+
|
|
105
|
+
def is_json(self, data_obj: StringIO) -> bool:
|
|
106
|
+
# see if its JSON
|
|
107
|
+
text = data_obj.read(100).strip()
|
|
108
|
+
data_obj.seek(0)
|
|
109
|
+
if len(text) > 0:
|
|
110
|
+
# it looks like a json, then try to parse it
|
|
111
|
+
if text.startswith("{") or text.startswith("["):
|
|
112
|
+
try:
|
|
113
|
+
json.loads(data_obj.read())
|
|
114
|
+
return True
|
|
115
|
+
except Exception:
|
|
116
|
+
return False
|
|
117
|
+
finally:
|
|
118
|
+
data_obj.seek(0)
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
def is_csv(self, data_obj: StringIO) -> bool:
|
|
122
|
+
sample = data_obj.readline() # trying to get dialect from header
|
|
123
|
+
data_obj.seek(0)
|
|
124
|
+
try:
|
|
125
|
+
csv.Sniffer().sniff(sample)
|
|
126
|
+
|
|
127
|
+
except Exception:
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
def is_parquet(self, data: BytesIO) -> bool:
|
|
131
|
+
# Check first and last 4 bytes equal to PAR1.
|
|
132
|
+
# Refer: https://parquet.apache.org/docs/file-format/
|
|
133
|
+
parquet_sig = b"PAR1"
|
|
134
|
+
data.seek(0, 0)
|
|
135
|
+
start_meta = data.read(4)
|
|
136
|
+
data.seek(-4, 2)
|
|
137
|
+
end_meta = data.read()
|
|
138
|
+
data.seek(0)
|
|
139
|
+
if start_meta == parquet_sig and end_meta == parquet_sig:
|
|
140
|
+
return True
|
|
141
|
+
return False
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class FileReader:
|
|
145
|
+
|
|
146
|
+
def _get_csv_dialect(self, buffer) -> csv.Dialect:
|
|
147
|
+
sample = buffer.readline() # trying to get dialect from header
|
|
148
|
+
buffer.seek(0)
|
|
149
|
+
try:
|
|
150
|
+
if isinstance(sample, bytes):
|
|
151
|
+
sample = sample.decode()
|
|
152
|
+
accepted_csv_delimiters = [",", "\t", ";"]
|
|
153
|
+
try:
|
|
154
|
+
dialect = csv.Sniffer().sniff(
|
|
155
|
+
sample, delimiters=accepted_csv_delimiters
|
|
156
|
+
)
|
|
157
|
+
dialect.doublequote = (
|
|
158
|
+
True # assume that all csvs have " as string escape
|
|
159
|
+
)
|
|
160
|
+
except Exception:
|
|
161
|
+
dialect = csv.reader(sample).dialect
|
|
162
|
+
if dialect.delimiter not in accepted_csv_delimiters:
|
|
163
|
+
raise Exception(
|
|
164
|
+
f"CSV delimeter '{dialect.delimiter}' is not supported"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
except csv.Error:
|
|
168
|
+
dialect = None
|
|
169
|
+
return dialect
|
|
170
|
+
|
|
171
|
+
def read(self, format, file_obj: BytesIO, **kwargs) -> pd.DataFrame:
|
|
172
|
+
func = {
|
|
173
|
+
'parquet': self.read_parquet,
|
|
174
|
+
'csv': self.read_csv,
|
|
175
|
+
'xlsx': self.read_excel,
|
|
176
|
+
'pdf': self.read_pdf,
|
|
177
|
+
'json': self.read_json,
|
|
178
|
+
'txt': self.read_txt,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
if format not in func:
|
|
182
|
+
raise FileDetectError(f'Unsupported format: {format}')
|
|
183
|
+
func = func[format]
|
|
184
|
+
|
|
185
|
+
return func(file_obj, **kwargs)
|
|
186
|
+
|
|
187
|
+
def read_csv(self, file_obj: BytesIO, **kwargs):
|
|
188
|
+
file_obj = decode(file_obj)
|
|
189
|
+
dialect = self._get_csv_dialect(file_obj)
|
|
190
|
+
|
|
191
|
+
return pd.read_csv(file_obj, sep=dialect.delimiter, index_col=False)
|
|
192
|
+
|
|
193
|
+
def read_txt(self, file_obj: BytesIO, **kwargs):
|
|
194
|
+
file_obj = decode(file_obj)
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
from langchain_core.documents import Document
|
|
198
|
+
except ImportError:
|
|
199
|
+
raise ImportError(
|
|
200
|
+
"To import TXT document please install 'langchain-community':\n"
|
|
201
|
+
" pip install langchain-community"
|
|
202
|
+
)
|
|
203
|
+
text = file_obj.read()
|
|
204
|
+
|
|
205
|
+
file_name = None
|
|
206
|
+
if hasattr(file_obj, "name"):
|
|
207
|
+
file_name = file_obj.name
|
|
208
|
+
metadata = {"source": file_name}
|
|
209
|
+
documents = [Document(page_content=text, metadata=metadata)]
|
|
210
|
+
|
|
211
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
|
212
|
+
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
docs = text_splitter.split_documents(documents)
|
|
216
|
+
return pd.DataFrame(
|
|
217
|
+
[
|
|
218
|
+
{"content": doc.page_content, "metadata": doc.metadata}
|
|
219
|
+
for doc in docs
|
|
220
|
+
]
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def read_pdf(self, file_obj: BytesIO, **kwargs):
|
|
224
|
+
import fitz # pymupdf
|
|
225
|
+
|
|
226
|
+
with fitz.open(stream=file_obj) as pdf: # open pdf
|
|
227
|
+
text = chr(12).join([page.get_text() for page in pdf])
|
|
228
|
+
|
|
229
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
|
230
|
+
chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
split_text = text_splitter.split_text(text)
|
|
234
|
+
|
|
235
|
+
return pd.DataFrame(
|
|
236
|
+
{"content": split_text, "metadata": [{}] * len(split_text)}
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def read_json(self, file_obj: BytesIO, **kwargs):
|
|
240
|
+
file_obj = decode(file_obj)
|
|
241
|
+
file_obj.seek(0)
|
|
242
|
+
json_doc = json.loads(file_obj.read())
|
|
243
|
+
return pd.json_normalize(json_doc, max_level=0)
|
|
244
|
+
|
|
245
|
+
def read_parquet(self, file_obj: BytesIO, **kwargs):
|
|
246
|
+
return pd.read_parquet(file_obj)
|
|
247
|
+
|
|
248
|
+
def read_excel(self, file_obj: BytesIO, sheet_name=None, **kwargs) -> pd.DataFrame:
|
|
249
|
+
|
|
250
|
+
file_obj.seek(0)
|
|
251
|
+
with pd.ExcelFile(file_obj) as xls:
|
|
252
|
+
if sheet_name is None:
|
|
253
|
+
# No sheet specified: Return list of sheets
|
|
254
|
+
sheet_list = xls.sheet_names
|
|
255
|
+
return pd.DataFrame(sheet_list, columns=["Sheet_Name"])
|
|
256
|
+
else:
|
|
257
|
+
# Specific sheet requested: Load that sheet
|
|
258
|
+
return pd.read_excel(xls, sheet_name=sheet_name)
|
|
@@ -131,7 +131,8 @@ class MSGraphAPIBaseClient:
|
|
|
131
131
|
response = self._make_request(api_url, params)
|
|
132
132
|
|
|
133
133
|
# If the response content is a binary file or a TSV file, return the raw content.
|
|
134
|
-
if response.headers["Content-Type"] in ("application/octet-stream", "text/
|
|
134
|
+
if response.headers["Content-Type"] in ("application/octet-stream", "text/plain",
|
|
135
|
+
"text/tab-separated-values", "application/pdf"):
|
|
135
136
|
return response.content
|
|
136
137
|
# Otherwise, return the JSON content.
|
|
137
138
|
else:
|
mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py
CHANGED
|
@@ -43,9 +43,14 @@ class MSGraphAPIDelegatedPermissionsManager:
|
|
|
43
43
|
# Set the redirect URI based on the request origin.
|
|
44
44
|
# If the request origin is 127.0.0.1 (localhost), replace it with localhost.
|
|
45
45
|
# This is done because the only HTTP origin allowed in Microsoft Entra ID app registration is localhost.
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
try:
|
|
47
|
+
request_origin = request.headers.get('ORIGIN') or (request.scheme + '://' + request.host)
|
|
48
|
+
if not request_origin:
|
|
49
|
+
raise AuthException('Request origin could not be determined!')
|
|
50
|
+
except RuntimeError:
|
|
51
|
+
# if it is outside of request context (streaming in agent)
|
|
52
|
+
request_origin = ''
|
|
53
|
+
|
|
49
54
|
request_origin = request_origin.replace('127.0.0.1', 'localhost') if 'http://127.0.0.1' in request_origin else request_origin
|
|
50
55
|
self.redirect_uri = request_origin + '/verify-auth'
|
|
51
56
|
|
|
@@ -23,7 +23,7 @@ logger = log.getLogger(__name__)
|
|
|
23
23
|
Summary = namedtuple('Summary', ['source_id', 'content'])
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) ->
|
|
26
|
+
def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
|
|
27
27
|
'''Creats a chain that map reduces documents into a single consolidated summary
|
|
28
28
|
|
|
29
29
|
Args:
|
|
@@ -43,7 +43,7 @@ def create_map_reduce_documents_chain(summarization_config: SummarizationConfig,
|
|
|
43
43
|
if 'input' in map_prompt.input_variables:
|
|
44
44
|
map_prompt = map_prompt.partial(input=input)
|
|
45
45
|
# Handles summarization of individual chunks.
|
|
46
|
-
map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt)
|
|
46
|
+
# map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt)
|
|
47
47
|
|
|
48
48
|
reduce_prompt_template = summarization_config.reduce_prompt_template
|
|
49
49
|
reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
|
|
@@ -60,18 +60,12 @@ def create_map_reduce_documents_chain(summarization_config: SummarizationConfig,
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
# Combines & iteratively reduces mapped documents.
|
|
63
|
-
|
|
63
|
+
return ReduceDocumentsChain(
|
|
64
64
|
combine_documents_chain=combine_documents_chain,
|
|
65
65
|
collapse_documents_chain=combine_documents_chain,
|
|
66
66
|
# Max number of tokens to group documents into.
|
|
67
67
|
token_max=summarization_config.max_summarization_tokens
|
|
68
68
|
)
|
|
69
|
-
return MapReduceDocumentsChain(
|
|
70
|
-
llm_chain=map_chain,
|
|
71
|
-
reduce_documents_chain=reduce_documents_chain,
|
|
72
|
-
document_variable_name='docs',
|
|
73
|
-
return_intermediate_steps=False
|
|
74
|
-
)
|
|
75
69
|
|
|
76
70
|
|
|
77
71
|
class MapReduceSummarizerChain(Chain):
|
|
@@ -135,6 +129,8 @@ class MapReduceSummarizerChain(Chain):
|
|
|
135
129
|
document_chunks = []
|
|
136
130
|
for _, row in all_source_chunks.iterrows():
|
|
137
131
|
metadata = row.get(self.metadata_column_name, {})
|
|
132
|
+
if row.get('chunk_id', None) is not None:
|
|
133
|
+
metadata['chunk_index'] = row.get('chunk_id', 0)
|
|
138
134
|
document_chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
|
|
139
135
|
# Sort by chunk index if present in metadata so the full document is in its original order.
|
|
140
136
|
document_chunks.sort(key=lambda doc: doc.metadata.get('chunk_index', 0) if doc.metadata else 0)
|
|
@@ -298,10 +298,8 @@ class LangChainRAGPipeline:
|
|
|
298
298
|
examples=retriever_config.examples,
|
|
299
299
|
embeddings_model=embeddings,
|
|
300
300
|
rewrite_prompt_template=retriever_config.rewrite_prompt_template,
|
|
301
|
-
|
|
301
|
+
metadata_filters_prompt_template=retriever_config.metadata_filters_prompt_template,
|
|
302
302
|
num_retries=retriever_config.num_retries,
|
|
303
|
-
sql_prompt_template=retriever_config.sql_prompt_template,
|
|
304
|
-
query_checker_template=retriever_config.query_checker_template,
|
|
305
303
|
embeddings_table=knowledge_base_table._kb.vector_database_table,
|
|
306
304
|
source_table=retriever_config.source_table,
|
|
307
305
|
distance_function=distance_function,
|
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
import json
|
|
2
|
-
|
|
2
|
+
import re
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
from typing import Any, List, Optional
|
|
3
5
|
|
|
4
6
|
from langchain.chains.llm import LLMChain
|
|
5
7
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
6
8
|
from langchain_core.documents.base import Document
|
|
7
9
|
from langchain_core.embeddings import Embeddings
|
|
10
|
+
from langchain_core.exceptions import OutputParserException
|
|
8
11
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
12
|
+
from langchain_core.output_parsers import PydanticOutputParser
|
|
9
13
|
from langchain_core.prompts import PromptTemplate
|
|
10
14
|
from langchain_core.retrievers import BaseRetriever
|
|
11
15
|
|
|
12
16
|
from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
|
|
17
|
+
from mindsdb.integrations.libs.response import HandlerResponse
|
|
13
18
|
from mindsdb.integrations.libs.vectordatabase_handler import DistanceFunction, VectorStoreHandler
|
|
14
19
|
from mindsdb.integrations.utilities.rag.settings import LLMExample, MetadataSchema, SearchKwargs
|
|
15
20
|
from mindsdb.utilities import log
|
|
@@ -17,6 +22,18 @@ from mindsdb.utilities import log
|
|
|
17
22
|
logger = log.getLogger(__name__)
|
|
18
23
|
|
|
19
24
|
|
|
25
|
+
class MetadataFilter(BaseModel):
|
|
26
|
+
'''Represents an LLM generated metadata filter to apply to a PostgreSQL query.'''
|
|
27
|
+
attribute: str = Field(description="Database column to apply filter to")
|
|
28
|
+
comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
|
|
29
|
+
value: Any = Field(description="Value to use to filter database column")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MetadataFilters(BaseModel):
|
|
33
|
+
'''List of LLM generated metadata filters to apply to a PostgreSQL query.'''
|
|
34
|
+
filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
|
|
35
|
+
|
|
36
|
+
|
|
20
37
|
class SQLRetriever(BaseRetriever):
|
|
21
38
|
'''Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
|
|
22
39
|
|
|
@@ -25,10 +42,10 @@ class SQLRetriever(BaseRetriever):
|
|
|
25
42
|
1. Use a LLM to rewrite the user input to something more suitable for retrieval. For example:
|
|
26
43
|
"Show me documents containing how to finetune a LLM please" --> "how to finetune a LLM"
|
|
27
44
|
|
|
28
|
-
2. Use a LLM to generate
|
|
29
|
-
metadata schemas & examples are used as additional context
|
|
45
|
+
2. Use a LLM to generate structured metadata filters based on the user input. Provided
|
|
46
|
+
metadata schemas & examples are used as additional context.
|
|
30
47
|
|
|
31
|
-
3.
|
|
48
|
+
3. Generate a prepared PostgreSQL query from the structured metadata filters.
|
|
32
49
|
|
|
33
50
|
4. Actually execute the query against our vector database to retrieve documents & return them.
|
|
34
51
|
'''
|
|
@@ -37,23 +54,22 @@ class SQLRetriever(BaseRetriever):
|
|
|
37
54
|
metadata_schemas: Optional[List[MetadataSchema]] = None
|
|
38
55
|
examples: Optional[List[LLMExample]] = None
|
|
39
56
|
|
|
40
|
-
embeddings_model: Embeddings
|
|
41
57
|
rewrite_prompt_template: str
|
|
42
|
-
|
|
58
|
+
metadata_filters_prompt_template: str
|
|
59
|
+
embeddings_model: Embeddings
|
|
43
60
|
num_retries: int
|
|
44
|
-
sql_prompt_template: str
|
|
45
|
-
query_checker_template: str
|
|
46
61
|
embeddings_table: str
|
|
47
62
|
source_table: str
|
|
63
|
+
source_id_column: str = 'Id'
|
|
48
64
|
distance_function: DistanceFunction
|
|
49
65
|
search_kwargs: SearchKwargs
|
|
50
66
|
|
|
51
67
|
llm: BaseChatModel
|
|
52
68
|
|
|
53
|
-
def
|
|
69
|
+
def _prepare_metadata_prompt(self) -> PromptTemplate:
|
|
54
70
|
base_prompt_template = PromptTemplate(
|
|
55
|
-
input_variables=['
|
|
56
|
-
template=self.
|
|
71
|
+
input_variables=['format_instructions', 'schema', 'examples', 'input', 'embeddings'],
|
|
72
|
+
template=self.metadata_filters_prompt_template
|
|
57
73
|
)
|
|
58
74
|
schema_prompt_str = ''
|
|
59
75
|
if self.metadata_schemas is not None:
|
|
@@ -67,7 +83,7 @@ class SQLRetriever(BaseRetriever):
|
|
|
67
83
|
if column.values is not None:
|
|
68
84
|
column_mapping[column.name]['values'] = column.values
|
|
69
85
|
column_mapping_json_str = json.dumps(column_mapping, indent=4)
|
|
70
|
-
schema_str = f'''{i+
|
|
86
|
+
schema_str = f'''{i+1}. {schema.table} - {schema.description}
|
|
71
87
|
|
|
72
88
|
Columns:
|
|
73
89
|
```json
|
|
@@ -86,7 +102,7 @@ Output:
|
|
|
86
102
|
{example.output}
|
|
87
103
|
|
|
88
104
|
'''
|
|
89
|
-
|
|
105
|
+
examples_prompt_str += example_str
|
|
90
106
|
return base_prompt_template.partial(
|
|
91
107
|
schema=schema_prompt_str,
|
|
92
108
|
examples=examples_prompt_str
|
|
@@ -100,97 +116,89 @@ Output:
|
|
|
100
116
|
rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
|
|
101
117
|
return rewrite_chain.predict(input=query)
|
|
102
118
|
|
|
103
|
-
def _prepare_pgvector_query(self,
|
|
104
|
-
#
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
query=query,
|
|
141
|
-
dialect='postgres',
|
|
142
|
-
error=error,
|
|
143
|
-
embeddings_table=self.embeddings_table,
|
|
144
|
-
schema=schema,
|
|
145
|
-
callbacks=run_manager.get_child() if run_manager else None
|
|
146
|
-
)
|
|
147
|
-
query_checker_prompt = PromptTemplate(
|
|
148
|
-
input_variables=['dialect', 'query'],
|
|
149
|
-
template=self.query_checker_template
|
|
150
|
-
)
|
|
151
|
-
query_checker_chain = LLMChain(llm=self.llm, prompt=query_checker_prompt)
|
|
152
|
-
# Check the query & return the final result to be executed.
|
|
153
|
-
return query_checker_chain.predict(
|
|
154
|
-
dialect='postgres',
|
|
155
|
-
query=sql_query
|
|
119
|
+
def _prepare_pgvector_query(self, metadata_filters: List[MetadataFilter]) -> str:
|
|
120
|
+
# Base select JOINed with document source table.
|
|
121
|
+
base_query = f'''SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" '''
|
|
122
|
+
col_to_schema = {}
|
|
123
|
+
if not self.metadata_schemas:
|
|
124
|
+
return ''
|
|
125
|
+
for schema in self.metadata_schemas:
|
|
126
|
+
for col in schema.columns:
|
|
127
|
+
col_to_schema[col.name] = schema
|
|
128
|
+
joined_schemas = set()
|
|
129
|
+
for filter in metadata_filters:
|
|
130
|
+
# Join schemas before filtering.
|
|
131
|
+
schema = col_to_schema.get(filter.attribute)
|
|
132
|
+
if schema is None or schema.table in joined_schemas or schema.table == self.source_table:
|
|
133
|
+
continue
|
|
134
|
+
joined_schemas.add(schema.table)
|
|
135
|
+
base_query += schema.join + ' '
|
|
136
|
+
# Actually construct WHERE conditions from metadata filters.
|
|
137
|
+
if metadata_filters:
|
|
138
|
+
base_query += 'WHERE '
|
|
139
|
+
for i, filter in enumerate(metadata_filters):
|
|
140
|
+
value = filter.value
|
|
141
|
+
if isinstance(value, str):
|
|
142
|
+
value = f"'{value}'"
|
|
143
|
+
base_query += f'"{filter.attribute}" {filter.comparator} {value}'
|
|
144
|
+
if i < len(metadata_filters) - 1:
|
|
145
|
+
base_query += ' AND '
|
|
146
|
+
base_query += f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
|
|
147
|
+
return base_query
|
|
148
|
+
|
|
149
|
+
def _generate_metadata_filters(self, query: str) -> List[MetadataFilter]:
|
|
150
|
+
parser = PydanticOutputParser(pydantic_object=MetadataFilters)
|
|
151
|
+
metadata_prompt = self._prepare_metadata_prompt()
|
|
152
|
+
metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
|
|
153
|
+
metadata_filters_output = metadata_filters_chain.predict(
|
|
154
|
+
format_instructions=parser.get_format_instructions(),
|
|
155
|
+
input=query
|
|
156
156
|
)
|
|
157
|
+
# If the LLM outputs raw JSON, use it as-is.
|
|
158
|
+
# If the LLM outputs anything including a json markdown section, use the last one.
|
|
159
|
+
json_markdown_output = re.findall(r'```json.*```', metadata_filters_output, re.DOTALL)
|
|
160
|
+
if json_markdown_output:
|
|
161
|
+
metadata_filters_output = json_markdown_output[-1]
|
|
162
|
+
# Clean the json tags.
|
|
163
|
+
metadata_filters_output = metadata_filters_output[7:]
|
|
164
|
+
metadata_filters_output = metadata_filters_output[:-3]
|
|
165
|
+
metadata_filters = parser.invoke(metadata_filters_output)
|
|
166
|
+
return metadata_filters.filters
|
|
167
|
+
|
|
168
|
+
def _prepare_and_execute_query(self, query: str, embeddings_str: str) -> HandlerResponse:
|
|
169
|
+
try:
|
|
170
|
+
metadata_filters = self._generate_metadata_filters(query)
|
|
171
|
+
checked_sql_query = self._prepare_pgvector_query(metadata_filters)
|
|
172
|
+
checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
|
|
173
|
+
return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
|
|
174
|
+
except OutputParserException as e:
|
|
175
|
+
logger.warning(f'LLM failed to generate structured metadata filters: {str(e)}')
|
|
176
|
+
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.warning(f'Failed to prepare and execute SQL query from structured metadata: {str(e)}')
|
|
179
|
+
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
|
|
157
180
|
|
|
158
181
|
def _get_relevant_documents(
|
|
159
182
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
160
183
|
) -> List[Document]:
|
|
161
184
|
# Rewrite query to be suitable for retrieval.
|
|
162
185
|
retrieval_query = self._prepare_retrieval_query(query)
|
|
163
|
-
|
|
164
|
-
# Generate & check the query to be executed
|
|
165
|
-
checked_sql_query = self._prepare_pgvector_query(query, run_manager)
|
|
166
|
-
|
|
167
186
|
# Embed the rewritten retrieval query & include it in the similarity search pgvector query.
|
|
168
187
|
embedded_query = self.embeddings_model.embed_query(retrieval_query)
|
|
169
|
-
checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=str(embedded_query))
|
|
170
|
-
# Handle LLM output that has the ```sql delimiter possibly.
|
|
171
|
-
checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```sql', '')
|
|
172
|
-
checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```', '')
|
|
173
188
|
# Actually execute the similarity search with metadata filters.
|
|
174
|
-
document_response = self.
|
|
189
|
+
document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
|
|
175
190
|
num_retries = 0
|
|
176
191
|
while num_retries < self.num_retries:
|
|
192
|
+
if document_response.resp_type != RESPONSE_TYPE.ERROR and len(document_response.data_frame) > 0:
|
|
193
|
+
# Successfully retrieved documents.
|
|
194
|
+
break
|
|
177
195
|
if document_response.resp_type == RESPONSE_TYPE.ERROR:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
logger.info(f'SQL Retriever query {checked_sql_query} failed with error {error_msg}')
|
|
181
|
-
checked_sql_query = self._prepare_retry_query(checked_sql_query, error_msg, run_manager)
|
|
196
|
+
# LLMs won't always generate structured metadata so we should have a fallback after retrying.
|
|
197
|
+
logger.info(f'SQL Retriever query failed with error {document_response.error_message}')
|
|
182
198
|
elif len(document_response.data_frame) == 0:
|
|
183
|
-
|
|
184
|
-
checked_sql_query = self._prepare_retry_query(checked_sql_query, error_msg, run_manager)
|
|
185
|
-
else:
|
|
186
|
-
break
|
|
187
|
-
|
|
188
|
-
checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=str(embedded_query))
|
|
189
|
-
# Handle LLM output that has the ```sql delimiter possibly.
|
|
190
|
-
checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```sql', '')
|
|
191
|
-
checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```', '')
|
|
192
|
-
document_response = self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
|
|
199
|
+
logger.info('No documents retrieved from SQL Retriever query')
|
|
193
200
|
|
|
201
|
+
document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
|
|
194
202
|
num_retries += 1
|
|
195
203
|
if num_retries >= self.num_retries:
|
|
196
204
|
logger.info('Using fallback retriever in SQL retriever.')
|