MindsDB 25.1.4.0__py3-none-any.whl → 25.1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MindsDB might be problematic. Click here for more details.

Files changed (44) hide show
  1. {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/METADATA +235 -246
  2. {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/RECORD +44 -42
  3. mindsdb/__about__.py +1 -1
  4. mindsdb/api/executor/datahub/datanodes/datanode.py +1 -1
  5. mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +1 -1
  6. mindsdb/api/executor/datahub/datanodes/integration_datanode.py +1 -1
  7. mindsdb/api/executor/datahub/datanodes/project_datanode.py +2 -26
  8. mindsdb/api/http/namespaces/agents.py +3 -1
  9. mindsdb/api/http/namespaces/knowledge_bases.py +4 -1
  10. mindsdb/integrations/handlers/databricks_handler/requirements.txt +1 -1
  11. mindsdb/integrations/handlers/file_handler/requirements.txt +0 -4
  12. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_handler.py +1 -1
  13. mindsdb/integrations/handlers/ms_one_drive_handler/ms_one_drive_tables.py +8 -0
  14. mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +4 -2
  15. mindsdb/integrations/handlers/ray_serve_handler/ray_serve_handler.py +5 -3
  16. mindsdb/integrations/handlers/snowflake_handler/requirements.txt +1 -1
  17. mindsdb/integrations/handlers/web_handler/requirements.txt +0 -1
  18. mindsdb/integrations/libs/ml_handler_process/learn_process.py +1 -1
  19. mindsdb/integrations/libs/vectordatabase_handler.py +4 -3
  20. mindsdb/integrations/utilities/files/__init__.py +0 -0
  21. mindsdb/integrations/utilities/files/file_reader.py +258 -0
  22. mindsdb/integrations/utilities/handlers/api_utilities/microsoft/ms_graph_api_utilities.py +2 -1
  23. mindsdb/integrations/utilities/handlers/auth_utilities/microsoft/ms_graph_api_auth_utilities.py +8 -3
  24. mindsdb/integrations/utilities/rag/chains/map_reduce_summarizer_chain.py +5 -9
  25. mindsdb/integrations/utilities/rag/pipelines/rag.py +1 -3
  26. mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +97 -89
  27. mindsdb/integrations/utilities/rag/settings.py +29 -14
  28. mindsdb/interfaces/agents/agents_controller.py +15 -3
  29. mindsdb/interfaces/agents/constants.py +1 -0
  30. mindsdb/interfaces/agents/langchain_agent.py +15 -10
  31. mindsdb/interfaces/agents/langfuse_callback_handler.py +4 -0
  32. mindsdb/interfaces/agents/mindsdb_database_agent.py +14 -0
  33. mindsdb/interfaces/database/integrations.py +5 -1
  34. mindsdb/interfaces/database/projects.py +38 -1
  35. mindsdb/interfaces/knowledge_base/controller.py +26 -11
  36. mindsdb/interfaces/knowledge_base/preprocessing/document_loader.py +7 -26
  37. mindsdb/interfaces/skills/custom/text2sql/mindsdb_sql_toolkit.py +18 -10
  38. mindsdb/interfaces/skills/skill_tool.py +12 -6
  39. mindsdb/interfaces/skills/skills_controller.py +7 -3
  40. mindsdb/interfaces/skills/sql_agent.py +81 -18
  41. mindsdb/utilities/langfuse.py +15 -0
  42. {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/LICENSE +0 -0
  43. {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/WHEEL +0 -0
  44. {MindsDB-25.1.4.0.dist-info → MindsDB-25.1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,258 @@
1
+ import traceback
2
+ import json
3
+ import csv
4
+ from io import BytesIO, StringIO
5
+ from pathlib import Path
6
+ import codecs
7
+
8
+ import filetype
9
+ import pandas as pd
10
+ from charset_normalizer import from_bytes
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+
13
+ from mindsdb.utilities import log
14
+
15
+ logger = log.getLogger(__name__)
16
+
17
+ DEFAULT_CHUNK_SIZE = 500
18
+ DEFAULT_CHUNK_OVERLAP = 250
19
+
20
+
21
+ class FileDetectError(Exception):
22
+ ...
23
+
24
+
25
+ def decode(file_obj: BytesIO) -> StringIO:
26
+ byte_str = file_obj.read()
27
+ # Move it to StringIO
28
+ try:
29
+ # Handle Microsoft's BOM "special" UTF-8 encoding
30
+ if byte_str.startswith(codecs.BOM_UTF8):
31
+ data_str = StringIO(byte_str.decode("utf-8-sig"))
32
+ else:
33
+ file_encoding_meta = from_bytes(
34
+ byte_str[: 32 * 1024],
35
+ steps=32, # Number of steps/block to extract from my_byte_str
36
+ chunk_size=1024, # Set block size of each extraction)
37
+ explain=False,
38
+ )
39
+ best_meta = file_encoding_meta.best()
40
+ errors = "strict"
41
+ if best_meta is not None:
42
+ encoding = file_encoding_meta.best().encoding
43
+
44
+ try:
45
+ data_str = StringIO(byte_str.decode(encoding, errors))
46
+ except UnicodeDecodeError:
47
+ encoding = "utf-8"
48
+ errors = "replace"
49
+
50
+ data_str = StringIO(byte_str.decode(encoding, errors))
51
+ else:
52
+ encoding = "utf-8"
53
+ errors = "replace"
54
+
55
+ data_str = StringIO(byte_str.decode(encoding, errors))
56
+ except Exception as e:
57
+ logger.error(traceback.format_exc())
58
+ raise FileDetectError("Could not load into string") from e
59
+
60
+ return data_str
61
+
62
+
63
+ class FormatDetector:
64
+
65
+ def get(self, name, file_obj: BytesIO = None):
66
+ format = self.get_format_by_name(name)
67
+ if format is None and file_obj is not None:
68
+ format = self.get_format_by_content(file_obj)
69
+
70
+ if format is not None:
71
+ return format
72
+ raise FileDetectError(f'Unable to detect format: {name}')
73
+
74
+ def get_format_by_name(self, filename):
75
+ extension = Path(filename).suffix.strip(".").lower()
76
+ if extension == "tsv":
77
+ extension = "csv"
78
+ return extension or None
79
+
80
+ def get_format_by_content(self, file_obj):
81
+ if self.is_parquet(file_obj):
82
+ return "parquet"
83
+
84
+ file_type = filetype.guess(file_obj)
85
+ if file_type is None:
86
+ return
87
+
88
+ if file_type.mime in {
89
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
90
+ "application/vnd.ms-excel",
91
+ }:
92
+ return 'xlsx'
93
+
94
+ if file_type.mime == 'application/pdf':
95
+ return "pdf"
96
+
97
+ file_obj = decode(file_obj)
98
+
99
+ if self.is_json(file_obj):
100
+ return "json"
101
+
102
+ if self.is_csv(file_obj):
103
+ return "csv"
104
+
105
+ def is_json(self, data_obj: StringIO) -> bool:
106
+ # see if its JSON
107
+ text = data_obj.read(100).strip()
108
+ data_obj.seek(0)
109
+ if len(text) > 0:
110
+ # it looks like a json, then try to parse it
111
+ if text.startswith("{") or text.startswith("["):
112
+ try:
113
+ json.loads(data_obj.read())
114
+ return True
115
+ except Exception:
116
+ return False
117
+ finally:
118
+ data_obj.seek(0)
119
+ return False
120
+
121
+ def is_csv(self, data_obj: StringIO) -> bool:
122
+ sample = data_obj.readline() # trying to get dialect from header
123
+ data_obj.seek(0)
124
+ try:
125
+ csv.Sniffer().sniff(sample)
126
+
127
+ except Exception:
128
+ return False
129
+
130
+ def is_parquet(self, data: BytesIO) -> bool:
131
+ # Check first and last 4 bytes equal to PAR1.
132
+ # Refer: https://parquet.apache.org/docs/file-format/
133
+ parquet_sig = b"PAR1"
134
+ data.seek(0, 0)
135
+ start_meta = data.read(4)
136
+ data.seek(-4, 2)
137
+ end_meta = data.read()
138
+ data.seek(0)
139
+ if start_meta == parquet_sig and end_meta == parquet_sig:
140
+ return True
141
+ return False
142
+
143
+
144
+ class FileReader:
145
+
146
+ def _get_csv_dialect(self, buffer) -> csv.Dialect:
147
+ sample = buffer.readline() # trying to get dialect from header
148
+ buffer.seek(0)
149
+ try:
150
+ if isinstance(sample, bytes):
151
+ sample = sample.decode()
152
+ accepted_csv_delimiters = [",", "\t", ";"]
153
+ try:
154
+ dialect = csv.Sniffer().sniff(
155
+ sample, delimiters=accepted_csv_delimiters
156
+ )
157
+ dialect.doublequote = (
158
+ True # assume that all csvs have " as string escape
159
+ )
160
+ except Exception:
161
+ dialect = csv.reader(sample).dialect
162
+ if dialect.delimiter not in accepted_csv_delimiters:
163
+ raise Exception(
164
+ f"CSV delimeter '{dialect.delimiter}' is not supported"
165
+ )
166
+
167
+ except csv.Error:
168
+ dialect = None
169
+ return dialect
170
+
171
+ def read(self, format, file_obj: BytesIO, **kwargs) -> pd.DataFrame:
172
+ func = {
173
+ 'parquet': self.read_parquet,
174
+ 'csv': self.read_csv,
175
+ 'xlsx': self.read_excel,
176
+ 'pdf': self.read_pdf,
177
+ 'json': self.read_json,
178
+ 'txt': self.read_txt,
179
+ }
180
+
181
+ if format not in func:
182
+ raise FileDetectError(f'Unsupported format: {format}')
183
+ func = func[format]
184
+
185
+ return func(file_obj, **kwargs)
186
+
187
+ def read_csv(self, file_obj: BytesIO, **kwargs):
188
+ file_obj = decode(file_obj)
189
+ dialect = self._get_csv_dialect(file_obj)
190
+
191
+ return pd.read_csv(file_obj, sep=dialect.delimiter, index_col=False)
192
+
193
+ def read_txt(self, file_obj: BytesIO, **kwargs):
194
+ file_obj = decode(file_obj)
195
+
196
+ try:
197
+ from langchain_core.documents import Document
198
+ except ImportError:
199
+ raise ImportError(
200
+ "To import TXT document please install 'langchain-community':\n"
201
+ " pip install langchain-community"
202
+ )
203
+ text = file_obj.read()
204
+
205
+ file_name = None
206
+ if hasattr(file_obj, "name"):
207
+ file_name = file_obj.name
208
+ metadata = {"source": file_name}
209
+ documents = [Document(page_content=text, metadata=metadata)]
210
+
211
+ text_splitter = RecursiveCharacterTextSplitter(
212
+ chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP
213
+ )
214
+
215
+ docs = text_splitter.split_documents(documents)
216
+ return pd.DataFrame(
217
+ [
218
+ {"content": doc.page_content, "metadata": doc.metadata}
219
+ for doc in docs
220
+ ]
221
+ )
222
+
223
+ def read_pdf(self, file_obj: BytesIO, **kwargs):
224
+ import fitz # pymupdf
225
+
226
+ with fitz.open(stream=file_obj) as pdf: # open pdf
227
+ text = chr(12).join([page.get_text() for page in pdf])
228
+
229
+ text_splitter = RecursiveCharacterTextSplitter(
230
+ chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP
231
+ )
232
+
233
+ split_text = text_splitter.split_text(text)
234
+
235
+ return pd.DataFrame(
236
+ {"content": split_text, "metadata": [{}] * len(split_text)}
237
+ )
238
+
239
+ def read_json(self, file_obj: BytesIO, **kwargs):
240
+ file_obj = decode(file_obj)
241
+ file_obj.seek(0)
242
+ json_doc = json.loads(file_obj.read())
243
+ return pd.json_normalize(json_doc, max_level=0)
244
+
245
+ def read_parquet(self, file_obj: BytesIO, **kwargs):
246
+ return pd.read_parquet(file_obj)
247
+
248
+ def read_excel(self, file_obj: BytesIO, sheet_name=None, **kwargs) -> pd.DataFrame:
249
+
250
+ file_obj.seek(0)
251
+ with pd.ExcelFile(file_obj) as xls:
252
+ if sheet_name is None:
253
+ # No sheet specified: Return list of sheets
254
+ sheet_list = xls.sheet_names
255
+ return pd.DataFrame(sheet_list, columns=["Sheet_Name"])
256
+ else:
257
+ # Specific sheet requested: Load that sheet
258
+ return pd.read_excel(xls, sheet_name=sheet_name)
@@ -131,7 +131,8 @@ class MSGraphAPIBaseClient:
131
131
  response = self._make_request(api_url, params)
132
132
 
133
133
  # If the response content is a binary file or a TSV file, return the raw content.
134
- if response.headers["Content-Type"] in ("application/octet-stream", "text/tab-separated-values"):
134
+ if response.headers["Content-Type"] in ("application/octet-stream", "text/plain",
135
+ "text/tab-separated-values", "application/pdf"):
135
136
  return response.content
136
137
  # Otherwise, return the JSON content.
137
138
  else:
@@ -43,9 +43,14 @@ class MSGraphAPIDelegatedPermissionsManager:
43
43
  # Set the redirect URI based on the request origin.
44
44
  # If the request origin is 127.0.0.1 (localhost), replace it with localhost.
45
45
  # This is done because the only HTTP origin allowed in Microsoft Entra ID app registration is localhost.
46
- request_origin = request.headers.get('ORIGIN') or (request.scheme + '://' + request.host)
47
- if not request_origin:
48
- raise AuthException('Request origin could not be determined!')
46
+ try:
47
+ request_origin = request.headers.get('ORIGIN') or (request.scheme + '://' + request.host)
48
+ if not request_origin:
49
+ raise AuthException('Request origin could not be determined!')
50
+ except RuntimeError:
51
+ # if it is outside of request context (streaming in agent)
52
+ request_origin = ''
53
+
49
54
  request_origin = request_origin.replace('127.0.0.1', 'localhost') if 'http://127.0.0.1' in request_origin else request_origin
50
55
  self.redirect_uri = request_origin + '/verify-auth'
51
56
 
@@ -23,7 +23,7 @@ logger = log.getLogger(__name__)
23
23
  Summary = namedtuple('Summary', ['source_id', 'content'])
24
24
 
25
25
 
26
- def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> MapReduceDocumentsChain:
26
+ def create_map_reduce_documents_chain(summarization_config: SummarizationConfig, input: str) -> ReduceDocumentsChain:
27
27
  '''Creats a chain that map reduces documents into a single consolidated summary
28
28
 
29
29
  Args:
@@ -43,7 +43,7 @@ def create_map_reduce_documents_chain(summarization_config: SummarizationConfig,
43
43
  if 'input' in map_prompt.input_variables:
44
44
  map_prompt = map_prompt.partial(input=input)
45
45
  # Handles summarization of individual chunks.
46
- map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt)
46
+ # map_chain = LLMChain(llm=summarization_llm, prompt=map_prompt)
47
47
 
48
48
  reduce_prompt_template = summarization_config.reduce_prompt_template
49
49
  reduce_prompt = PromptTemplate.from_template(reduce_prompt_template)
@@ -60,18 +60,12 @@ def create_map_reduce_documents_chain(summarization_config: SummarizationConfig,
60
60
  )
61
61
 
62
62
  # Combines & iteratively reduces mapped documents.
63
- reduce_documents_chain = ReduceDocumentsChain(
63
+ return ReduceDocumentsChain(
64
64
  combine_documents_chain=combine_documents_chain,
65
65
  collapse_documents_chain=combine_documents_chain,
66
66
  # Max number of tokens to group documents into.
67
67
  token_max=summarization_config.max_summarization_tokens
68
68
  )
69
- return MapReduceDocumentsChain(
70
- llm_chain=map_chain,
71
- reduce_documents_chain=reduce_documents_chain,
72
- document_variable_name='docs',
73
- return_intermediate_steps=False
74
- )
75
69
 
76
70
 
77
71
  class MapReduceSummarizerChain(Chain):
@@ -135,6 +129,8 @@ class MapReduceSummarizerChain(Chain):
135
129
  document_chunks = []
136
130
  for _, row in all_source_chunks.iterrows():
137
131
  metadata = row.get(self.metadata_column_name, {})
132
+ if row.get('chunk_id', None) is not None:
133
+ metadata['chunk_index'] = row.get('chunk_id', 0)
138
134
  document_chunks.append(Document(page_content=row[self.content_column_name], metadata=metadata))
139
135
  # Sort by chunk index if present in metadata so the full document is in its original order.
140
136
  document_chunks.sort(key=lambda doc: doc.metadata.get('chunk_index', 0) if doc.metadata else 0)
@@ -298,10 +298,8 @@ class LangChainRAGPipeline:
298
298
  examples=retriever_config.examples,
299
299
  embeddings_model=embeddings,
300
300
  rewrite_prompt_template=retriever_config.rewrite_prompt_template,
301
- retry_prompt_template=retriever_config.query_retry_template,
301
+ metadata_filters_prompt_template=retriever_config.metadata_filters_prompt_template,
302
302
  num_retries=retriever_config.num_retries,
303
- sql_prompt_template=retriever_config.sql_prompt_template,
304
- query_checker_template=retriever_config.query_checker_template,
305
303
  embeddings_table=knowledge_base_table._kb.vector_database_table,
306
304
  source_table=retriever_config.source_table,
307
305
  distance_function=distance_function,
@@ -1,15 +1,20 @@
1
1
  import json
2
- from typing import List, Optional
2
+ import re
3
+ from pydantic import BaseModel, Field
4
+ from typing import Any, List, Optional
3
5
 
4
6
  from langchain.chains.llm import LLMChain
5
7
  from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
6
8
  from langchain_core.documents.base import Document
7
9
  from langchain_core.embeddings import Embeddings
10
+ from langchain_core.exceptions import OutputParserException
8
11
  from langchain_core.language_models.chat_models import BaseChatModel
12
+ from langchain_core.output_parsers import PydanticOutputParser
9
13
  from langchain_core.prompts import PromptTemplate
10
14
  from langchain_core.retrievers import BaseRetriever
11
15
 
12
16
  from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
17
+ from mindsdb.integrations.libs.response import HandlerResponse
13
18
  from mindsdb.integrations.libs.vectordatabase_handler import DistanceFunction, VectorStoreHandler
14
19
  from mindsdb.integrations.utilities.rag.settings import LLMExample, MetadataSchema, SearchKwargs
15
20
  from mindsdb.utilities import log
@@ -17,6 +22,18 @@ from mindsdb.utilities import log
17
22
  logger = log.getLogger(__name__)
18
23
 
19
24
 
25
+ class MetadataFilter(BaseModel):
26
+ '''Represents an LLM generated metadata filter to apply to a PostgreSQL query.'''
27
+ attribute: str = Field(description="Database column to apply filter to")
28
+ comparator: str = Field(description="PostgreSQL comparator to use to filter database column")
29
+ value: Any = Field(description="Value to use to filter database column")
30
+
31
+
32
+ class MetadataFilters(BaseModel):
33
+ '''List of LLM generated metadata filters to apply to a PostgreSQL query.'''
34
+ filters: List[MetadataFilter] = Field(description="List of PostgreSQL metadata filters to apply for user query")
35
+
36
+
20
37
  class SQLRetriever(BaseRetriever):
21
38
  '''Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
22
39
 
@@ -25,10 +42,10 @@ class SQLRetriever(BaseRetriever):
25
42
  1. Use a LLM to rewrite the user input to something more suitable for retrieval. For example:
26
43
  "Show me documents containing how to finetune a LLM please" --> "how to finetune a LLM"
27
44
 
28
- 2. Use a LLM to generate a pgvector query with metadata filters based on the user input. Provided
29
- metadata schemas & examples are used as additional context to generate the query.
45
+ 2. Use a LLM to generate structured metadata filters based on the user input. Provided
46
+ metadata schemas & examples are used as additional context.
30
47
 
31
- 3. Use a LLM to double check the generated pgvector query is correct.
48
+ 3. Generate a prepared PostgreSQL query from the structured metadata filters.
32
49
 
33
50
  4. Actually execute the query against our vector database to retrieve documents & return them.
34
51
  '''
@@ -37,23 +54,22 @@ class SQLRetriever(BaseRetriever):
37
54
  metadata_schemas: Optional[List[MetadataSchema]] = None
38
55
  examples: Optional[List[LLMExample]] = None
39
56
 
40
- embeddings_model: Embeddings
41
57
  rewrite_prompt_template: str
42
- retry_prompt_template: str
58
+ metadata_filters_prompt_template: str
59
+ embeddings_model: Embeddings
43
60
  num_retries: int
44
- sql_prompt_template: str
45
- query_checker_template: str
46
61
  embeddings_table: str
47
62
  source_table: str
63
+ source_id_column: str = 'Id'
48
64
  distance_function: DistanceFunction
49
65
  search_kwargs: SearchKwargs
50
66
 
51
67
  llm: BaseChatModel
52
68
 
53
- def _prepare_sql_prompt(self) -> PromptTemplate:
69
+ def _prepare_metadata_prompt(self) -> PromptTemplate:
54
70
  base_prompt_template = PromptTemplate(
55
- input_variables=['dialect', 'input', 'embeddings_table', 'source_table', 'embeddings', 'distance_function', 'schema', 'examples'],
56
- template=self.sql_prompt_template
71
+ input_variables=['format_instructions', 'schema', 'examples', 'input', 'embeddings'],
72
+ template=self.metadata_filters_prompt_template
57
73
  )
58
74
  schema_prompt_str = ''
59
75
  if self.metadata_schemas is not None:
@@ -67,7 +83,7 @@ class SQLRetriever(BaseRetriever):
67
83
  if column.values is not None:
68
84
  column_mapping[column.name]['values'] = column.values
69
85
  column_mapping_json_str = json.dumps(column_mapping, indent=4)
70
- schema_str = f'''{i+2}. {schema.table} - {schema.description}
86
+ schema_str = f'''{i+1}. {schema.table} - {schema.description}
71
87
 
72
88
  Columns:
73
89
  ```json
@@ -86,7 +102,7 @@ Output:
86
102
  {example.output}
87
103
 
88
104
  '''
89
- examples_prompt_str += example_str
105
+ examples_prompt_str += example_str
90
106
  return base_prompt_template.partial(
91
107
  schema=schema_prompt_str,
92
108
  examples=examples_prompt_str
@@ -100,97 +116,89 @@ Output:
100
116
  rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
101
117
  return rewrite_chain.predict(input=query)
102
118
 
103
- def _prepare_pgvector_query(self, query: str, run_manager: CallbackManagerForRetrieverRun) -> str:
104
- # Incorporate metadata schemas & examples into prompt.
105
- sql_prompt = self._prepare_sql_prompt()
106
- sql_chain = LLMChain(llm=self.llm, prompt=sql_prompt)
107
- # Generate the initial pgvector query.
108
- sql_query = sql_chain.predict(
109
- # Only pgvector & similarity search is supported.
110
- dialect='postgres',
111
- input=query,
112
- embeddings_table=self.embeddings_table,
113
- source_table=self.source_table,
114
- distance_function=self.distance_function.value[0],
115
- k=self.search_kwargs.k,
116
- callbacks=run_manager.get_child() if run_manager else None
117
- )
118
- query_checker_prompt = PromptTemplate(
119
- input_variables=['dialect', 'query'],
120
- template=self.query_checker_template
121
- )
122
- query_checker_chain = LLMChain(llm=self.llm, prompt=query_checker_prompt)
123
- # Check the query & return the final result to be executed.
124
- return query_checker_chain.predict(
125
- dialect='postgres',
126
- query=sql_query
127
- )
128
-
129
- def _prepare_retry_query(self, query: str, error: str, run_manager: CallbackManagerForRetrieverRun) -> str:
130
- sql_prompt = self._prepare_sql_prompt()
131
- # Use provided schema as context for retrying failed queries.
132
- schema = sql_prompt.partial_variables.get('schema', '')
133
- retry_prompt = PromptTemplate(
134
- input_variables=['query', 'dialect', 'error', 'embeddings_table', 'schema'],
135
- template=self.retry_prompt_template
136
- )
137
- retry_chain = LLMChain(llm=self.llm, prompt=retry_prompt)
138
- # Generate rewritten query.
139
- sql_query = retry_chain.predict(
140
- query=query,
141
- dialect='postgres',
142
- error=error,
143
- embeddings_table=self.embeddings_table,
144
- schema=schema,
145
- callbacks=run_manager.get_child() if run_manager else None
146
- )
147
- query_checker_prompt = PromptTemplate(
148
- input_variables=['dialect', 'query'],
149
- template=self.query_checker_template
150
- )
151
- query_checker_chain = LLMChain(llm=self.llm, prompt=query_checker_prompt)
152
- # Check the query & return the final result to be executed.
153
- return query_checker_chain.predict(
154
- dialect='postgres',
155
- query=sql_query
119
+ def _prepare_pgvector_query(self, metadata_filters: List[MetadataFilter]) -> str:
120
+ # Base select JOINed with document source table.
121
+ base_query = f'''SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" '''
122
+ col_to_schema = {}
123
+ if not self.metadata_schemas:
124
+ return ''
125
+ for schema in self.metadata_schemas:
126
+ for col in schema.columns:
127
+ col_to_schema[col.name] = schema
128
+ joined_schemas = set()
129
+ for filter in metadata_filters:
130
+ # Join schemas before filtering.
131
+ schema = col_to_schema.get(filter.attribute)
132
+ if schema is None or schema.table in joined_schemas or schema.table == self.source_table:
133
+ continue
134
+ joined_schemas.add(schema.table)
135
+ base_query += schema.join + ' '
136
+ # Actually construct WHERE conditions from metadata filters.
137
+ if metadata_filters:
138
+ base_query += 'WHERE '
139
+ for i, filter in enumerate(metadata_filters):
140
+ value = filter.value
141
+ if isinstance(value, str):
142
+ value = f"'{value}'"
143
+ base_query += f'"{filter.attribute}" {filter.comparator} {value}'
144
+ if i < len(metadata_filters) - 1:
145
+ base_query += ' AND '
146
+ base_query += f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
147
+ return base_query
148
+
149
+ def _generate_metadata_filters(self, query: str) -> List[MetadataFilter]:
150
+ parser = PydanticOutputParser(pydantic_object=MetadataFilters)
151
+ metadata_prompt = self._prepare_metadata_prompt()
152
+ metadata_filters_chain = LLMChain(llm=self.llm, prompt=metadata_prompt)
153
+ metadata_filters_output = metadata_filters_chain.predict(
154
+ format_instructions=parser.get_format_instructions(),
155
+ input=query
156
156
  )
157
+ # If the LLM outputs raw JSON, use it as-is.
158
+ # If the LLM outputs anything including a json markdown section, use the last one.
159
+ json_markdown_output = re.findall(r'```json.*```', metadata_filters_output, re.DOTALL)
160
+ if json_markdown_output:
161
+ metadata_filters_output = json_markdown_output[-1]
162
+ # Clean the json tags.
163
+ metadata_filters_output = metadata_filters_output[7:]
164
+ metadata_filters_output = metadata_filters_output[:-3]
165
+ metadata_filters = parser.invoke(metadata_filters_output)
166
+ return metadata_filters.filters
167
+
168
+ def _prepare_and_execute_query(self, query: str, embeddings_str: str) -> HandlerResponse:
169
+ try:
170
+ metadata_filters = self._generate_metadata_filters(query)
171
+ checked_sql_query = self._prepare_pgvector_query(metadata_filters)
172
+ checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str)
173
+ return self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
174
+ except OutputParserException as e:
175
+ logger.warning(f'LLM failed to generate structured metadata filters: {str(e)}')
176
+ return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
177
+ except Exception as e:
178
+ logger.warning(f'Failed to prepare and execute SQL query from structured metadata: {str(e)}')
179
+ return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
157
180
 
158
181
  def _get_relevant_documents(
159
182
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
160
183
  ) -> List[Document]:
161
184
  # Rewrite query to be suitable for retrieval.
162
185
  retrieval_query = self._prepare_retrieval_query(query)
163
-
164
- # Generate & check the query to be executed
165
- checked_sql_query = self._prepare_pgvector_query(query, run_manager)
166
-
167
186
  # Embed the rewritten retrieval query & include it in the similarity search pgvector query.
168
187
  embedded_query = self.embeddings_model.embed_query(retrieval_query)
169
- checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=str(embedded_query))
170
- # Handle LLM output that has the ```sql delimiter possibly.
171
- checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```sql', '')
172
- checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```', '')
173
188
  # Actually execute the similarity search with metadata filters.
174
- document_response = self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
189
+ document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
175
190
  num_retries = 0
176
191
  while num_retries < self.num_retries:
192
+ if document_response.resp_type != RESPONSE_TYPE.ERROR and len(document_response.data_frame) > 0:
193
+ # Successfully retrieved documents.
194
+ break
177
195
  if document_response.resp_type == RESPONSE_TYPE.ERROR:
178
- error_msg = document_response.error_message
179
- # LLMs won't always generate a working SQL query so we should have a fallback after retrying.
180
- logger.info(f'SQL Retriever query {checked_sql_query} failed with error {error_msg}')
181
- checked_sql_query = self._prepare_retry_query(checked_sql_query, error_msg, run_manager)
196
+ # LLMs won't always generate structured metadata so we should have a fallback after retrying.
197
+ logger.info(f'SQL Retriever query failed with error {document_response.error_message}')
182
198
  elif len(document_response.data_frame) == 0:
183
- error_msg = "No documents retrieved from query."
184
- checked_sql_query = self._prepare_retry_query(checked_sql_query, error_msg, run_manager)
185
- else:
186
- break
187
-
188
- checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=str(embedded_query))
189
- # Handle LLM output that has the ```sql delimiter possibly.
190
- checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```sql', '')
191
- checked_sql_query_with_embeddings = checked_sql_query_with_embeddings.replace('```', '')
192
- document_response = self.vector_store_handler.native_query(checked_sql_query_with_embeddings)
199
+ logger.info('No documents retrieved from SQL Retriever query')
193
200
 
201
+ document_response = self._prepare_and_execute_query(retrieval_query, str(embedded_query))
194
202
  num_retries += 1
195
203
  if num_retries >= self.num_retries:
196
204
  logger.info('Using fallback retriever in SQL retriever.')