nucliadb 6.3.6.post4063__py3-none-any.whl → 6.3.7.post4068__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/api/v1/search.py +6 -39
- nucliadb/search/search/chat/ask.py +19 -26
- nucliadb/search/search/chat/query.py +6 -6
- nucliadb/search/search/find.py +21 -91
- nucliadb/search/search/find_merge.py +18 -9
- nucliadb/search/search/graph_strategy.py +9 -10
- nucliadb/search/search/merge.py +76 -65
- nucliadb/search/search/query.py +2 -455
- nucliadb/search/search/query_parser/fetcher.py +41 -0
- nucliadb/search/search/query_parser/models.py +82 -8
- nucliadb/search/search/query_parser/parsers/ask.py +77 -0
- nucliadb/search/search/query_parser/parsers/common.py +189 -0
- nucliadb/search/search/query_parser/parsers/find.py +175 -13
- nucliadb/search/search/query_parser/parsers/search.py +249 -0
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +176 -0
- nucliadb/search/search/rerankers.py +4 -2
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/METADATA +6 -6
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/RECORD +21 -17
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/WHEEL +0 -0
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/top_level.txt +0 -0
nucliadb/search/search/query.py
CHANGED
@@ -17,46 +17,22 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
import asyncio
|
21
|
-
import re
|
22
|
-
import string
|
23
20
|
from datetime import datetime
|
24
|
-
from typing import Any,
|
21
|
+
from typing import Any, Optional
|
25
22
|
|
26
23
|
from nucliadb.common import datamanagers
|
27
|
-
from nucliadb.common.models_utils.from_proto import RelationNodeTypeMap
|
28
|
-
from nucliadb.search import logger
|
29
|
-
from nucliadb.search.predict import SendToPredictError
|
30
24
|
from nucliadb.search.search.filters import (
|
31
25
|
translate_label,
|
32
26
|
)
|
33
|
-
from nucliadb.search.search.metrics import (
|
34
|
-
node_features,
|
35
|
-
query_parser_observer,
|
36
|
-
)
|
37
27
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
38
|
-
from nucliadb.search.search.rank_fusion import (
|
39
|
-
RankFusionAlgorithm,
|
40
|
-
)
|
41
|
-
from nucliadb.search.search.rerankers import (
|
42
|
-
Reranker,
|
43
|
-
)
|
44
28
|
from nucliadb_models.filters import FilterExpression
|
45
|
-
from nucliadb_models.
|
46
|
-
from nucliadb_models.labels import LABEL_HIDDEN, translate_system_to_alias_label
|
29
|
+
from nucliadb_models.labels import LABEL_HIDDEN
|
47
30
|
from nucliadb_models.metadata import ResourceProcessingStatus
|
48
31
|
from nucliadb_models.search import (
|
49
|
-
KnowledgeGraphEntity,
|
50
|
-
MaxTokens,
|
51
|
-
MinScore,
|
52
|
-
SearchOptions,
|
53
32
|
SortField,
|
54
|
-
SortOptions,
|
55
33
|
SortOrder,
|
56
|
-
SortOrderMap,
|
57
34
|
SuggestOptions,
|
58
35
|
)
|
59
|
-
from nucliadb_models.security import RequestSecurity
|
60
36
|
from nucliadb_protos import nodereader_pb2, utils_pb2
|
61
37
|
from nucliadb_protos.noderesources_pb2 import Resource
|
62
38
|
|
@@ -64,435 +40,6 @@ from .exceptions import InvalidQueryError
|
|
64
40
|
from .query_parser.filter_expression import add_and_expression, parse_expression
|
65
41
|
from .query_parser.old_filters import OldFilterParams, parse_old_filters
|
66
42
|
|
67
|
-
INDEX_SORTABLE_FIELDS = [
|
68
|
-
SortField.CREATED,
|
69
|
-
SortField.MODIFIED,
|
70
|
-
]
|
71
|
-
|
72
|
-
DEFAULT_GENERIC_SEMANTIC_THRESHOLD = 0.7
|
73
|
-
|
74
|
-
# -* is an invalid query in tantivy and it won't return results but if you add some whitespaces
|
75
|
-
# between - and *, it will actually trigger a tantivy bug and panic
|
76
|
-
INVALID_QUERY = re.compile(r"- +\*")
|
77
|
-
|
78
|
-
|
79
|
-
class QueryParser:
|
80
|
-
"""
|
81
|
-
Queries are getting more and more complex and different phases of the query
|
82
|
-
depending on different data.
|
83
|
-
|
84
|
-
This class is an encapsulation of the different phases of the query and allow
|
85
|
-
some stateful interaction with a query and different depenedencies during
|
86
|
-
query parsing.
|
87
|
-
"""
|
88
|
-
|
89
|
-
_query_information_task: Optional[asyncio.Task] = None
|
90
|
-
|
91
|
-
def __init__(
|
92
|
-
self,
|
93
|
-
*,
|
94
|
-
kbid: str,
|
95
|
-
features: list[SearchOptions],
|
96
|
-
query: str,
|
97
|
-
top_k: int,
|
98
|
-
min_score: MinScore,
|
99
|
-
old_filters: OldFilterParams,
|
100
|
-
filter_expression: Optional[FilterExpression] = None,
|
101
|
-
query_entities: Optional[list[KnowledgeGraphEntity]] = None,
|
102
|
-
faceted: Optional[list[str]] = None,
|
103
|
-
sort: Optional[SortOptions] = None,
|
104
|
-
user_vector: Optional[list[float]] = None,
|
105
|
-
vectorset: Optional[str] = None,
|
106
|
-
with_duplicates: bool = False,
|
107
|
-
with_status: Optional[ResourceProcessingStatus] = None,
|
108
|
-
with_synonyms: bool = False,
|
109
|
-
autofilter: bool = False,
|
110
|
-
security: Optional[RequestSecurity] = None,
|
111
|
-
generative_model: Optional[str] = None,
|
112
|
-
rephrase: bool = False,
|
113
|
-
rephrase_prompt: Optional[str] = None,
|
114
|
-
max_tokens: Optional[MaxTokens] = None,
|
115
|
-
hidden: Optional[bool] = None,
|
116
|
-
rank_fusion: Optional[RankFusionAlgorithm] = None,
|
117
|
-
reranker: Optional[Reranker] = None,
|
118
|
-
):
|
119
|
-
self.kbid = kbid
|
120
|
-
self.features = features
|
121
|
-
self.query = query
|
122
|
-
self.query_entities = query_entities
|
123
|
-
self.hidden = hidden
|
124
|
-
self.faceted = faceted or []
|
125
|
-
self.top_k = top_k
|
126
|
-
self.min_score = min_score
|
127
|
-
self.sort = sort
|
128
|
-
self.user_vector = user_vector
|
129
|
-
self.vectorset = vectorset
|
130
|
-
self.with_duplicates = with_duplicates
|
131
|
-
self.with_status = with_status
|
132
|
-
self.with_synonyms = with_synonyms
|
133
|
-
self.autofilter = autofilter
|
134
|
-
self.security = security
|
135
|
-
self.generative_model = generative_model
|
136
|
-
self.rephrase = rephrase
|
137
|
-
self.rephrase_prompt = rephrase_prompt
|
138
|
-
self.query_endpoint_used = False
|
139
|
-
self.max_tokens = max_tokens
|
140
|
-
self.rank_fusion = rank_fusion
|
141
|
-
self.reranker = reranker
|
142
|
-
self.filter_expression = filter_expression
|
143
|
-
self.old_filters = old_filters
|
144
|
-
self.fetcher = Fetcher(
|
145
|
-
kbid=kbid,
|
146
|
-
query=query,
|
147
|
-
user_vector=user_vector,
|
148
|
-
vectorset=vectorset,
|
149
|
-
rephrase=rephrase,
|
150
|
-
rephrase_prompt=rephrase_prompt,
|
151
|
-
generative_model=generative_model,
|
152
|
-
)
|
153
|
-
|
154
|
-
@property
|
155
|
-
def has_vector_search(self) -> bool:
|
156
|
-
return SearchOptions.SEMANTIC in self.features
|
157
|
-
|
158
|
-
@property
|
159
|
-
def has_relations_search(self) -> bool:
|
160
|
-
return SearchOptions.RELATIONS in self.features
|
161
|
-
|
162
|
-
def _get_query_information(self) -> Awaitable[QueryInfo]:
|
163
|
-
if self._query_information_task is None: # pragma: no cover
|
164
|
-
self._query_information_task = asyncio.create_task(self._query_information())
|
165
|
-
return self._query_information_task
|
166
|
-
|
167
|
-
async def _query_information(self) -> QueryInfo:
|
168
|
-
# HACK: while transitioning to the new query parser, use fetcher under
|
169
|
-
# the hood for a smoother migration
|
170
|
-
query_info = await self.fetcher._predict_query_endpoint()
|
171
|
-
if query_info is None:
|
172
|
-
raise SendToPredictError("Error while using predict's query endpoint")
|
173
|
-
return query_info
|
174
|
-
|
175
|
-
async def _schedule_dependency_tasks(self) -> None:
|
176
|
-
"""
|
177
|
-
This will schedule concurrent tasks for different data that needs to be pulled
|
178
|
-
for the sake of the query being performed
|
179
|
-
"""
|
180
|
-
if len(self.old_filters.label_filters) > 0:
|
181
|
-
asyncio.ensure_future(self.fetcher.get_classification_labels())
|
182
|
-
|
183
|
-
if self.has_vector_search and self.user_vector is None:
|
184
|
-
self.query_endpoint_used = True
|
185
|
-
asyncio.ensure_future(self._get_query_information())
|
186
|
-
# XXX: should we also ensure get_vectorset and get_query_vector?
|
187
|
-
asyncio.ensure_future(self.fetcher.get_matryoshka_dimension())
|
188
|
-
|
189
|
-
if (self.has_relations_search or self.autofilter) and len(self.query) > 0:
|
190
|
-
if not self.query_endpoint_used:
|
191
|
-
# If we only need to detect entities, we don't need the query endpoint
|
192
|
-
asyncio.ensure_future(self.fetcher.get_detected_entities())
|
193
|
-
asyncio.ensure_future(self.fetcher.get_entities_meta_cache())
|
194
|
-
asyncio.ensure_future(self.fetcher.get_deleted_entity_groups())
|
195
|
-
if self.with_synonyms and self.query:
|
196
|
-
asyncio.ensure_future(self.fetcher.get_synonyms())
|
197
|
-
|
198
|
-
@query_parser_observer.wrap({"type": "QueryParser"})
|
199
|
-
async def parse(self) -> tuple[nodereader_pb2.SearchRequest, bool, list[str], Optional[str]]:
|
200
|
-
"""
|
201
|
-
:return: (request, incomplete, autofilters)
|
202
|
-
where:
|
203
|
-
- request: protobuf nodereader_pb2.SearchRequest object
|
204
|
-
- incomplete: If the query is incomplete (missing vectors)
|
205
|
-
- autofilters: The autofilters that were applied
|
206
|
-
"""
|
207
|
-
|
208
|
-
# Filter some queries that panic tantivy, better than returning the 500
|
209
|
-
if INVALID_QUERY.search(self.query):
|
210
|
-
raise InvalidQueryError("query", "Invalid query syntax")
|
211
|
-
|
212
|
-
request = nodereader_pb2.SearchRequest()
|
213
|
-
request.body = self.query
|
214
|
-
request.with_duplicates = self.with_duplicates
|
215
|
-
|
216
|
-
self.parse_sorting(request)
|
217
|
-
|
218
|
-
await self._schedule_dependency_tasks()
|
219
|
-
|
220
|
-
await self.parse_filters(request)
|
221
|
-
self.parse_document_search(request)
|
222
|
-
self.parse_paragraph_search(request)
|
223
|
-
incomplete, rephrased_query = await self.parse_vector_search(request)
|
224
|
-
autofilters = await self.parse_relation_search(request)
|
225
|
-
await self.parse_synonyms(request)
|
226
|
-
await self.parse_min_score(request, incomplete)
|
227
|
-
await self.adjust_page_size(request, self.rank_fusion, self.reranker)
|
228
|
-
return request, incomplete, autofilters, rephrased_query
|
229
|
-
|
230
|
-
async def parse_filters(self, request: nodereader_pb2.SearchRequest) -> None:
|
231
|
-
request.faceted.labels.extend([translate_label(facet) for facet in self.faceted])
|
232
|
-
|
233
|
-
if self.security is not None and len(self.security.groups) > 0:
|
234
|
-
security_pb = utils_pb2.Security()
|
235
|
-
for group_id in self.security.groups:
|
236
|
-
if group_id not in security_pb.access_groups:
|
237
|
-
security_pb.access_groups.append(group_id)
|
238
|
-
request.security.CopyFrom(security_pb)
|
239
|
-
|
240
|
-
has_old_filters = False
|
241
|
-
if self.old_filters:
|
242
|
-
field_expr, paragraph_expr = await parse_old_filters(self.old_filters, self.fetcher)
|
243
|
-
if field_expr is not None:
|
244
|
-
request.field_filter.CopyFrom(field_expr)
|
245
|
-
has_old_filters = True
|
246
|
-
if paragraph_expr is not None:
|
247
|
-
request.paragraph_filter.CopyFrom(paragraph_expr)
|
248
|
-
has_old_filters = True
|
249
|
-
|
250
|
-
if self.filter_expression and has_old_filters:
|
251
|
-
raise InvalidQueryError("filter_expression", "Cannot mix old filters with filter_expression")
|
252
|
-
|
253
|
-
if self.filter_expression:
|
254
|
-
if self.filter_expression.field:
|
255
|
-
expr = await parse_expression(self.filter_expression.field, self.kbid)
|
256
|
-
if expr:
|
257
|
-
request.field_filter.CopyFrom(expr)
|
258
|
-
|
259
|
-
if self.filter_expression.paragraph:
|
260
|
-
expr = await parse_expression(self.filter_expression.paragraph, self.kbid)
|
261
|
-
if expr:
|
262
|
-
request.paragraph_filter.CopyFrom(expr)
|
263
|
-
|
264
|
-
if self.filter_expression.operator == FilterExpression.Operator.OR:
|
265
|
-
request.filter_operator = nodereader_pb2.FilterOperator.OR
|
266
|
-
else:
|
267
|
-
request.filter_operator = nodereader_pb2.FilterOperator.AND
|
268
|
-
|
269
|
-
if self.hidden is not None:
|
270
|
-
expr = nodereader_pb2.FilterExpression()
|
271
|
-
if self.hidden:
|
272
|
-
expr.facet.facet = LABEL_HIDDEN
|
273
|
-
else:
|
274
|
-
expr.bool_not.facet.facet = LABEL_HIDDEN
|
275
|
-
|
276
|
-
add_and_expression(request.field_filter, expr)
|
277
|
-
|
278
|
-
def parse_sorting(self, request: nodereader_pb2.SearchRequest) -> None:
|
279
|
-
if len(self.query) == 0:
|
280
|
-
if self.sort is None:
|
281
|
-
self.sort = SortOptions(
|
282
|
-
field=SortField.CREATED,
|
283
|
-
order=SortOrder.DESC,
|
284
|
-
limit=None,
|
285
|
-
)
|
286
|
-
elif self.sort.field not in INDEX_SORTABLE_FIELDS:
|
287
|
-
raise InvalidQueryError(
|
288
|
-
"sort_field",
|
289
|
-
f"Empty query can only be sorted by '{SortField.CREATED}' or"
|
290
|
-
f" '{SortField.MODIFIED}' and sort limit won't be applied",
|
291
|
-
)
|
292
|
-
else:
|
293
|
-
if self.sort is None:
|
294
|
-
self.sort = SortOptions(
|
295
|
-
field=SortField.SCORE,
|
296
|
-
order=SortOrder.DESC,
|
297
|
-
limit=None,
|
298
|
-
)
|
299
|
-
elif self.sort.field not in INDEX_SORTABLE_FIELDS and self.sort.limit is None:
|
300
|
-
raise InvalidQueryError(
|
301
|
-
"sort_field",
|
302
|
-
f"Sort by '{self.sort.field}' requires setting a sort limit",
|
303
|
-
)
|
304
|
-
|
305
|
-
# We need to ask for all and cut later
|
306
|
-
request.page_number = 0
|
307
|
-
if self.sort and self.sort.limit is not None:
|
308
|
-
# As the index can't sort, we have to do it when merging. To
|
309
|
-
# have consistent results, we must limit them
|
310
|
-
request.result_per_page = self.sort.limit
|
311
|
-
else:
|
312
|
-
request.result_per_page = self.top_k
|
313
|
-
|
314
|
-
sort_field = get_sort_field_proto(self.sort.field) if self.sort else None
|
315
|
-
if sort_field is not None:
|
316
|
-
request.order.sort_by = sort_field
|
317
|
-
request.order.type = SortOrderMap[self.sort.order] # type: ignore
|
318
|
-
|
319
|
-
async def parse_min_score(self, request: nodereader_pb2.SearchRequest, incomplete: bool) -> None:
|
320
|
-
semantic_min_score = DEFAULT_GENERIC_SEMANTIC_THRESHOLD
|
321
|
-
if self.min_score.semantic is not None:
|
322
|
-
semantic_min_score = self.min_score.semantic
|
323
|
-
elif self.has_vector_search and not incomplete:
|
324
|
-
query_information = await self._get_query_information()
|
325
|
-
vectorset = await self.fetcher.get_vectorset()
|
326
|
-
semantic_threshold = query_information.semantic_thresholds.get(vectorset, None)
|
327
|
-
if semantic_threshold is not None:
|
328
|
-
semantic_min_score = semantic_threshold
|
329
|
-
else:
|
330
|
-
logger.warning(
|
331
|
-
"Semantic threshold not found in query information, using default",
|
332
|
-
extra={"kbid": self.kbid},
|
333
|
-
)
|
334
|
-
self.min_score.semantic = semantic_min_score
|
335
|
-
request.min_score_semantic = self.min_score.semantic
|
336
|
-
request.min_score_bm25 = self.min_score.bm25
|
337
|
-
|
338
|
-
def parse_document_search(self, request: nodereader_pb2.SearchRequest) -> None:
|
339
|
-
if SearchOptions.FULLTEXT in self.features:
|
340
|
-
request.document = True
|
341
|
-
node_features.inc({"type": "documents"})
|
342
|
-
|
343
|
-
def parse_paragraph_search(self, request: nodereader_pb2.SearchRequest) -> None:
|
344
|
-
if SearchOptions.KEYWORD in self.features:
|
345
|
-
request.paragraph = True
|
346
|
-
node_features.inc({"type": "paragraphs"})
|
347
|
-
|
348
|
-
async def parse_vector_search(
|
349
|
-
self, request: nodereader_pb2.SearchRequest
|
350
|
-
) -> tuple[bool, Optional[str]]:
|
351
|
-
if not self.has_vector_search:
|
352
|
-
return False, None
|
353
|
-
|
354
|
-
node_features.inc({"type": "vectors"})
|
355
|
-
|
356
|
-
vectorset = await self.fetcher.get_vectorset()
|
357
|
-
query_vector = await self.fetcher.get_query_vector()
|
358
|
-
rephrased_query = await self.fetcher.get_rephrased_query()
|
359
|
-
incomplete = query_vector is None
|
360
|
-
|
361
|
-
request.vectorset = vectorset
|
362
|
-
if query_vector is not None:
|
363
|
-
request.vector.extend(query_vector)
|
364
|
-
|
365
|
-
return incomplete, rephrased_query
|
366
|
-
|
367
|
-
async def parse_relation_search(self, request: nodereader_pb2.SearchRequest) -> list[str]:
|
368
|
-
autofilters = []
|
369
|
-
# BUG: autofiler should autofilter, not enable relation search
|
370
|
-
if self.has_relations_search or self.autofilter:
|
371
|
-
if self.query_entities:
|
372
|
-
detected_entities = []
|
373
|
-
for entity in self.query_entities:
|
374
|
-
relation_node = utils_pb2.RelationNode()
|
375
|
-
relation_node.value = entity.name
|
376
|
-
if entity.type is not None:
|
377
|
-
relation_node.ntype = RelationNodeTypeMap[entity.type]
|
378
|
-
if entity.subtype is not None:
|
379
|
-
relation_node.subtype = entity.subtype
|
380
|
-
detected_entities.append(relation_node)
|
381
|
-
else:
|
382
|
-
detected_entities = await self.fetcher.get_detected_entities()
|
383
|
-
meta_cache = await self.fetcher.get_entities_meta_cache()
|
384
|
-
detected_entities = expand_entities(meta_cache, detected_entities)
|
385
|
-
if self.has_relations_search:
|
386
|
-
request.relation_subgraph.entry_points.extend(detected_entities)
|
387
|
-
request.relation_subgraph.depth = 1
|
388
|
-
request.relation_subgraph.deleted_groups.extend(
|
389
|
-
await self.fetcher.get_deleted_entity_groups()
|
390
|
-
)
|
391
|
-
for group_id, deleted_entities in meta_cache.deleted_entities.items():
|
392
|
-
request.relation_subgraph.deleted_entities.append(
|
393
|
-
nodereader_pb2.EntitiesSubgraphRequest.DeletedEntities(
|
394
|
-
node_subtype=group_id, node_values=deleted_entities
|
395
|
-
)
|
396
|
-
)
|
397
|
-
node_features.inc({"type": "relations"})
|
398
|
-
if self.autofilter:
|
399
|
-
entity_filters = apply_entities_filter(request, detected_entities)
|
400
|
-
autofilters.extend([translate_system_to_alias_label(e) for e in entity_filters])
|
401
|
-
return autofilters
|
402
|
-
|
403
|
-
async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None:
|
404
|
-
"""
|
405
|
-
Replace the terms in the query with an expression that will make it match with the configured synonyms.
|
406
|
-
We're using the Tantivy's query language here: https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html
|
407
|
-
|
408
|
-
Example:
|
409
|
-
- Synonyms: Foo -> Bar, Baz
|
410
|
-
- Query: "What is Foo?"
|
411
|
-
- Advanced Query: "What is (Foo OR Bar OR Baz)?"
|
412
|
-
"""
|
413
|
-
if not self.with_synonyms or not self.query:
|
414
|
-
# Nothing to do
|
415
|
-
return
|
416
|
-
|
417
|
-
if self.has_vector_search or self.has_relations_search:
|
418
|
-
raise InvalidQueryError(
|
419
|
-
"synonyms",
|
420
|
-
"Search with custom synonyms is only supported on paragraph and document search",
|
421
|
-
)
|
422
|
-
|
423
|
-
synonyms = await self.fetcher.get_synonyms()
|
424
|
-
if synonyms is None:
|
425
|
-
# No synonyms found
|
426
|
-
return
|
427
|
-
|
428
|
-
# Calculate term variants: 'term' -> '(term OR synonym1 OR synonym2)'
|
429
|
-
variants: dict[str, str] = {}
|
430
|
-
for term, term_synonyms in synonyms.terms.items():
|
431
|
-
if len(term_synonyms.synonyms) > 0:
|
432
|
-
variants[term] = "({})".format(" OR ".join([term] + list(term_synonyms.synonyms)))
|
433
|
-
|
434
|
-
# Split the query into terms
|
435
|
-
query_terms = self.query.split()
|
436
|
-
|
437
|
-
# Remove punctuation from the query terms
|
438
|
-
clean_query_terms = [term.strip(string.punctuation) for term in query_terms]
|
439
|
-
|
440
|
-
# Replace the original terms with the variants if the cleaned term is in the variants
|
441
|
-
term_with_synonyms_found = False
|
442
|
-
for index, clean_term in enumerate(clean_query_terms):
|
443
|
-
if clean_term in variants:
|
444
|
-
term_with_synonyms_found = True
|
445
|
-
query_terms[index] = query_terms[index].replace(clean_term, variants[clean_term])
|
446
|
-
|
447
|
-
if term_with_synonyms_found:
|
448
|
-
request.advanced_query = " ".join(query_terms)
|
449
|
-
request.ClearField("body")
|
450
|
-
|
451
|
-
async def get_visual_llm_enabled(self) -> bool:
|
452
|
-
return (await self._get_query_information()).visual_llm
|
453
|
-
|
454
|
-
async def get_max_tokens_context(self) -> int:
|
455
|
-
model_max = (await self._get_query_information()).max_context
|
456
|
-
if self.max_tokens is not None and self.max_tokens.context is not None:
|
457
|
-
if self.max_tokens.context > model_max:
|
458
|
-
raise InvalidQueryError(
|
459
|
-
"max_tokens.context",
|
460
|
-
f"Max context tokens is higher than the model's limit of {model_max}",
|
461
|
-
)
|
462
|
-
return self.max_tokens.context
|
463
|
-
return model_max
|
464
|
-
|
465
|
-
def get_max_tokens_answer(self) -> Optional[int]:
|
466
|
-
if self.max_tokens is not None and self.max_tokens.answer is not None:
|
467
|
-
return self.max_tokens.answer
|
468
|
-
return None
|
469
|
-
|
470
|
-
async def adjust_page_size(
|
471
|
-
self,
|
472
|
-
request: nodereader_pb2.SearchRequest,
|
473
|
-
rank_fusion: Optional[RankFusionAlgorithm],
|
474
|
-
reranker: Optional[Reranker],
|
475
|
-
):
|
476
|
-
"""Adjust requested page size depending on rank fusion and reranking algorithms.
|
477
|
-
|
478
|
-
Some rerankers want more results than the requested by the user so
|
479
|
-
reranking can have more choices.
|
480
|
-
|
481
|
-
"""
|
482
|
-
rank_fusion_window = 0
|
483
|
-
if rank_fusion is not None:
|
484
|
-
rank_fusion_window = rank_fusion.window
|
485
|
-
|
486
|
-
reranker_window = 0
|
487
|
-
if reranker is not None:
|
488
|
-
reranker_window = reranker.window or 0
|
489
|
-
|
490
|
-
request.result_per_page = max(
|
491
|
-
request.result_per_page,
|
492
|
-
rank_fusion_window,
|
493
|
-
reranker_window,
|
494
|
-
)
|
495
|
-
|
496
43
|
|
497
44
|
async def paragraph_query_to_pb(
|
498
45
|
kbid: str,
|
@@ -33,6 +33,9 @@ from nucliadb.search.search.metrics import (
|
|
33
33
|
from nucliadb.search.search.query_parser.exceptions import InvalidQueryError
|
34
34
|
from nucliadb.search.utilities import get_predict
|
35
35
|
from nucliadb_models.internal.predict import QueryInfo
|
36
|
+
from nucliadb_models.search import (
|
37
|
+
MaxTokens,
|
38
|
+
)
|
36
39
|
from nucliadb_protos import knowledgebox_pb2, utils_pb2
|
37
40
|
|
38
41
|
|
@@ -205,6 +208,15 @@ class Fetcher:
|
|
205
208
|
return None
|
206
209
|
return query_info.rephrased_query
|
207
210
|
|
211
|
+
async def get_semantic_min_score(self) -> Optional[float]:
|
212
|
+
query_info = await self._predict_query_endpoint()
|
213
|
+
if query_info is None:
|
214
|
+
return None
|
215
|
+
|
216
|
+
vectorset = await self.get_vectorset()
|
217
|
+
min_score = query_info.semantic_thresholds.get(vectorset, None)
|
218
|
+
return min_score
|
219
|
+
|
208
220
|
# Labels
|
209
221
|
|
210
222
|
async def get_classification_labels(self) -> knowledgebox_pb2.Labels:
|
@@ -268,6 +280,35 @@ class Fetcher:
|
|
268
280
|
self.cache.synonyms = synonyms
|
269
281
|
return synonyms
|
270
282
|
|
283
|
+
# Generative
|
284
|
+
|
285
|
+
async def get_visual_llm_enabled(self) -> bool:
|
286
|
+
query_info = await self._predict_query_endpoint()
|
287
|
+
if query_info is None:
|
288
|
+
raise SendToPredictError("Error while using predict's query endpoint")
|
289
|
+
|
290
|
+
return query_info.visual_llm
|
291
|
+
|
292
|
+
async def get_max_context_tokens(self, max_tokens: Optional[MaxTokens]) -> int:
|
293
|
+
query_info = await self._predict_query_endpoint()
|
294
|
+
if query_info is None:
|
295
|
+
raise SendToPredictError("Error while using predict's query endpoint")
|
296
|
+
|
297
|
+
model_max = query_info.max_context
|
298
|
+
if max_tokens is not None and max_tokens.context is not None:
|
299
|
+
if max_tokens.context > model_max:
|
300
|
+
raise InvalidQueryError(
|
301
|
+
"max_tokens.context",
|
302
|
+
f"Max context tokens is higher than the model's limit of {model_max}",
|
303
|
+
)
|
304
|
+
return max_tokens.context
|
305
|
+
return model_max
|
306
|
+
|
307
|
+
def get_max_answer_tokens(self, max_tokens: Optional[MaxTokens]) -> Optional[int]:
|
308
|
+
if max_tokens is not None and max_tokens.answer is not None:
|
309
|
+
return max_tokens.answer
|
310
|
+
return None
|
311
|
+
|
271
312
|
# Predict API
|
272
313
|
|
273
314
|
async def _predict_query_endpoint(self) -> Optional[QueryInfo]:
|
@@ -17,7 +17,6 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
|
21
20
|
from dataclasses import dataclass
|
22
21
|
from datetime import datetime
|
23
22
|
from typing import Literal, Optional, Union
|
@@ -27,14 +26,70 @@ from pydantic import (
|
|
27
26
|
Field,
|
28
27
|
)
|
29
28
|
|
29
|
+
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
30
30
|
from nucliadb_models import search as search_models
|
31
|
-
from nucliadb_protos import nodereader_pb2
|
31
|
+
from nucliadb_protos import nodereader_pb2, utils_pb2
|
32
32
|
|
33
33
|
### Retrieval
|
34
34
|
|
35
|
+
# query
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class _TextQuery:
|
40
|
+
query: str
|
41
|
+
is_synonyms_query: bool
|
42
|
+
min_score: float
|
43
|
+
sort: search_models.SortOrder = search_models.SortOrder.DESC
|
44
|
+
order_by: search_models.SortField = search_models.SortField.SCORE
|
45
|
+
|
46
|
+
|
47
|
+
FulltextQuery = _TextQuery
|
48
|
+
KeywordQuery = _TextQuery
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class SemanticQuery:
|
53
|
+
query: Optional[list[float]]
|
54
|
+
vectorset: str
|
55
|
+
min_score: float
|
56
|
+
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class RelationQuery:
|
60
|
+
detected_entities: list[utils_pb2.RelationNode]
|
61
|
+
# list[subtype]
|
62
|
+
deleted_entity_groups: list[str]
|
63
|
+
# subtype -> list[entity]
|
64
|
+
deleted_entities: dict[str, list[str]]
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class Query:
|
69
|
+
fulltext: Optional[FulltextQuery] = None
|
70
|
+
keyword: Optional[KeywordQuery] = None
|
71
|
+
semantic: Optional[SemanticQuery] = None
|
72
|
+
relation: Optional[RelationQuery] = None
|
73
|
+
|
74
|
+
|
35
75
|
# filters
|
36
76
|
|
37
77
|
|
78
|
+
@dataclass
|
79
|
+
class Filters:
|
80
|
+
field_expression: Optional[nodereader_pb2.FilterExpression] = None
|
81
|
+
paragraph_expression: Optional[nodereader_pb2.FilterExpression] = None
|
82
|
+
filter_expression_operator: nodereader_pb2.FilterOperator.ValueType = (
|
83
|
+
nodereader_pb2.FilterOperator.AND
|
84
|
+
)
|
85
|
+
|
86
|
+
autofilter: Optional[list[utils_pb2.RelationNode]] = None
|
87
|
+
facets: list[str] = Field(default_factory=list)
|
88
|
+
hidden: Optional[bool] = None
|
89
|
+
security: Optional[search_models.RequestSecurity] = None
|
90
|
+
with_duplicates: bool = False
|
91
|
+
|
92
|
+
|
38
93
|
class DateTimeFilter(BaseModel):
|
39
94
|
after: Optional[datetime] = None # aka, start
|
40
95
|
before: Optional[datetime] = None # aka, end
|
@@ -57,26 +112,45 @@ class ReciprocalRankFusion(RankFusion):
|
|
57
112
|
# reranking
|
58
113
|
|
59
114
|
|
60
|
-
class
|
61
|
-
|
115
|
+
class NoopReranker(BaseModel):
|
116
|
+
pass
|
62
117
|
|
63
|
-
class NoopReranker(Reranker): ...
|
64
118
|
|
65
|
-
|
66
|
-
class PredictReranker(Reranker):
|
119
|
+
class PredictReranker(BaseModel):
|
67
120
|
window: int = Field(le=200)
|
68
121
|
|
69
122
|
|
70
|
-
|
123
|
+
Reranker = Union[NoopReranker, PredictReranker]
|
124
|
+
|
125
|
+
# retrieval and generation operations
|
71
126
|
|
72
127
|
|
73
128
|
@dataclass
|
74
129
|
class UnitRetrieval:
|
130
|
+
query: Query
|
75
131
|
top_k: int
|
132
|
+
filters: Filters
|
133
|
+
# TODO: rank fusion depends on the response building, not the retrieval
|
76
134
|
rank_fusion: RankFusion
|
135
|
+
# TODO: reranking fusion depends on the response building, not the retrieval
|
77
136
|
reranker: Reranker
|
78
137
|
|
79
138
|
|
139
|
+
@dataclass
|
140
|
+
class Generation:
|
141
|
+
use_visual_llm: bool
|
142
|
+
max_context_tokens: int
|
143
|
+
max_answer_tokens: Optional[int]
|
144
|
+
|
145
|
+
|
146
|
+
@dataclass
|
147
|
+
class ParsedQuery:
|
148
|
+
fetcher: Fetcher
|
149
|
+
retrieval: UnitRetrieval
|
150
|
+
generation: Optional[Generation] = None
|
151
|
+
# TODO: add merge, rank fusion, rerank...
|
152
|
+
|
153
|
+
|
80
154
|
### Catalog
|
81
155
|
@dataclass
|
82
156
|
class CatalogExpression:
|