nucliadb 6.3.6.post4063__py3-none-any.whl → 6.3.7.post4068__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,46 +17,22 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- import asyncio
21
- import re
22
- import string
23
20
  from datetime import datetime
24
- from typing import Any, Awaitable, Optional
21
+ from typing import Any, Optional
25
22
 
26
23
  from nucliadb.common import datamanagers
27
- from nucliadb.common.models_utils.from_proto import RelationNodeTypeMap
28
- from nucliadb.search import logger
29
- from nucliadb.search.predict import SendToPredictError
30
24
  from nucliadb.search.search.filters import (
31
25
  translate_label,
32
26
  )
33
- from nucliadb.search.search.metrics import (
34
- node_features,
35
- query_parser_observer,
36
- )
37
27
  from nucliadb.search.search.query_parser.fetcher import Fetcher
38
- from nucliadb.search.search.rank_fusion import (
39
- RankFusionAlgorithm,
40
- )
41
- from nucliadb.search.search.rerankers import (
42
- Reranker,
43
- )
44
28
  from nucliadb_models.filters import FilterExpression
45
- from nucliadb_models.internal.predict import QueryInfo
46
- from nucliadb_models.labels import LABEL_HIDDEN, translate_system_to_alias_label
29
+ from nucliadb_models.labels import LABEL_HIDDEN
47
30
  from nucliadb_models.metadata import ResourceProcessingStatus
48
31
  from nucliadb_models.search import (
49
- KnowledgeGraphEntity,
50
- MaxTokens,
51
- MinScore,
52
- SearchOptions,
53
32
  SortField,
54
- SortOptions,
55
33
  SortOrder,
56
- SortOrderMap,
57
34
  SuggestOptions,
58
35
  )
59
- from nucliadb_models.security import RequestSecurity
60
36
  from nucliadb_protos import nodereader_pb2, utils_pb2
61
37
  from nucliadb_protos.noderesources_pb2 import Resource
62
38
 
@@ -64,435 +40,6 @@ from .exceptions import InvalidQueryError
64
40
  from .query_parser.filter_expression import add_and_expression, parse_expression
65
41
  from .query_parser.old_filters import OldFilterParams, parse_old_filters
66
42
 
67
- INDEX_SORTABLE_FIELDS = [
68
- SortField.CREATED,
69
- SortField.MODIFIED,
70
- ]
71
-
72
- DEFAULT_GENERIC_SEMANTIC_THRESHOLD = 0.7
73
-
74
- # -* is an invalid query in tantivy and it won't return results but if you add some whitespaces
75
- # between - and *, it will actually trigger a tantivy bug and panic
76
- INVALID_QUERY = re.compile(r"- +\*")
77
-
78
-
79
- class QueryParser:
80
- """
81
- Queries are getting more and more complex and different phases of the query
82
- depending on different data.
83
-
84
- This class is an encapsulation of the different phases of the query and allow
85
- some stateful interaction with a query and different depenedencies during
86
- query parsing.
87
- """
88
-
89
- _query_information_task: Optional[asyncio.Task] = None
90
-
91
- def __init__(
92
- self,
93
- *,
94
- kbid: str,
95
- features: list[SearchOptions],
96
- query: str,
97
- top_k: int,
98
- min_score: MinScore,
99
- old_filters: OldFilterParams,
100
- filter_expression: Optional[FilterExpression] = None,
101
- query_entities: Optional[list[KnowledgeGraphEntity]] = None,
102
- faceted: Optional[list[str]] = None,
103
- sort: Optional[SortOptions] = None,
104
- user_vector: Optional[list[float]] = None,
105
- vectorset: Optional[str] = None,
106
- with_duplicates: bool = False,
107
- with_status: Optional[ResourceProcessingStatus] = None,
108
- with_synonyms: bool = False,
109
- autofilter: bool = False,
110
- security: Optional[RequestSecurity] = None,
111
- generative_model: Optional[str] = None,
112
- rephrase: bool = False,
113
- rephrase_prompt: Optional[str] = None,
114
- max_tokens: Optional[MaxTokens] = None,
115
- hidden: Optional[bool] = None,
116
- rank_fusion: Optional[RankFusionAlgorithm] = None,
117
- reranker: Optional[Reranker] = None,
118
- ):
119
- self.kbid = kbid
120
- self.features = features
121
- self.query = query
122
- self.query_entities = query_entities
123
- self.hidden = hidden
124
- self.faceted = faceted or []
125
- self.top_k = top_k
126
- self.min_score = min_score
127
- self.sort = sort
128
- self.user_vector = user_vector
129
- self.vectorset = vectorset
130
- self.with_duplicates = with_duplicates
131
- self.with_status = with_status
132
- self.with_synonyms = with_synonyms
133
- self.autofilter = autofilter
134
- self.security = security
135
- self.generative_model = generative_model
136
- self.rephrase = rephrase
137
- self.rephrase_prompt = rephrase_prompt
138
- self.query_endpoint_used = False
139
- self.max_tokens = max_tokens
140
- self.rank_fusion = rank_fusion
141
- self.reranker = reranker
142
- self.filter_expression = filter_expression
143
- self.old_filters = old_filters
144
- self.fetcher = Fetcher(
145
- kbid=kbid,
146
- query=query,
147
- user_vector=user_vector,
148
- vectorset=vectorset,
149
- rephrase=rephrase,
150
- rephrase_prompt=rephrase_prompt,
151
- generative_model=generative_model,
152
- )
153
-
154
- @property
155
- def has_vector_search(self) -> bool:
156
- return SearchOptions.SEMANTIC in self.features
157
-
158
- @property
159
- def has_relations_search(self) -> bool:
160
- return SearchOptions.RELATIONS in self.features
161
-
162
- def _get_query_information(self) -> Awaitable[QueryInfo]:
163
- if self._query_information_task is None: # pragma: no cover
164
- self._query_information_task = asyncio.create_task(self._query_information())
165
- return self._query_information_task
166
-
167
- async def _query_information(self) -> QueryInfo:
168
- # HACK: while transitioning to the new query parser, use fetcher under
169
- # the hood for a smoother migration
170
- query_info = await self.fetcher._predict_query_endpoint()
171
- if query_info is None:
172
- raise SendToPredictError("Error while using predict's query endpoint")
173
- return query_info
174
-
175
- async def _schedule_dependency_tasks(self) -> None:
176
- """
177
- This will schedule concurrent tasks for different data that needs to be pulled
178
- for the sake of the query being performed
179
- """
180
- if len(self.old_filters.label_filters) > 0:
181
- asyncio.ensure_future(self.fetcher.get_classification_labels())
182
-
183
- if self.has_vector_search and self.user_vector is None:
184
- self.query_endpoint_used = True
185
- asyncio.ensure_future(self._get_query_information())
186
- # XXX: should we also ensure get_vectorset and get_query_vector?
187
- asyncio.ensure_future(self.fetcher.get_matryoshka_dimension())
188
-
189
- if (self.has_relations_search or self.autofilter) and len(self.query) > 0:
190
- if not self.query_endpoint_used:
191
- # If we only need to detect entities, we don't need the query endpoint
192
- asyncio.ensure_future(self.fetcher.get_detected_entities())
193
- asyncio.ensure_future(self.fetcher.get_entities_meta_cache())
194
- asyncio.ensure_future(self.fetcher.get_deleted_entity_groups())
195
- if self.with_synonyms and self.query:
196
- asyncio.ensure_future(self.fetcher.get_synonyms())
197
-
198
- @query_parser_observer.wrap({"type": "QueryParser"})
199
- async def parse(self) -> tuple[nodereader_pb2.SearchRequest, bool, list[str], Optional[str]]:
200
- """
201
- :return: (request, incomplete, autofilters)
202
- where:
203
- - request: protobuf nodereader_pb2.SearchRequest object
204
- - incomplete: If the query is incomplete (missing vectors)
205
- - autofilters: The autofilters that were applied
206
- """
207
-
208
- # Filter some queries that panic tantivy, better than returning the 500
209
- if INVALID_QUERY.search(self.query):
210
- raise InvalidQueryError("query", "Invalid query syntax")
211
-
212
- request = nodereader_pb2.SearchRequest()
213
- request.body = self.query
214
- request.with_duplicates = self.with_duplicates
215
-
216
- self.parse_sorting(request)
217
-
218
- await self._schedule_dependency_tasks()
219
-
220
- await self.parse_filters(request)
221
- self.parse_document_search(request)
222
- self.parse_paragraph_search(request)
223
- incomplete, rephrased_query = await self.parse_vector_search(request)
224
- autofilters = await self.parse_relation_search(request)
225
- await self.parse_synonyms(request)
226
- await self.parse_min_score(request, incomplete)
227
- await self.adjust_page_size(request, self.rank_fusion, self.reranker)
228
- return request, incomplete, autofilters, rephrased_query
229
-
230
- async def parse_filters(self, request: nodereader_pb2.SearchRequest) -> None:
231
- request.faceted.labels.extend([translate_label(facet) for facet in self.faceted])
232
-
233
- if self.security is not None and len(self.security.groups) > 0:
234
- security_pb = utils_pb2.Security()
235
- for group_id in self.security.groups:
236
- if group_id not in security_pb.access_groups:
237
- security_pb.access_groups.append(group_id)
238
- request.security.CopyFrom(security_pb)
239
-
240
- has_old_filters = False
241
- if self.old_filters:
242
- field_expr, paragraph_expr = await parse_old_filters(self.old_filters, self.fetcher)
243
- if field_expr is not None:
244
- request.field_filter.CopyFrom(field_expr)
245
- has_old_filters = True
246
- if paragraph_expr is not None:
247
- request.paragraph_filter.CopyFrom(paragraph_expr)
248
- has_old_filters = True
249
-
250
- if self.filter_expression and has_old_filters:
251
- raise InvalidQueryError("filter_expression", "Cannot mix old filters with filter_expression")
252
-
253
- if self.filter_expression:
254
- if self.filter_expression.field:
255
- expr = await parse_expression(self.filter_expression.field, self.kbid)
256
- if expr:
257
- request.field_filter.CopyFrom(expr)
258
-
259
- if self.filter_expression.paragraph:
260
- expr = await parse_expression(self.filter_expression.paragraph, self.kbid)
261
- if expr:
262
- request.paragraph_filter.CopyFrom(expr)
263
-
264
- if self.filter_expression.operator == FilterExpression.Operator.OR:
265
- request.filter_operator = nodereader_pb2.FilterOperator.OR
266
- else:
267
- request.filter_operator = nodereader_pb2.FilterOperator.AND
268
-
269
- if self.hidden is not None:
270
- expr = nodereader_pb2.FilterExpression()
271
- if self.hidden:
272
- expr.facet.facet = LABEL_HIDDEN
273
- else:
274
- expr.bool_not.facet.facet = LABEL_HIDDEN
275
-
276
- add_and_expression(request.field_filter, expr)
277
-
278
- def parse_sorting(self, request: nodereader_pb2.SearchRequest) -> None:
279
- if len(self.query) == 0:
280
- if self.sort is None:
281
- self.sort = SortOptions(
282
- field=SortField.CREATED,
283
- order=SortOrder.DESC,
284
- limit=None,
285
- )
286
- elif self.sort.field not in INDEX_SORTABLE_FIELDS:
287
- raise InvalidQueryError(
288
- "sort_field",
289
- f"Empty query can only be sorted by '{SortField.CREATED}' or"
290
- f" '{SortField.MODIFIED}' and sort limit won't be applied",
291
- )
292
- else:
293
- if self.sort is None:
294
- self.sort = SortOptions(
295
- field=SortField.SCORE,
296
- order=SortOrder.DESC,
297
- limit=None,
298
- )
299
- elif self.sort.field not in INDEX_SORTABLE_FIELDS and self.sort.limit is None:
300
- raise InvalidQueryError(
301
- "sort_field",
302
- f"Sort by '{self.sort.field}' requires setting a sort limit",
303
- )
304
-
305
- # We need to ask for all and cut later
306
- request.page_number = 0
307
- if self.sort and self.sort.limit is not None:
308
- # As the index can't sort, we have to do it when merging. To
309
- # have consistent results, we must limit them
310
- request.result_per_page = self.sort.limit
311
- else:
312
- request.result_per_page = self.top_k
313
-
314
- sort_field = get_sort_field_proto(self.sort.field) if self.sort else None
315
- if sort_field is not None:
316
- request.order.sort_by = sort_field
317
- request.order.type = SortOrderMap[self.sort.order] # type: ignore
318
-
319
- async def parse_min_score(self, request: nodereader_pb2.SearchRequest, incomplete: bool) -> None:
320
- semantic_min_score = DEFAULT_GENERIC_SEMANTIC_THRESHOLD
321
- if self.min_score.semantic is not None:
322
- semantic_min_score = self.min_score.semantic
323
- elif self.has_vector_search and not incomplete:
324
- query_information = await self._get_query_information()
325
- vectorset = await self.fetcher.get_vectorset()
326
- semantic_threshold = query_information.semantic_thresholds.get(vectorset, None)
327
- if semantic_threshold is not None:
328
- semantic_min_score = semantic_threshold
329
- else:
330
- logger.warning(
331
- "Semantic threshold not found in query information, using default",
332
- extra={"kbid": self.kbid},
333
- )
334
- self.min_score.semantic = semantic_min_score
335
- request.min_score_semantic = self.min_score.semantic
336
- request.min_score_bm25 = self.min_score.bm25
337
-
338
- def parse_document_search(self, request: nodereader_pb2.SearchRequest) -> None:
339
- if SearchOptions.FULLTEXT in self.features:
340
- request.document = True
341
- node_features.inc({"type": "documents"})
342
-
343
- def parse_paragraph_search(self, request: nodereader_pb2.SearchRequest) -> None:
344
- if SearchOptions.KEYWORD in self.features:
345
- request.paragraph = True
346
- node_features.inc({"type": "paragraphs"})
347
-
348
- async def parse_vector_search(
349
- self, request: nodereader_pb2.SearchRequest
350
- ) -> tuple[bool, Optional[str]]:
351
- if not self.has_vector_search:
352
- return False, None
353
-
354
- node_features.inc({"type": "vectors"})
355
-
356
- vectorset = await self.fetcher.get_vectorset()
357
- query_vector = await self.fetcher.get_query_vector()
358
- rephrased_query = await self.fetcher.get_rephrased_query()
359
- incomplete = query_vector is None
360
-
361
- request.vectorset = vectorset
362
- if query_vector is not None:
363
- request.vector.extend(query_vector)
364
-
365
- return incomplete, rephrased_query
366
-
367
- async def parse_relation_search(self, request: nodereader_pb2.SearchRequest) -> list[str]:
368
- autofilters = []
369
- # BUG: autofiler should autofilter, not enable relation search
370
- if self.has_relations_search or self.autofilter:
371
- if self.query_entities:
372
- detected_entities = []
373
- for entity in self.query_entities:
374
- relation_node = utils_pb2.RelationNode()
375
- relation_node.value = entity.name
376
- if entity.type is not None:
377
- relation_node.ntype = RelationNodeTypeMap[entity.type]
378
- if entity.subtype is not None:
379
- relation_node.subtype = entity.subtype
380
- detected_entities.append(relation_node)
381
- else:
382
- detected_entities = await self.fetcher.get_detected_entities()
383
- meta_cache = await self.fetcher.get_entities_meta_cache()
384
- detected_entities = expand_entities(meta_cache, detected_entities)
385
- if self.has_relations_search:
386
- request.relation_subgraph.entry_points.extend(detected_entities)
387
- request.relation_subgraph.depth = 1
388
- request.relation_subgraph.deleted_groups.extend(
389
- await self.fetcher.get_deleted_entity_groups()
390
- )
391
- for group_id, deleted_entities in meta_cache.deleted_entities.items():
392
- request.relation_subgraph.deleted_entities.append(
393
- nodereader_pb2.EntitiesSubgraphRequest.DeletedEntities(
394
- node_subtype=group_id, node_values=deleted_entities
395
- )
396
- )
397
- node_features.inc({"type": "relations"})
398
- if self.autofilter:
399
- entity_filters = apply_entities_filter(request, detected_entities)
400
- autofilters.extend([translate_system_to_alias_label(e) for e in entity_filters])
401
- return autofilters
402
-
403
- async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None:
404
- """
405
- Replace the terms in the query with an expression that will make it match with the configured synonyms.
406
- We're using the Tantivy's query language here: https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html
407
-
408
- Example:
409
- - Synonyms: Foo -> Bar, Baz
410
- - Query: "What is Foo?"
411
- - Advanced Query: "What is (Foo OR Bar OR Baz)?"
412
- """
413
- if not self.with_synonyms or not self.query:
414
- # Nothing to do
415
- return
416
-
417
- if self.has_vector_search or self.has_relations_search:
418
- raise InvalidQueryError(
419
- "synonyms",
420
- "Search with custom synonyms is only supported on paragraph and document search",
421
- )
422
-
423
- synonyms = await self.fetcher.get_synonyms()
424
- if synonyms is None:
425
- # No synonyms found
426
- return
427
-
428
- # Calculate term variants: 'term' -> '(term OR synonym1 OR synonym2)'
429
- variants: dict[str, str] = {}
430
- for term, term_synonyms in synonyms.terms.items():
431
- if len(term_synonyms.synonyms) > 0:
432
- variants[term] = "({})".format(" OR ".join([term] + list(term_synonyms.synonyms)))
433
-
434
- # Split the query into terms
435
- query_terms = self.query.split()
436
-
437
- # Remove punctuation from the query terms
438
- clean_query_terms = [term.strip(string.punctuation) for term in query_terms]
439
-
440
- # Replace the original terms with the variants if the cleaned term is in the variants
441
- term_with_synonyms_found = False
442
- for index, clean_term in enumerate(clean_query_terms):
443
- if clean_term in variants:
444
- term_with_synonyms_found = True
445
- query_terms[index] = query_terms[index].replace(clean_term, variants[clean_term])
446
-
447
- if term_with_synonyms_found:
448
- request.advanced_query = " ".join(query_terms)
449
- request.ClearField("body")
450
-
451
- async def get_visual_llm_enabled(self) -> bool:
452
- return (await self._get_query_information()).visual_llm
453
-
454
- async def get_max_tokens_context(self) -> int:
455
- model_max = (await self._get_query_information()).max_context
456
- if self.max_tokens is not None and self.max_tokens.context is not None:
457
- if self.max_tokens.context > model_max:
458
- raise InvalidQueryError(
459
- "max_tokens.context",
460
- f"Max context tokens is higher than the model's limit of {model_max}",
461
- )
462
- return self.max_tokens.context
463
- return model_max
464
-
465
- def get_max_tokens_answer(self) -> Optional[int]:
466
- if self.max_tokens is not None and self.max_tokens.answer is not None:
467
- return self.max_tokens.answer
468
- return None
469
-
470
- async def adjust_page_size(
471
- self,
472
- request: nodereader_pb2.SearchRequest,
473
- rank_fusion: Optional[RankFusionAlgorithm],
474
- reranker: Optional[Reranker],
475
- ):
476
- """Adjust requested page size depending on rank fusion and reranking algorithms.
477
-
478
- Some rerankers want more results than the requested by the user so
479
- reranking can have more choices.
480
-
481
- """
482
- rank_fusion_window = 0
483
- if rank_fusion is not None:
484
- rank_fusion_window = rank_fusion.window
485
-
486
- reranker_window = 0
487
- if reranker is not None:
488
- reranker_window = reranker.window or 0
489
-
490
- request.result_per_page = max(
491
- request.result_per_page,
492
- rank_fusion_window,
493
- reranker_window,
494
- )
495
-
496
43
 
497
44
  async def paragraph_query_to_pb(
498
45
  kbid: str,
@@ -33,6 +33,9 @@ from nucliadb.search.search.metrics import (
33
33
  from nucliadb.search.search.query_parser.exceptions import InvalidQueryError
34
34
  from nucliadb.search.utilities import get_predict
35
35
  from nucliadb_models.internal.predict import QueryInfo
36
+ from nucliadb_models.search import (
37
+ MaxTokens,
38
+ )
36
39
  from nucliadb_protos import knowledgebox_pb2, utils_pb2
37
40
 
38
41
 
@@ -205,6 +208,15 @@ class Fetcher:
205
208
  return None
206
209
  return query_info.rephrased_query
207
210
 
211
+ async def get_semantic_min_score(self) -> Optional[float]:
212
+ query_info = await self._predict_query_endpoint()
213
+ if query_info is None:
214
+ return None
215
+
216
+ vectorset = await self.get_vectorset()
217
+ min_score = query_info.semantic_thresholds.get(vectorset, None)
218
+ return min_score
219
+
208
220
  # Labels
209
221
 
210
222
  async def get_classification_labels(self) -> knowledgebox_pb2.Labels:
@@ -268,6 +280,35 @@ class Fetcher:
268
280
  self.cache.synonyms = synonyms
269
281
  return synonyms
270
282
 
283
+ # Generative
284
+
285
+ async def get_visual_llm_enabled(self) -> bool:
286
+ query_info = await self._predict_query_endpoint()
287
+ if query_info is None:
288
+ raise SendToPredictError("Error while using predict's query endpoint")
289
+
290
+ return query_info.visual_llm
291
+
292
+ async def get_max_context_tokens(self, max_tokens: Optional[MaxTokens]) -> int:
293
+ query_info = await self._predict_query_endpoint()
294
+ if query_info is None:
295
+ raise SendToPredictError("Error while using predict's query endpoint")
296
+
297
+ model_max = query_info.max_context
298
+ if max_tokens is not None and max_tokens.context is not None:
299
+ if max_tokens.context > model_max:
300
+ raise InvalidQueryError(
301
+ "max_tokens.context",
302
+ f"Max context tokens is higher than the model's limit of {model_max}",
303
+ )
304
+ return max_tokens.context
305
+ return model_max
306
+
307
+ def get_max_answer_tokens(self, max_tokens: Optional[MaxTokens]) -> Optional[int]:
308
+ if max_tokens is not None and max_tokens.answer is not None:
309
+ return max_tokens.answer
310
+ return None
311
+
271
312
  # Predict API
272
313
 
273
314
  async def _predict_query_endpoint(self) -> Optional[QueryInfo]:
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
-
21
20
  from dataclasses import dataclass
22
21
  from datetime import datetime
23
22
  from typing import Literal, Optional, Union
@@ -27,14 +26,70 @@ from pydantic import (
27
26
  Field,
28
27
  )
29
28
 
29
+ from nucliadb.search.search.query_parser.fetcher import Fetcher
30
30
  from nucliadb_models import search as search_models
31
- from nucliadb_protos import nodereader_pb2
31
+ from nucliadb_protos import nodereader_pb2, utils_pb2
32
32
 
33
33
  ### Retrieval
34
34
 
35
+ # query
36
+
37
+
38
+ @dataclass
39
+ class _TextQuery:
40
+ query: str
41
+ is_synonyms_query: bool
42
+ min_score: float
43
+ sort: search_models.SortOrder = search_models.SortOrder.DESC
44
+ order_by: search_models.SortField = search_models.SortField.SCORE
45
+
46
+
47
+ FulltextQuery = _TextQuery
48
+ KeywordQuery = _TextQuery
49
+
50
+
51
+ @dataclass
52
+ class SemanticQuery:
53
+ query: Optional[list[float]]
54
+ vectorset: str
55
+ min_score: float
56
+
57
+
58
+ @dataclass
59
+ class RelationQuery:
60
+ detected_entities: list[utils_pb2.RelationNode]
61
+ # list[subtype]
62
+ deleted_entity_groups: list[str]
63
+ # subtype -> list[entity]
64
+ deleted_entities: dict[str, list[str]]
65
+
66
+
67
+ @dataclass
68
+ class Query:
69
+ fulltext: Optional[FulltextQuery] = None
70
+ keyword: Optional[KeywordQuery] = None
71
+ semantic: Optional[SemanticQuery] = None
72
+ relation: Optional[RelationQuery] = None
73
+
74
+
35
75
  # filters
36
76
 
37
77
 
78
+ @dataclass
79
+ class Filters:
80
+ field_expression: Optional[nodereader_pb2.FilterExpression] = None
81
+ paragraph_expression: Optional[nodereader_pb2.FilterExpression] = None
82
+ filter_expression_operator: nodereader_pb2.FilterOperator.ValueType = (
83
+ nodereader_pb2.FilterOperator.AND
84
+ )
85
+
86
+ autofilter: Optional[list[utils_pb2.RelationNode]] = None
87
+ facets: list[str] = Field(default_factory=list)
88
+ hidden: Optional[bool] = None
89
+ security: Optional[search_models.RequestSecurity] = None
90
+ with_duplicates: bool = False
91
+
92
+
38
93
  class DateTimeFilter(BaseModel):
39
94
  after: Optional[datetime] = None # aka, start
40
95
  before: Optional[datetime] = None # aka, end
@@ -57,26 +112,45 @@ class ReciprocalRankFusion(RankFusion):
57
112
  # reranking
58
113
 
59
114
 
60
- class Reranker(BaseModel): ...
61
-
115
+ class NoopReranker(BaseModel):
116
+ pass
62
117
 
63
- class NoopReranker(Reranker): ...
64
118
 
65
-
66
- class PredictReranker(Reranker):
119
+ class PredictReranker(BaseModel):
67
120
  window: int = Field(le=200)
68
121
 
69
122
 
70
- # retrieval operation
123
+ Reranker = Union[NoopReranker, PredictReranker]
124
+
125
+ # retrieval and generation operations
71
126
 
72
127
 
73
128
  @dataclass
74
129
  class UnitRetrieval:
130
+ query: Query
75
131
  top_k: int
132
+ filters: Filters
133
+ # TODO: rank fusion depends on the response building, not the retrieval
76
134
  rank_fusion: RankFusion
135
+ # TODO: reranking fusion depends on the response building, not the retrieval
77
136
  reranker: Reranker
78
137
 
79
138
 
139
+ @dataclass
140
+ class Generation:
141
+ use_visual_llm: bool
142
+ max_context_tokens: int
143
+ max_answer_tokens: Optional[int]
144
+
145
+
146
+ @dataclass
147
+ class ParsedQuery:
148
+ fetcher: Fetcher
149
+ retrieval: UnitRetrieval
150
+ generation: Optional[Generation] = None
151
+ # TODO: add merge, rank fusion, rerank...
152
+
153
+
80
154
  ### Catalog
81
155
  @dataclass
82
156
  class CatalogExpression: