MindsDB 25.2.4.0__py3-none-any.whl → 25.3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MindsDB might be problematic. Click here for more details.
- mindsdb/__about__.py +1 -1
- mindsdb/__main__.py +15 -0
- mindsdb/api/executor/command_executor.py +1 -1
- mindsdb/api/executor/datahub/datanodes/system_tables.py +6 -1
- mindsdb/api/executor/planner/query_planner.py +6 -2
- mindsdb/api/executor/sql_query/steps/prepare_steps.py +2 -1
- mindsdb/api/mongo/classes/query_sql.py +2 -1
- mindsdb/api/mongo/responders/aggregate.py +2 -2
- mindsdb/api/mongo/responders/coll_stats.py +3 -2
- mindsdb/api/mongo/responders/db_stats.py +2 -1
- mindsdb/api/mongo/responders/insert.py +4 -2
- mindsdb/api/mysql/mysql_proxy/classes/fake_mysql_proxy/fake_mysql_proxy.py +2 -1
- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +5 -4
- mindsdb/api/postgres/postgres_proxy/postgres_proxy.py +2 -4
- mindsdb/integrations/handlers/autosklearn_handler/autosklearn_handler.py +1 -1
- mindsdb/integrations/handlers/gmail_handler/connection_args.py +2 -2
- mindsdb/integrations/handlers/gmail_handler/gmail_handler.py +19 -66
- mindsdb/integrations/handlers/gmail_handler/requirements.txt +0 -1
- mindsdb/integrations/handlers/google_calendar_handler/connection_args.py +15 -0
- mindsdb/integrations/handlers/google_calendar_handler/google_calendar_handler.py +31 -41
- mindsdb/integrations/handlers/google_calendar_handler/requirements.txt +0 -2
- mindsdb/integrations/handlers/youtube_handler/youtube_handler.py +2 -38
- mindsdb/integrations/libs/llm/utils.py +2 -1
- mindsdb/integrations/utilities/handlers/auth_utilities/google/google_user_oauth_utilities.py +29 -38
- mindsdb/integrations/utilities/pydantic_utils.py +208 -0
- mindsdb/integrations/utilities/rag/pipelines/rag.py +11 -4
- mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +800 -135
- mindsdb/integrations/utilities/rag/settings.py +390 -152
- mindsdb/integrations/utilities/sql_utils.py +2 -1
- mindsdb/interfaces/agents/agents_controller.py +11 -7
- mindsdb/interfaces/agents/mindsdb_chat_model.py +4 -2
- mindsdb/interfaces/chatbot/chatbot_controller.py +9 -8
- mindsdb/interfaces/database/database.py +2 -1
- mindsdb/interfaces/database/projects.py +28 -2
- mindsdb/interfaces/jobs/jobs_controller.py +4 -1
- mindsdb/interfaces/model/model_controller.py +5 -2
- mindsdb/interfaces/skills/retrieval_tool.py +128 -39
- mindsdb/interfaces/skills/skill_tool.py +7 -7
- mindsdb/interfaces/skills/skills_controller.py +8 -4
- mindsdb/interfaces/storage/db.py +14 -0
- mindsdb/interfaces/storage/json.py +59 -0
- mindsdb/interfaces/storage/model_fs.py +85 -3
- mindsdb/interfaces/triggers/triggers_controller.py +2 -1
- mindsdb/migrations/versions/2022-10-14_43c52d23845a_projects.py +17 -3
- mindsdb/migrations/versions/2025-02-14_4521dafe89ab_added_encrypted_content_to_json_storage.py +29 -0
- mindsdb/migrations/versions/2025-02-19_11347c213b36_added_metadata_to_projects.py +41 -0
- mindsdb/utilities/config.py +5 -1
- mindsdb/utilities/functions.py +11 -0
- {MindsDB-25.2.4.0.dist-info → mindsdb-25.3.1.0.dist-info}/METADATA +221 -223
- {MindsDB-25.2.4.0.dist-info → mindsdb-25.3.1.0.dist-info}/RECORD +53 -51
- {MindsDB-25.2.4.0.dist-info → mindsdb-25.3.1.0.dist-info}/WHEEL +1 -1
- mindsdb/integrations/handlers/gmail_handler/utils.py +0 -45
- {MindsDB-25.2.4.0.dist-info → mindsdb-25.3.1.0.dist-info}/LICENSE +0 -0
- {MindsDB-25.2.4.0.dist-info → mindsdb-25.3.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import re
|
|
2
|
+
|
|
3
3
|
from pydantic import BaseModel, Field
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import List, Any, Optional, Dict, Tuple, Union, Callable
|
|
5
|
+
import collections
|
|
6
|
+
import math
|
|
5
7
|
|
|
6
8
|
from langchain.chains.llm import LLMChain
|
|
7
9
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
@@ -10,32 +12,57 @@ from langchain_core.embeddings import Embeddings
|
|
|
10
12
|
from langchain_core.exceptions import OutputParserException
|
|
11
13
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
12
14
|
from langchain_core.output_parsers import PydanticOutputParser
|
|
13
|
-
from langchain_core.prompts import PromptTemplate
|
|
15
|
+
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
|
|
14
16
|
from langchain_core.retrievers import BaseRetriever
|
|
15
17
|
|
|
16
18
|
from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
|
|
17
19
|
from mindsdb.integrations.libs.response import HandlerResponse
|
|
18
|
-
from mindsdb.integrations.libs.vectordatabase_handler import
|
|
19
|
-
|
|
20
|
+
from mindsdb.integrations.libs.vectordatabase_handler import (
|
|
21
|
+
DistanceFunction,
|
|
22
|
+
VectorStoreHandler,
|
|
23
|
+
)
|
|
24
|
+
from mindsdb.integrations.utilities.rag.settings import (
|
|
25
|
+
DatabaseSchema,
|
|
26
|
+
TableSchema,
|
|
27
|
+
ColumnSchema,
|
|
28
|
+
ValueSchema,
|
|
29
|
+
SearchKwargs,
|
|
30
|
+
)
|
|
20
31
|
from mindsdb.utilities import log
|
|
21
32
|
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
22
35
|
logger = log.getLogger(__name__)
|
|
23
36
|
|
|
24
37
|
|
|
25
38
|
class MetadataFilter(BaseModel):
|
|
26
|
-
|
|
39
|
+
"""Represents an LLM generated metadata filter to apply to a PostgreSQL query."""
|
|
40
|
+
|
|
27
41
|
attribute: str = Field(description="Database column to apply filter to")
|
|
28
|
-
comparator: str = Field(
|
|
42
|
+
comparator: str = Field(
|
|
43
|
+
description="PostgreSQL comparator to use to filter database column"
|
|
44
|
+
)
|
|
29
45
|
value: Any = Field(description="Value to use to filter database column")
|
|
30
46
|
|
|
31
47
|
|
|
48
|
+
class AblativeMetadataFilter(MetadataFilter):
|
|
49
|
+
"""Adds additional fields to support ablation."""
|
|
50
|
+
|
|
51
|
+
schema_table: str = Field(description="schema name of the table for this filter")
|
|
52
|
+
schema_column: str = Field(description="schema name of the column for this filter")
|
|
53
|
+
schema_value: str = Field(description="schema name of the value for this filter")
|
|
54
|
+
|
|
55
|
+
|
|
32
56
|
class MetadataFilters(BaseModel):
|
|
33
|
-
|
|
34
|
-
|
|
57
|
+
"""List of LLM generated metadata filters to apply to a PostgreSQL query."""
|
|
58
|
+
|
|
59
|
+
filters: List[MetadataFilter] = Field(
|
|
60
|
+
description="List of PostgreSQL metadata filters to apply for user query"
|
|
61
|
+
)
|
|
35
62
|
|
|
36
63
|
|
|
37
64
|
class SQLRetriever(BaseRetriever):
|
|
38
|
-
|
|
65
|
+
"""Retriever that uses a LLM to generate pgvector queries to do similarity search with metadata filters.
|
|
39
66
|
|
|
40
67
|
How it works:
|
|
41
68
|
|
|
@@ -48,134 +75,723 @@ class SQLRetriever(BaseRetriever):
|
|
|
48
75
|
3. Generate a prepared PostgreSQL query from the structured metadata filters.
|
|
49
76
|
|
|
50
77
|
4. Actually execute the query against our vector database to retrieve documents & return them.
|
|
51
|
-
|
|
78
|
+
"""
|
|
79
|
+
|
|
52
80
|
fallback_retriever: BaseRetriever
|
|
53
81
|
vector_store_handler: VectorStoreHandler
|
|
54
|
-
|
|
55
|
-
|
|
82
|
+
# search parameters
|
|
83
|
+
max_filters: int
|
|
84
|
+
filter_threshold: float
|
|
85
|
+
min_k: int
|
|
56
86
|
|
|
57
|
-
|
|
58
|
-
|
|
87
|
+
# Schema description
|
|
88
|
+
database_schema: Optional[DatabaseSchema] = None
|
|
89
|
+
|
|
90
|
+
# Embeddings
|
|
59
91
|
embeddings_model: Embeddings
|
|
92
|
+
search_kwargs: SearchKwargs
|
|
93
|
+
|
|
94
|
+
# prompt templates
|
|
95
|
+
rewrite_prompt_template: str
|
|
96
|
+
|
|
97
|
+
# schema templates
|
|
98
|
+
table_prompt_template: str
|
|
99
|
+
column_prompt_template: str
|
|
100
|
+
value_prompt_template: str
|
|
101
|
+
|
|
102
|
+
# formatting templates
|
|
103
|
+
boolean_system_prompt: str
|
|
104
|
+
generative_system_prompt: str
|
|
105
|
+
|
|
106
|
+
# SQL search config
|
|
60
107
|
num_retries: int
|
|
61
108
|
embeddings_table: str
|
|
62
109
|
source_table: str
|
|
63
|
-
source_id_column: str
|
|
110
|
+
source_id_column: str
|
|
64
111
|
distance_function: DistanceFunction
|
|
65
|
-
search_kwargs: SearchKwargs
|
|
66
112
|
|
|
113
|
+
# Re-rank and metadata generation model.
|
|
67
114
|
llm: BaseChatModel
|
|
68
115
|
|
|
69
|
-
def
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
116
|
+
def _sort_schema_by_priority_key(
|
|
117
|
+
self,
|
|
118
|
+
schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
|
|
119
|
+
):
|
|
120
|
+
return schema_dict_item[1].priority
|
|
121
|
+
|
|
122
|
+
def _sort_schema_by_relevance_key(
|
|
123
|
+
self,
|
|
124
|
+
schema_dict_item: Tuple[str, Union[TableSchema, ColumnSchema, ValueSchema]],
|
|
125
|
+
):
|
|
126
|
+
if schema_dict_item[1].relevance is not None:
|
|
127
|
+
return schema_dict_item[1].relevance
|
|
128
|
+
else:
|
|
129
|
+
return 0
|
|
130
|
+
|
|
131
|
+
def _sort_schema_by_key(
|
|
132
|
+
self,
|
|
133
|
+
schema: Union[DatabaseSchema, TableSchema, ColumnSchema],
|
|
134
|
+
key: Callable,
|
|
135
|
+
update: Dict[str, Any] = None,
|
|
136
|
+
) -> Union[DatabaseSchema, TableSchema, ColumnSchema]:
|
|
137
|
+
"""Takes a schema and converts its dict into an OrderedDict"""
|
|
138
|
+
if isinstance(schema, DatabaseSchema):
|
|
139
|
+
collection_key = "tables"
|
|
140
|
+
elif isinstance(schema, TableSchema):
|
|
141
|
+
collection_key = "columns"
|
|
142
|
+
elif isinstance(schema, ColumnSchema):
|
|
143
|
+
collection_key = "values"
|
|
144
|
+
else:
|
|
145
|
+
raise Exception(
|
|
146
|
+
"schema must be either a DatabaseSchema, TableSchema, or ColumnSchema."
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if update is not None:
|
|
150
|
+
ordered = collections.OrderedDict(
|
|
151
|
+
sorted(update.items(), key=key, reverse=True)
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
ordered = collections.OrderedDict(
|
|
155
|
+
sorted(getattr(schema, collection_key).items(), key=key, reverse=True)
|
|
156
|
+
)
|
|
157
|
+
schema = schema.model_copy(update={collection_key: ordered})
|
|
158
|
+
|
|
159
|
+
return schema
|
|
160
|
+
|
|
161
|
+
def _sort_database_schema_by_key(
|
|
162
|
+
self, database_schema: DatabaseSchema, key: Callable
|
|
163
|
+
) -> DatabaseSchema:
|
|
164
|
+
"""Re-build schema with OrderedDicts"""
|
|
165
|
+
tables = {}
|
|
166
|
+
# build new tables dict
|
|
167
|
+
for table_key, table_schema in database_schema.tables.items():
|
|
168
|
+
columns = {}
|
|
169
|
+
# build new column dict
|
|
170
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
171
|
+
# sort values directly and update column schema
|
|
172
|
+
columns[column_key] = self._sort_schema_by_key(
|
|
173
|
+
schema=column_schema, key=key
|
|
174
|
+
)
|
|
175
|
+
# update table schema and sort
|
|
176
|
+
tables[table_key] = self._sort_schema_by_key(
|
|
177
|
+
schema=table_schema, key=key, update=columns
|
|
178
|
+
)
|
|
179
|
+
# update table schema and sort
|
|
180
|
+
database_schema = self._sort_schema_by_key(
|
|
181
|
+
schema=database_schema, key=key, update=tables
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return database_schema
|
|
185
|
+
|
|
186
|
+
def _prepare_value_prompt(
|
|
187
|
+
self,
|
|
188
|
+
value_schema: ValueSchema,
|
|
189
|
+
column_schema: ColumnSchema,
|
|
190
|
+
table_schema: TableSchema,
|
|
191
|
+
boolean_system_prompt: bool = True,
|
|
192
|
+
format_instructions: Optional[str] = None,
|
|
193
|
+
) -> ChatPromptTemplate:
|
|
194
|
+
|
|
195
|
+
if boolean_system_prompt is True:
|
|
196
|
+
system_prompt = self.boolean_system_prompt
|
|
197
|
+
else:
|
|
198
|
+
system_prompt = self.generative_system_prompt
|
|
199
|
+
|
|
200
|
+
prepared_column_prompt = self._prepare_column_prompt(
|
|
201
|
+
column_schema=column_schema, table_schema=table_schema
|
|
202
|
+
)
|
|
203
|
+
column_schema_str = (
|
|
204
|
+
prepared_column_prompt.messages[1]
|
|
205
|
+
.format(
|
|
206
|
+
**prepared_column_prompt.partial_variables,
|
|
207
|
+
query="See query at the lowest level schema.",
|
|
208
|
+
)
|
|
209
|
+
.content
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
base_prompt_template = ChatPromptTemplate.from_messages(
|
|
213
|
+
[("system", system_prompt), ("user", self.value_prompt_template)]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
value_str = ""
|
|
217
|
+
header_str = ""
|
|
218
|
+
if type(value_schema.value) in [str, int, float, bool]:
|
|
219
|
+
header_str = f"This schema describes a single value in the {column_schema.column} column."
|
|
220
|
+
|
|
221
|
+
value_str = f"""
|
|
222
|
+
-**Value**: {value_schema.value}
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
elif type(value_schema.value) is dict:
|
|
226
|
+
header_str = f"This schema describes enumerated values in the {column_schema.column} column."
|
|
227
|
+
|
|
228
|
+
value_str = """
|
|
229
|
+
## **Enumerated Values**
|
|
230
|
+
|
|
231
|
+
The values in the column are an enumeration of named values. These are listed below with format **[Column Value]**: [named value].
|
|
232
|
+
"""
|
|
233
|
+
for value, value_name in value_schema.value.items():
|
|
234
|
+
value_str += f"""
|
|
235
|
+
- **{value}:** {value_name}"""
|
|
236
|
+
|
|
237
|
+
elif type(value_schema.value) is list:
|
|
238
|
+
header_str = f"This schema describes some of the values in the {column_schema.column} column."
|
|
239
|
+
|
|
240
|
+
value_str = """
|
|
241
|
+
## **Sample Values**
|
|
242
|
+
|
|
243
|
+
There are too many values in this column to list exhaustively. Below is a sampling of values found in the column:
|
|
244
|
+
"""
|
|
245
|
+
for value in value_schema.value:
|
|
246
|
+
value_str += f"""
|
|
247
|
+
- {value}"""
|
|
248
|
+
|
|
249
|
+
if getattr(value_schema, "comparator", None) is not None:
|
|
250
|
+
comparator_str = """
|
|
251
|
+
|
|
252
|
+
## **Comparators**
|
|
253
|
+
|
|
254
|
+
Below is a list of comparison operators for constructing filters for this value schema:
|
|
255
|
+
"""
|
|
256
|
+
if type(value_schema.comparator) is str:
|
|
257
|
+
comparator_str += f"""- {value_schema.comparator}
|
|
258
|
+
"""
|
|
259
|
+
else:
|
|
260
|
+
for comp in value_schema.comparator:
|
|
261
|
+
comparator_str += f"""- {comp}
|
|
262
|
+
"""
|
|
263
|
+
else:
|
|
264
|
+
comparator_str = ""
|
|
265
|
+
|
|
266
|
+
if getattr(value_schema, "example_questions", None) is not None:
|
|
267
|
+
example_str = """## **Example Questions**
|
|
268
|
+
"""
|
|
269
|
+
for i, example in enumerate(value_schema.example_questions):
|
|
270
|
+
example_str += f"""{i}. **Query:** {example.input} **Answer:** {example.output}
|
|
271
|
+
"""
|
|
272
|
+
else:
|
|
273
|
+
example_str = ""
|
|
274
|
+
|
|
275
|
+
return base_prompt_template.partial(
|
|
276
|
+
format_instructions=format_instructions,
|
|
277
|
+
header=header_str,
|
|
278
|
+
column_schema=column_schema_str,
|
|
279
|
+
value=value_str,
|
|
280
|
+
comparator=comparator_str,
|
|
281
|
+
type=value_schema.type,
|
|
282
|
+
description=value_schema.description,
|
|
283
|
+
usage=value_schema.usage,
|
|
284
|
+
examples=example_str,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def _prepare_column_prompt(
|
|
288
|
+
self,
|
|
289
|
+
column_schema: ColumnSchema,
|
|
290
|
+
table_schema: TableSchema,
|
|
291
|
+
boolean_system_prompt: bool = True,
|
|
292
|
+
) -> ChatPromptTemplate:
|
|
293
|
+
|
|
294
|
+
if boolean_system_prompt is True:
|
|
295
|
+
system_prompt = self.boolean_system_prompt
|
|
296
|
+
else:
|
|
297
|
+
system_prompt = self.generative_system_prompt
|
|
298
|
+
|
|
299
|
+
prepared_table_prompt = self._prepare_table_prompt(
|
|
300
|
+
table_schema=table_schema, boolean_system_prompt=boolean_system_prompt
|
|
301
|
+
)
|
|
302
|
+
table_schema_str = (
|
|
303
|
+
prepared_table_prompt.messages[1]
|
|
304
|
+
.format(
|
|
305
|
+
**prepared_table_prompt.partial_variables,
|
|
306
|
+
query="See query at the lowest level schema",
|
|
307
|
+
)
|
|
308
|
+
.content
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
base_prompt_template = ChatPromptTemplate.from_messages(
|
|
312
|
+
[("system", system_prompt), ("user", self.column_prompt_template)]
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
header_str = (
|
|
316
|
+
f"This schema describes a column in the {table_schema.table} table."
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
value_str = """
|
|
320
|
+
## **Content**
|
|
321
|
+
|
|
322
|
+
Below is a description of the contents in this column in list format:
|
|
323
|
+
"""
|
|
324
|
+
for value_schema in column_schema.values.values():
|
|
325
|
+
value_str += f"""
|
|
326
|
+
- {value_schema.description}
|
|
327
|
+
"""
|
|
328
|
+
value_str += """
|
|
329
|
+
**Important:** The above descriptions are not the actual values stored in this column. See the Value schema for actual values.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
if getattr(column_schema, "examples", None) is not None:
|
|
333
|
+
example_str = """## **Example Questions**
|
|
334
|
+
"""
|
|
335
|
+
for example in column_schema.examples:
|
|
336
|
+
example_str += f"""- {example}
|
|
337
|
+
"""
|
|
338
|
+
else:
|
|
339
|
+
example_str = ""
|
|
340
|
+
|
|
341
|
+
return base_prompt_template.partial(
|
|
342
|
+
table_schema=table_schema_str,
|
|
343
|
+
header=header_str,
|
|
344
|
+
column=column_schema.column,
|
|
345
|
+
type=column_schema.type,
|
|
346
|
+
description=column_schema.description,
|
|
347
|
+
usage=column_schema.usage,
|
|
348
|
+
values=value_str,
|
|
349
|
+
examples=example_str,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def _prepare_table_prompt(
|
|
353
|
+
self, table_schema: TableSchema, boolean_system_prompt: bool = True
|
|
354
|
+
) -> ChatPromptTemplate:
|
|
355
|
+
if boolean_system_prompt is True:
|
|
356
|
+
system_prompt = self.boolean_system_prompt
|
|
357
|
+
else:
|
|
358
|
+
system_prompt = self.generative_system_prompt
|
|
359
|
+
|
|
360
|
+
base_prompt_template = ChatPromptTemplate.from_messages(
|
|
361
|
+
[("system", system_prompt), ("user", self.table_prompt_template)]
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
header_str = "This schema describes a table in the database."
|
|
365
|
+
|
|
366
|
+
columns_str = ""
|
|
367
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
368
|
+
columns_str += f"""
|
|
369
|
+
- **{column_schema.column}:** {column_schema.description}
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
if getattr(table_schema, "examples", None) is not None:
|
|
373
|
+
example_str = """## **Example Questions**
|
|
374
|
+
"""
|
|
375
|
+
for example in table_schema.examples:
|
|
376
|
+
example_str += f"""- {example}
|
|
377
|
+
"""
|
|
378
|
+
else:
|
|
379
|
+
example_str = ""
|
|
380
|
+
|
|
106
381
|
return base_prompt_template.partial(
|
|
107
|
-
|
|
108
|
-
|
|
382
|
+
header=header_str,
|
|
383
|
+
table=table_schema.table,
|
|
384
|
+
description=table_schema.description,
|
|
385
|
+
usage=table_schema.usage,
|
|
386
|
+
columns=columns_str,
|
|
387
|
+
examples=example_str,
|
|
109
388
|
)
|
|
110
389
|
|
|
390
|
+
def _rank_schema(self, prompt: ChatPromptTemplate, query: str) -> float:
|
|
391
|
+
rank_chain = LLMChain(
|
|
392
|
+
llm=self.llm.bind(logprobs=True), prompt=prompt, return_final_only=False
|
|
393
|
+
)
|
|
394
|
+
output = rank_chain({"query": query}) # returns metadata
|
|
395
|
+
|
|
396
|
+
# parse through metadata tokens until encountering either yes, or no.
|
|
397
|
+
score = None # a None score indicates the model output could not be parsed.
|
|
398
|
+
for content in output["full_generation"][0].message.response_metadata[
|
|
399
|
+
"logprobs"
|
|
400
|
+
]["content"]:
|
|
401
|
+
# Convert answer to score using the model's confidence
|
|
402
|
+
if content["token"].lower().strip() == "yes":
|
|
403
|
+
score = (
|
|
404
|
+
1 + math.exp(content["logprob"])
|
|
405
|
+
) / 2 # If yes, use the model's confidence
|
|
406
|
+
break
|
|
407
|
+
elif content["token"].lower().strip() == "no":
|
|
408
|
+
score = (
|
|
409
|
+
1 - math.exp(content["logprob"])
|
|
410
|
+
) / 2 # If no, invert the confidence
|
|
411
|
+
break
|
|
412
|
+
|
|
413
|
+
if score is None:
|
|
414
|
+
score = 0.0
|
|
415
|
+
|
|
416
|
+
return score
|
|
417
|
+
|
|
418
|
+
def _breadth_first_search(self, query: str, greedy: bool = False) -> Tuple:
|
|
419
|
+
"""Search breadth wise through Tables, then Columns, then Values.Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
|
|
420
|
+
|
|
421
|
+
# sort based on priority
|
|
422
|
+
ordered_database_schema = self._sort_database_schema_by_key(
|
|
423
|
+
database_schema=self.database_schema, key=self._sort_schema_by_priority_key
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Rank Tables ########################################################
|
|
427
|
+
greedy_count = 0
|
|
428
|
+
tables = {}
|
|
429
|
+
# rank tables by relevance
|
|
430
|
+
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
431
|
+
prompt: ChatPromptTemplate = self._prepare_table_prompt(
|
|
432
|
+
table_schema=table_schema, boolean_system_prompt=True
|
|
433
|
+
)
|
|
434
|
+
table_schema.relevance = self._rank_schema(prompt=prompt, query=query)
|
|
435
|
+
|
|
436
|
+
# only keep greedy tables
|
|
437
|
+
tables[table_key] = table_schema
|
|
438
|
+
|
|
439
|
+
if greedy:
|
|
440
|
+
if table_schema.relevance >= ordered_database_schema.filter_threshold:
|
|
441
|
+
greedy_count += 1
|
|
442
|
+
if greedy_count >= ordered_database_schema.max_filters:
|
|
443
|
+
break
|
|
444
|
+
|
|
445
|
+
# sort tables
|
|
446
|
+
ordered_database_schema = self._sort_schema_by_key(
|
|
447
|
+
schema=ordered_database_schema,
|
|
448
|
+
key=self._sort_schema_by_relevance_key,
|
|
449
|
+
update=tables,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Rank Columns #######################################################
|
|
453
|
+
# iterate through tables to rank columns
|
|
454
|
+
tables = {}
|
|
455
|
+
table_count = 0 # take only the top n number of tables specified by the databases max filters
|
|
456
|
+
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
457
|
+
# only drop into tables above the filter threshold
|
|
458
|
+
if table_schema.relevance >= ordered_database_schema.filter_threshold:
|
|
459
|
+
greedy_count = 0
|
|
460
|
+
# rank columns by relevance
|
|
461
|
+
columns = {}
|
|
462
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
463
|
+
prompt: ChatPromptTemplate = self._prepare_column_prompt(
|
|
464
|
+
column_schema=column_schema,
|
|
465
|
+
table_schema=table_schema,
|
|
466
|
+
boolean_system_prompt=True,
|
|
467
|
+
)
|
|
468
|
+
column_schema.relevance = self._rank_schema(
|
|
469
|
+
prompt=prompt, query=query
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
columns[column_key] = column_schema
|
|
473
|
+
|
|
474
|
+
if greedy:
|
|
475
|
+
if column_schema.relevance >= table_schema.filter_threshold:
|
|
476
|
+
greedy_count += 1
|
|
477
|
+
if greedy_count >= table_schema.max_filters:
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
# sort columns and keep only columns that made the cut.
|
|
481
|
+
tables[table_key] = self._sort_schema_by_key(
|
|
482
|
+
table_schema, key=self._sort_schema_by_relevance_key, update=columns
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
table_count += 1
|
|
486
|
+
if table_count >= ordered_database_schema.max_filters:
|
|
487
|
+
break
|
|
488
|
+
|
|
489
|
+
# sort tables and keep only tables that made the cut.
|
|
490
|
+
ordered_database_schema = self._sort_schema_by_key(
|
|
491
|
+
ordered_database_schema,
|
|
492
|
+
key=self._sort_schema_by_relevance_key,
|
|
493
|
+
update=tables,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Rank Values ########################################################
|
|
497
|
+
# iterate through tables to rank values
|
|
498
|
+
tables = {}
|
|
499
|
+
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
500
|
+
columns = {}
|
|
501
|
+
column_count = 0
|
|
502
|
+
# iterate through columns to rank values
|
|
503
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
504
|
+
if column_schema.relevance >= table_schema.filter_threshold:
|
|
505
|
+
greedy_count = 0
|
|
506
|
+
values = {}
|
|
507
|
+
# rank values by relevance
|
|
508
|
+
for value_key, value_schema in column_schema.values.items():
|
|
509
|
+
prompt: ChatPromptTemplate = self._prepare_value_prompt(
|
|
510
|
+
value_schema=value_schema,
|
|
511
|
+
column_schema=column_schema,
|
|
512
|
+
table_schema=table_schema,
|
|
513
|
+
boolean_system_prompt=True,
|
|
514
|
+
)
|
|
515
|
+
value_schema.relevance = self._rank_schema(
|
|
516
|
+
prompt=prompt, query=query
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
values[value_key] = value_schema
|
|
520
|
+
|
|
521
|
+
if greedy:
|
|
522
|
+
if value_schema.relevance >= column_schema.filter_threshold:
|
|
523
|
+
greedy_count += 1
|
|
524
|
+
if greedy_count >= column_schema.max_filters:
|
|
525
|
+
break
|
|
526
|
+
|
|
527
|
+
# sort values and keep only values that make the cut
|
|
528
|
+
columns[column_key] = self._sort_schema_by_key(
|
|
529
|
+
column_schema,
|
|
530
|
+
key=self._sort_schema_by_relevance_key,
|
|
531
|
+
update=values,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
column_count += 1
|
|
535
|
+
if column_count >= table_schema.max_filters:
|
|
536
|
+
break
|
|
537
|
+
|
|
538
|
+
# sort columns and keep only columns that made the cut
|
|
539
|
+
tables[table_key] = self._sort_schema_by_key(
|
|
540
|
+
table_schema, key=self._sort_schema_by_relevance_key, update=columns
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# sort tables and keep only tables that made the cut.
|
|
544
|
+
ordered_database_schema = self._sort_schema_by_key(
|
|
545
|
+
ordered_database_schema,
|
|
546
|
+
key=self._sort_schema_by_relevance_key,
|
|
547
|
+
update=tables,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# discard low ranked values ###################################################################################
|
|
551
|
+
tables = {}
|
|
552
|
+
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
553
|
+
columns = {}
|
|
554
|
+
# iterate through columns to rank values
|
|
555
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
556
|
+
value_count = 0
|
|
557
|
+
values = {}
|
|
558
|
+
# rank values by relevance
|
|
559
|
+
for value_key, value_schema in column_schema.values.items():
|
|
560
|
+
if value_schema.relevance >= column_schema.filter_threshold:
|
|
561
|
+
values[value_key] = value_schema
|
|
562
|
+
|
|
563
|
+
value_count += 1
|
|
564
|
+
if value_count >= column_schema.max_filters:
|
|
565
|
+
break
|
|
566
|
+
|
|
567
|
+
# sort values and keep only values that make the cut
|
|
568
|
+
columns[column_key] = self._sort_schema_by_key(
|
|
569
|
+
column_schema,
|
|
570
|
+
key=self._sort_schema_by_relevance_key,
|
|
571
|
+
update=values,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# sort columns and keep only columns that made the cut
|
|
575
|
+
tables[table_key] = self._sort_schema_by_key(
|
|
576
|
+
table_schema, key=self._sort_schema_by_relevance_key, update=columns
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# sort tables and keep only tables that made the cut.
|
|
580
|
+
ordered_database_schema = self._sort_schema_by_key(
|
|
581
|
+
ordered_database_schema,
|
|
582
|
+
key=self._sort_schema_by_relevance_key,
|
|
583
|
+
update=tables,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
ranked_database_schema = ordered_database_schema
|
|
587
|
+
|
|
588
|
+
# Build Ablation #####################################################
|
|
589
|
+
|
|
590
|
+
ablation_value_dict = {}
|
|
591
|
+
# assemble a relevance dictionary
|
|
592
|
+
for table_key, table_schema in ordered_database_schema.tables.items():
|
|
593
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
594
|
+
for value_key, value_schema in column_schema.values.items():
|
|
595
|
+
ablation_value_dict[(table_key, column_key, value_key)] = (
|
|
596
|
+
value_schema.relevance
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
ablation_value_dict = collections.OrderedDict(
|
|
600
|
+
sorted(ablation_value_dict.items(), key=lambda x: x[1])
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
relevance_scores = list(ablation_value_dict.values())
|
|
604
|
+
if len(relevance_scores) > 0:
|
|
605
|
+
ablation_quantiles = np.quantile(
|
|
606
|
+
relevance_scores, np.linspace(0, 1, self.num_retries + 2)[1:-1]
|
|
607
|
+
)
|
|
608
|
+
else:
|
|
609
|
+
ablation_quantiles = None
|
|
610
|
+
|
|
611
|
+
return ranked_database_schema, ablation_value_dict, ablation_quantiles
|
|
612
|
+
|
|
613
|
+
def _dynamic_ablation(
|
|
614
|
+
self,
|
|
615
|
+
metadata_filters: List[AblativeMetadataFilter],
|
|
616
|
+
ablation_value_dict,
|
|
617
|
+
ablation_quantiles,
|
|
618
|
+
retry: int,
|
|
619
|
+
):
|
|
620
|
+
"""Ablate metadata filters in aggregate by quantiles until the required minimum number of documents are returned."""
|
|
621
|
+
|
|
622
|
+
ablated_dict = {}
|
|
623
|
+
for key, value in ablation_value_dict.items():
|
|
624
|
+
if value >= ablation_quantiles[retry]:
|
|
625
|
+
ablated_dict[key] = value
|
|
626
|
+
|
|
627
|
+
# discard low ranked filters ##################################################################################
|
|
628
|
+
ablated_filters = []
|
|
629
|
+
for filter in metadata_filters:
|
|
630
|
+
for key in ablated_dict.keys():
|
|
631
|
+
if (
|
|
632
|
+
filter.schema_table in key
|
|
633
|
+
and filter.schema_column in key
|
|
634
|
+
and filter.schema_value in key
|
|
635
|
+
):
|
|
636
|
+
ablated_filters.append(filter)
|
|
637
|
+
|
|
638
|
+
return ablated_filters
|
|
639
|
+
|
|
640
|
+
def depth_first_search(self, greedy=True):
|
|
641
|
+
"""Search depth wise through Tables, then Columns, then Values. Uses a greedy strategy to maximize quota if greedy=True, otherwise a dynamic strategy."""
|
|
642
|
+
pass
|
|
643
|
+
|
|
644
|
+
def depth_first_ablation(self):
|
|
645
|
+
"""Ablate metadata filters in reverse depth first search until the required minimum number of documents are returned."""
|
|
646
|
+
pass
|
|
647
|
+
|
|
111
648
|
def _prepare_retrieval_query(self, query: str) -> str:
|
|
112
649
|
rewrite_prompt = PromptTemplate(
|
|
113
|
-
input_variables=[
|
|
114
|
-
template=self.rewrite_prompt_template
|
|
650
|
+
input_variables=["input"], template=self.rewrite_prompt_template
|
|
115
651
|
)
|
|
116
652
|
rewrite_chain = LLMChain(llm=self.llm, prompt=rewrite_prompt)
|
|
117
653
|
return rewrite_chain.predict(input=query)
|
|
118
654
|
|
|
119
|
-
def _prepare_pgvector_query(
|
|
655
|
+
def _prepare_pgvector_query(
|
|
656
|
+
self,
|
|
657
|
+
ranked_database_schema: DatabaseSchema,
|
|
658
|
+
metadata_filters: List[AblativeMetadataFilter],
|
|
659
|
+
retry: int = 0,
|
|
660
|
+
) -> str:
|
|
120
661
|
# Base select JOINed with document source table.
|
|
121
|
-
base_query = f
|
|
122
|
-
|
|
123
|
-
if not
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
for
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
662
|
+
base_query = f"""SELECT * FROM {self.embeddings_table} AS e INNER JOIN {self.source_table} AS s ON (e.metadata->>'original_row_id')::int = s."{self.source_id_column}" """
|
|
663
|
+
|
|
664
|
+
# return an empty string if schema has not been ranked
|
|
665
|
+
if not ranked_database_schema:
|
|
666
|
+
return ""
|
|
667
|
+
|
|
668
|
+
# Add Table JOIN statements
|
|
669
|
+
join_clauses = set()
|
|
670
|
+
for metadata_filter in metadata_filters:
|
|
671
|
+
join_clause = ranked_database_schema.tables[
|
|
672
|
+
metadata_filter.schema_table
|
|
673
|
+
].join
|
|
674
|
+
if join_clause in join_clauses:
|
|
133
675
|
continue
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
676
|
+
else:
|
|
677
|
+
join_clauses.add(join_clause)
|
|
678
|
+
base_query += join_clause + " "
|
|
679
|
+
|
|
680
|
+
# Add WHERE conditions from metadata filters
|
|
137
681
|
if metadata_filters:
|
|
138
|
-
base_query +=
|
|
682
|
+
base_query += "WHERE "
|
|
139
683
|
for i, filter in enumerate(metadata_filters):
|
|
140
684
|
value = filter.value
|
|
141
685
|
if isinstance(value, str):
|
|
142
686
|
value = f"'{value}'"
|
|
143
687
|
base_query += f'"{filter.attribute}" {filter.comparator} {value}'
|
|
144
688
|
if i < len(metadata_filters) - 1:
|
|
145
|
-
base_query +=
|
|
689
|
+
base_query += " AND "
|
|
690
|
+
|
|
146
691
|
base_query += f" ORDER BY e.embeddings {self.distance_function.value[0]} '{{embeddings}}' LIMIT {self.search_kwargs.k};"
|
|
147
692
|
return base_query
|
|
148
693
|
|
|
149
|
-
def
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
694
|
+
def _generate_filter(
|
|
695
|
+
self, prompt: ChatPromptTemplate, query: str
|
|
696
|
+
) -> MetadataFilter:
|
|
697
|
+
gen_filter_chain = LLMChain(llm=self.llm, prompt=prompt)
|
|
698
|
+
output = gen_filter_chain({"query": query})
|
|
699
|
+
return output
|
|
700
|
+
|
|
701
|
+
def _generate_metadata_filters(
|
|
702
|
+
self, query: str, ranked_database_schema
|
|
703
|
+
) -> Union[List[AblativeMetadataFilter], HandlerResponse]:
|
|
704
|
+
parser = PydanticOutputParser(pydantic_object=MetadataFilter)
|
|
705
|
+
|
|
706
|
+
metadata_filter_list = []
|
|
707
|
+
# iterate through tables to rank values
|
|
708
|
+
for table_key, table_schema in ranked_database_schema.tables.items():
|
|
709
|
+
# iterate through columns to rank values
|
|
710
|
+
for column_key, column_schema in table_schema.columns.items():
|
|
711
|
+
if column_schema.relevance >= table_schema.filter_threshold:
|
|
712
|
+
# generate filters
|
|
713
|
+
for value_key, value_schema in column_schema.values.items():
|
|
714
|
+
# must use generation if field is a dictionary of tuples or a list
|
|
715
|
+
if type(value_schema.value) in [list, dict]:
|
|
716
|
+
try:
|
|
717
|
+
metadata_prompt: ChatPromptTemplate = (
|
|
718
|
+
self._prepare_value_prompt(
|
|
719
|
+
format_instructions=parser.get_format_instructions(),
|
|
720
|
+
value_schema=value_schema,
|
|
721
|
+
column_schema=column_schema,
|
|
722
|
+
table_schema=table_schema,
|
|
723
|
+
boolean_system_prompt=False,
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
metadata_filters_chain = LLMChain(
|
|
728
|
+
llm=self.llm, prompt=metadata_prompt
|
|
729
|
+
)
|
|
730
|
+
metadata_filter_output = metadata_filters_chain.predict(
|
|
731
|
+
query=query,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# If the LLM outputs raw JSON, use it as-is.
|
|
735
|
+
# If the LLM outputs anything including a json markdown section, use the last one.
|
|
736
|
+
json_markdown_output = re.findall(
|
|
737
|
+
r"```json.*```", metadata_filter_output, re.DOTALL
|
|
738
|
+
)
|
|
739
|
+
if json_markdown_output:
|
|
740
|
+
metadata_filter_output = json_markdown_output[-1]
|
|
741
|
+
# Clean the json tags.
|
|
742
|
+
metadata_filter_output = metadata_filter_output[7:]
|
|
743
|
+
metadata_filter_output = metadata_filter_output[:-3]
|
|
744
|
+
|
|
745
|
+
metadata_filter = parser.invoke(metadata_filter_output)
|
|
746
|
+
model_dump = metadata_filter.model_dump()
|
|
747
|
+
model_dump.update(
|
|
748
|
+
{
|
|
749
|
+
"schema_table": table_key,
|
|
750
|
+
"schema_column": column_key,
|
|
751
|
+
"schema_value": value_key,
|
|
752
|
+
}
|
|
753
|
+
)
|
|
754
|
+
metadata_filter = AblativeMetadataFilter(**model_dump)
|
|
755
|
+
except OutputParserException as e:
|
|
756
|
+
logger.warning(
|
|
757
|
+
f"LLM failed to generate structured metadata filters: {str(e)}"
|
|
758
|
+
)
|
|
759
|
+
return HandlerResponse(
|
|
760
|
+
RESPONSE_TYPE.ERROR, error_message=str(e)
|
|
761
|
+
)
|
|
762
|
+
else:
|
|
763
|
+
metadata_filter = AblativeMetadataFilter(
|
|
764
|
+
attribute=column_schema.column,
|
|
765
|
+
comparator=value_schema.comparator,
|
|
766
|
+
value=value_schema.value,
|
|
767
|
+
schema_table=table_key,
|
|
768
|
+
schema_column=column_key,
|
|
769
|
+
schema_value=value_key,
|
|
770
|
+
)
|
|
771
|
+
metadata_filter_list.append(metadata_filter)
|
|
772
|
+
|
|
773
|
+
return metadata_filter_list
|
|
774
|
+
|
|
775
|
+
def _prepare_and_execute_query(
|
|
776
|
+
self,
|
|
777
|
+
ranked_database_schema: DatabaseSchema,
|
|
778
|
+
metadata_filters: List[AblativeMetadataFilter],
|
|
779
|
+
embeddings_str: str,
|
|
780
|
+
) -> HandlerResponse:
|
|
169
781
|
try:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
return
|
|
782
|
+
checked_sql_query = self._prepare_pgvector_query(
|
|
783
|
+
ranked_database_schema, metadata_filters
|
|
784
|
+
)
|
|
785
|
+
checked_sql_query_with_embeddings = checked_sql_query.format(
|
|
786
|
+
embeddings=embeddings_str
|
|
787
|
+
)
|
|
788
|
+
return self.vector_store_handler.native_query(
|
|
789
|
+
checked_sql_query_with_embeddings
|
|
790
|
+
)
|
|
177
791
|
except Exception as e:
|
|
178
|
-
logger.warning(
|
|
792
|
+
logger.warning(
|
|
793
|
+
f"Failed to prepare and execute SQL query from structured metadata: {str(e)}"
|
|
794
|
+
)
|
|
179
795
|
return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e))
|
|
180
796
|
|
|
181
797
|
def _get_relevant_documents(
|
|
@@ -183,36 +799,85 @@ Output:
|
|
|
183
799
|
) -> List[Document]:
|
|
184
800
|
# Rewrite query to be suitable for retrieval.
|
|
185
801
|
retrieval_query = self._prepare_retrieval_query(query)
|
|
802
|
+
|
|
186
803
|
# Embed the rewritten retrieval query & include it in the similarity search pgvector query.
|
|
187
804
|
embedded_query = self.embeddings_model.embed_query(retrieval_query)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
document_response = self._prepare_and_execute_query(
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
805
|
+
|
|
806
|
+
# Search for relevant filters
|
|
807
|
+
ranked_database_schema, ablation_value_dict, ablation_quantiles = (
|
|
808
|
+
self._breadth_first_search(query=query)
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
# Generate metadata filters
|
|
812
|
+
metadata_filters = self._generate_metadata_filters(
|
|
813
|
+
query=query, ranked_database_schema=ranked_database_schema
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
if type(metadata_filters) is list:
|
|
817
|
+
# Initial Execution of the similarity search with metadata filters.
|
|
818
|
+
document_response = self._prepare_and_execute_query(
|
|
819
|
+
ranked_database_schema=ranked_database_schema,
|
|
820
|
+
metadata_filters=metadata_filters,
|
|
821
|
+
embeddings_str=str(embedded_query),
|
|
822
|
+
)
|
|
823
|
+
num_retries = 0
|
|
824
|
+
while num_retries < self.num_retries:
|
|
825
|
+
if (
|
|
826
|
+
document_response.resp_type != RESPONSE_TYPE.ERROR
|
|
827
|
+
and len(document_response.data_frame) >= self.min_k
|
|
828
|
+
):
|
|
829
|
+
# Successfully retrieved k documents to send to re-ranker.
|
|
830
|
+
break
|
|
831
|
+
elif document_response.resp_type == RESPONSE_TYPE.ERROR:
|
|
832
|
+
# LLMs won't always generate structured metadata so we should have a fallback after retrying.
|
|
833
|
+
logger.info(
|
|
834
|
+
f"SQL Retriever query failed with error {document_response.error_message}"
|
|
835
|
+
)
|
|
836
|
+
else:
|
|
837
|
+
logger.info(
|
|
838
|
+
f"SQL Retriever did not retrieve {self.min_k} documents: {len(document_response.data_frame)} documents retrieved."
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
ablated_metadata_filters = self._dynamic_ablation(
|
|
842
|
+
metadata_filters=metadata_filters,
|
|
843
|
+
ablation_value_dict=ablation_value_dict,
|
|
844
|
+
ablation_quantiles=ablation_quantiles,
|
|
845
|
+
retry=num_retries,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
document_response = self._prepare_and_execute_query(
|
|
849
|
+
ranked_database_schema=ranked_database_schema,
|
|
850
|
+
metadata_filters=ablated_metadata_filters,
|
|
851
|
+
embeddings_str=str(embedded_query),
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
num_retries += 1
|
|
855
|
+
|
|
856
|
+
retrieved_documents = []
|
|
857
|
+
if document_response.resp_type != RESPONSE_TYPE.ERROR:
|
|
858
|
+
document_df = document_response.data_frame
|
|
859
|
+
for _, document_row in document_df.iterrows():
|
|
860
|
+
retrieved_documents.append(
|
|
861
|
+
Document(
|
|
862
|
+
document_row.get("content", ""),
|
|
863
|
+
metadata=document_row.get("metadata", {}),
|
|
864
|
+
)
|
|
865
|
+
)
|
|
866
|
+
if retrieved_documents:
|
|
867
|
+
return retrieved_documents
|
|
868
|
+
|
|
869
|
+
# If the SQL query constructed did not return any documents, fallback.
|
|
870
|
+
logger.info(
|
|
871
|
+
"No documents returned from SQL retriever, using fallback retriever."
|
|
872
|
+
)
|
|
873
|
+
return self.fallback_retriever._get_relevant_documents(
|
|
874
|
+
retrieval_query, run_manager=run_manager
|
|
875
|
+
)
|
|
876
|
+
else:
|
|
877
|
+
# If no metadata fields could be generated fallback.
|
|
878
|
+
logger.info(
|
|
879
|
+
"No metadata fields were successfully generated, using fallback retriever."
|
|
880
|
+
)
|
|
881
|
+
return self.fallback_retriever._get_relevant_documents(
|
|
882
|
+
retrieval_query, run_manager=run_manager
|
|
883
|
+
)
|