nucliadb 6.3.6.post4063__py3-none-any.whl → 6.3.7.post4068__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/api/v1/search.py +6 -39
- nucliadb/search/search/chat/ask.py +19 -26
- nucliadb/search/search/chat/query.py +6 -6
- nucliadb/search/search/find.py +21 -91
- nucliadb/search/search/find_merge.py +18 -9
- nucliadb/search/search/graph_strategy.py +9 -10
- nucliadb/search/search/merge.py +76 -65
- nucliadb/search/search/query.py +2 -455
- nucliadb/search/search/query_parser/fetcher.py +41 -0
- nucliadb/search/search/query_parser/models.py +82 -8
- nucliadb/search/search/query_parser/parsers/ask.py +77 -0
- nucliadb/search/search/query_parser/parsers/common.py +189 -0
- nucliadb/search/search/query_parser/parsers/find.py +175 -13
- nucliadb/search/search/query_parser/parsers/search.py +249 -0
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +176 -0
- nucliadb/search/search/rerankers.py +4 -2
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/METADATA +6 -6
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/RECORD +21 -17
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/WHEEL +0 -0
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.3.6.post4063.dist-info → nucliadb-6.3.7.post4068.dist-info}/top_level.txt +0 -0
nucliadb/search/api/v1/search.py
CHANGED
@@ -36,10 +36,9 @@ from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_quer
|
|
36
36
|
from nucliadb.search.search import cache
|
37
37
|
from nucliadb.search.search.exceptions import InvalidQueryError
|
38
38
|
from nucliadb.search.search.merge import merge_results
|
39
|
-
from nucliadb.search.search.
|
40
|
-
from nucliadb.search.search.query_parser.
|
39
|
+
from nucliadb.search.search.query_parser.parsers.search import parse_search
|
40
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
41
41
|
from nucliadb.search.search.utils import (
|
42
|
-
filter_hidden_resources,
|
43
42
|
min_score_from_payload,
|
44
43
|
min_score_from_query_params,
|
45
44
|
should_disable_vector_search,
|
@@ -270,53 +269,21 @@ async def search(
|
|
270
269
|
if should_disable_vector_search(item):
|
271
270
|
item.features.remove(SearchOptions.SEMANTIC)
|
272
271
|
|
273
|
-
|
274
|
-
|
275
|
-
kbid=kbid,
|
276
|
-
features=item.features,
|
277
|
-
query=item.query,
|
278
|
-
filter_expression=item.filter_expression,
|
279
|
-
faceted=item.faceted,
|
280
|
-
sort=item.sort,
|
281
|
-
top_k=item.top_k,
|
282
|
-
min_score=item.min_score,
|
283
|
-
old_filters=OldFilterParams(
|
284
|
-
label_filters=item.filters,
|
285
|
-
keyword_filters=[],
|
286
|
-
range_creation_start=item.range_creation_start,
|
287
|
-
range_creation_end=item.range_creation_end,
|
288
|
-
range_modification_start=item.range_modification_start,
|
289
|
-
range_modification_end=item.range_modification_end,
|
290
|
-
fields=item.fields,
|
291
|
-
),
|
292
|
-
user_vector=item.vector,
|
293
|
-
vectorset=item.vectorset,
|
294
|
-
with_duplicates=item.with_duplicates,
|
295
|
-
with_status=with_status,
|
296
|
-
with_synonyms=item.with_synonyms,
|
297
|
-
autofilter=item.autofilter,
|
298
|
-
security=item.security,
|
299
|
-
rephrase=item.rephrase,
|
300
|
-
hidden=await filter_hidden_resources(kbid, item.show_hidden),
|
301
|
-
rephrase_prompt=item.rephrase_prompt,
|
302
|
-
)
|
303
|
-
pb_query, incomplete_results, autofilters, _ = await query_parser.parse()
|
272
|
+
parsed = await parse_search(kbid, item)
|
273
|
+
pb_query, incomplete_results, autofilters, _ = await convert_retrieval_to_proto(parsed)
|
304
274
|
|
275
|
+
# We need to query all nodes
|
305
276
|
results, query_incomplete_results, queried_nodes = await node_query(kbid, Method.SEARCH, pb_query)
|
306
|
-
|
307
277
|
incomplete_results = incomplete_results or query_incomplete_results
|
308
278
|
|
309
279
|
# We need to merge
|
310
280
|
search_results = await merge_results(
|
311
281
|
results,
|
312
|
-
|
282
|
+
parsed.retrieval,
|
313
283
|
kbid=kbid,
|
314
284
|
show=item.show,
|
315
285
|
field_type_filter=item.field_type_filter,
|
316
286
|
extracted=item.extracted,
|
317
|
-
sort=query_parser.sort, # type: ignore
|
318
|
-
requested_relations=pb_query.relation_subgraph,
|
319
|
-
min_score=query_parser.min_score,
|
320
287
|
highlight=item.highlight,
|
321
288
|
)
|
322
289
|
|
@@ -61,8 +61,11 @@ from nucliadb.search.search.exceptions import (
|
|
61
61
|
)
|
62
62
|
from nucliadb.search.search.graph_strategy import get_graph_results
|
63
63
|
from nucliadb.search.search.metrics import RAGMetrics
|
64
|
-
from nucliadb.search.search.
|
65
|
-
from nucliadb.search.search.query_parser.
|
64
|
+
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
65
|
+
from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
|
66
|
+
from nucliadb.search.search.rerankers import (
|
67
|
+
get_reranker,
|
68
|
+
)
|
66
69
|
from nucliadb.search.utilities import get_predict
|
67
70
|
from nucliadb_models.search import (
|
68
71
|
AnswerAskResponseItem,
|
@@ -83,7 +86,6 @@ from nucliadb_models.search import (
|
|
83
86
|
JSONAskResponseItem,
|
84
87
|
KnowledgeboxFindResults,
|
85
88
|
MetadataAskResponseItem,
|
86
|
-
MinScore,
|
87
89
|
NucliaDBClientType,
|
88
90
|
PrequeriesAskResponseItem,
|
89
91
|
PreQueriesStrategy,
|
@@ -116,7 +118,7 @@ class RetrievalMatch:
|
|
116
118
|
@dataclasses.dataclass
|
117
119
|
class RetrievalResults:
|
118
120
|
main_query: KnowledgeboxFindResults
|
119
|
-
|
121
|
+
fetcher: Fetcher
|
120
122
|
main_query_weight: float
|
121
123
|
prequeries: Optional[list[PreQueryResult]] = None
|
122
124
|
best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
|
@@ -543,12 +545,12 @@ async def ask(
|
|
543
545
|
prequeries_results=err.prequeries,
|
544
546
|
)
|
545
547
|
|
546
|
-
|
548
|
+
# parse ask request generation parameters reusing the same fetcher as
|
549
|
+
# retrieval, to avoid multiple round trips to Predict API
|
550
|
+
generation = await parse_ask(kbid, ask_request, fetcher=retrieval_results.fetcher)
|
547
551
|
|
548
552
|
# Now we build the prompt context
|
549
553
|
with metrics.time("context_building"):
|
550
|
-
query_parser.max_tokens = ask_request.max_tokens # type: ignore
|
551
|
-
max_tokens_context = await query_parser.get_max_tokens_context()
|
552
554
|
prompt_context_builder = PromptContextBuilder(
|
553
555
|
kbid=kbid,
|
554
556
|
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
@@ -557,8 +559,8 @@ async def ask(
|
|
557
559
|
user_image_context=ask_request.extra_context_images,
|
558
560
|
strategies=ask_request.rag_strategies,
|
559
561
|
image_strategies=ask_request.rag_images_strategies,
|
560
|
-
max_context_characters=tokens_to_chars(
|
561
|
-
visual_llm=
|
562
|
+
max_context_characters=tokens_to_chars(generation.max_context_tokens),
|
563
|
+
visual_llm=generation.use_visual_llm,
|
562
564
|
)
|
563
565
|
(
|
564
566
|
prompt_context,
|
@@ -580,7 +582,7 @@ async def ask(
|
|
580
582
|
citations=ask_request.citations,
|
581
583
|
citation_threshold=ask_request.citation_threshold,
|
582
584
|
generative_model=ask_request.generative_model,
|
583
|
-
max_tokens=
|
585
|
+
max_tokens=generation.max_answer_tokens,
|
584
586
|
query_context_images=prompt_context_images,
|
585
587
|
json_schema=ask_request.answer_json_schema,
|
586
588
|
rerank_context=False,
|
@@ -741,7 +743,7 @@ async def retrieval_in_kb(
|
|
741
743
|
prequeries = parse_prequeries(ask_request)
|
742
744
|
graph_strategy = parse_graph_strategy(ask_request)
|
743
745
|
with metrics.time("retrieval"):
|
744
|
-
main_results, prequeries_results,
|
746
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
745
747
|
kbid=kbid,
|
746
748
|
query=main_query,
|
747
749
|
item=ask_request,
|
@@ -753,6 +755,7 @@ async def retrieval_in_kb(
|
|
753
755
|
)
|
754
756
|
|
755
757
|
if graph_strategy is not None:
|
758
|
+
reranker = get_reranker(parsed_query.retrieval.reranker)
|
756
759
|
graph_results, graph_request = await get_graph_results(
|
757
760
|
kbid=kbid,
|
758
761
|
query=main_query,
|
@@ -762,6 +765,7 @@ async def retrieval_in_kb(
|
|
762
765
|
origin=origin,
|
763
766
|
graph_strategy=graph_strategy,
|
764
767
|
metrics=metrics,
|
768
|
+
text_block_reranker=reranker,
|
765
769
|
)
|
766
770
|
|
767
771
|
if prequeries_results is None:
|
@@ -784,7 +788,7 @@ async def retrieval_in_kb(
|
|
784
788
|
return RetrievalResults(
|
785
789
|
main_query=main_results,
|
786
790
|
prequeries=prequeries_results,
|
787
|
-
|
791
|
+
fetcher=parsed_query.fetcher,
|
788
792
|
main_query_weight=main_query_weight,
|
789
793
|
best_matches=best_matches,
|
790
794
|
)
|
@@ -805,18 +809,7 @@ async def retrieval_in_resource(
|
|
805
809
|
return RetrievalResults(
|
806
810
|
main_query=KnowledgeboxFindResults(resources={}, min_score=None),
|
807
811
|
prequeries=None,
|
808
|
-
|
809
|
-
kbid=kbid,
|
810
|
-
features=[],
|
811
|
-
query="",
|
812
|
-
filter_expression=ask_request.filter_expression,
|
813
|
-
old_filters=OldFilterParams(
|
814
|
-
label_filters=ask_request.filters,
|
815
|
-
keyword_filters=ask_request.keyword_filters,
|
816
|
-
),
|
817
|
-
top_k=0,
|
818
|
-
min_score=MinScore(),
|
819
|
-
),
|
812
|
+
fetcher=fetcher_for_ask(kbid, ask_request),
|
820
813
|
main_query_weight=1.0,
|
821
814
|
)
|
822
815
|
|
@@ -836,7 +829,7 @@ async def retrieval_in_resource(
|
|
836
829
|
add_resource_filter(prequery.request, [resource])
|
837
830
|
|
838
831
|
with metrics.time("retrieval"):
|
839
|
-
main_results, prequeries_results,
|
832
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
840
833
|
kbid=kbid,
|
841
834
|
query=main_query,
|
842
835
|
item=ask_request,
|
@@ -859,7 +852,7 @@ async def retrieval_in_resource(
|
|
859
852
|
return RetrievalResults(
|
860
853
|
main_query=main_results,
|
861
854
|
prequeries=prequeries_results,
|
862
|
-
|
855
|
+
fetcher=parsed_query.fetcher,
|
863
856
|
main_query_weight=main_query_weight,
|
864
857
|
best_matches=best_matches,
|
865
858
|
)
|
@@ -29,7 +29,7 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
|
29
29
|
from nucliadb.search.search.find import find
|
30
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
31
|
from nucliadb.search.search.metrics import RAGMetrics
|
32
|
-
from nucliadb.search.search.
|
32
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery
|
33
33
|
from nucliadb.search.settings import settings
|
34
34
|
from nucliadb.search.utilities import get_predict
|
35
35
|
from nucliadb_models import filters
|
@@ -93,7 +93,7 @@ async def get_find_results(
|
|
93
93
|
origin: str,
|
94
94
|
metrics: RAGMetrics = RAGMetrics(),
|
95
95
|
prequeries_strategy: Optional[PreQueriesStrategy] = None,
|
96
|
-
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]],
|
96
|
+
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], ParsedQuery]:
|
97
97
|
prequeries_results = None
|
98
98
|
prefilter_queries_results = None
|
99
99
|
queries_results = None
|
@@ -223,10 +223,10 @@ async def run_main_query(
|
|
223
223
|
user: str,
|
224
224
|
origin: str,
|
225
225
|
metrics: RAGMetrics = RAGMetrics(),
|
226
|
-
) -> tuple[KnowledgeboxFindResults,
|
226
|
+
) -> tuple[KnowledgeboxFindResults, ParsedQuery]:
|
227
227
|
find_request = find_request_from_ask_request(item, query)
|
228
228
|
|
229
|
-
find_results, incomplete,
|
229
|
+
find_results, incomplete, parsed_query = await find(
|
230
230
|
kbid,
|
231
231
|
find_request,
|
232
232
|
ndb_client,
|
@@ -237,7 +237,7 @@ async def run_main_query(
|
|
237
237
|
)
|
238
238
|
if incomplete:
|
239
239
|
raise IncompleteFindResultsError()
|
240
|
-
return find_results,
|
240
|
+
return find_results, parsed_query
|
241
241
|
|
242
242
|
|
243
243
|
async def get_relations_results(
|
@@ -297,7 +297,7 @@ async def get_relations_results_from_entities(
|
|
297
297
|
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
298
298
|
return await merge_relations_results(
|
299
299
|
relations_results,
|
300
|
-
request.relation_subgraph,
|
300
|
+
request.relation_subgraph.entry_points,
|
301
301
|
only_with_metadata,
|
302
302
|
only_agentic_relations,
|
303
303
|
only_entity_to_entity,
|
nucliadb/search/search/find.py
CHANGED
@@ -18,7 +18,6 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import logging
|
21
|
-
from dataclasses import dataclass
|
22
21
|
from time import time
|
23
22
|
from typing import Optional
|
24
23
|
|
@@ -38,30 +37,22 @@ from nucliadb.search.search.hydrator import (
|
|
38
37
|
from nucliadb.search.search.metrics import (
|
39
38
|
RAGMetrics,
|
40
39
|
)
|
41
|
-
from nucliadb.search.search.
|
42
|
-
from nucliadb.search.search.query_parser.old_filters import OldFilterParams
|
40
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery
|
43
41
|
from nucliadb.search.search.query_parser.parsers import parse_find
|
42
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
44
43
|
from nucliadb.search.search.rank_fusion import (
|
45
|
-
RankFusionAlgorithm,
|
46
44
|
get_rank_fusion,
|
47
45
|
)
|
48
46
|
from nucliadb.search.search.rerankers import (
|
49
|
-
Reranker,
|
50
47
|
RerankingOptions,
|
51
48
|
get_reranker,
|
52
49
|
)
|
53
|
-
from nucliadb.search.search.utils import (
|
54
|
-
filter_hidden_resources,
|
55
|
-
min_score_from_payload,
|
56
|
-
should_disable_vector_search,
|
57
|
-
)
|
58
50
|
from nucliadb.search.settings import settings
|
59
51
|
from nucliadb_models.search import (
|
60
52
|
FindRequest,
|
61
53
|
KnowledgeboxFindResults,
|
62
54
|
MinScore,
|
63
55
|
NucliaDBClientType,
|
64
|
-
SearchOptions,
|
65
56
|
)
|
66
57
|
from nucliadb_utils.utilities import get_audit
|
67
58
|
|
@@ -76,7 +67,7 @@ async def find(
|
|
76
67
|
x_forwarded_for: str,
|
77
68
|
generative_model: Optional[str] = None,
|
78
69
|
metrics: RAGMetrics = RAGMetrics(),
|
79
|
-
) -> tuple[KnowledgeboxFindResults, bool,
|
70
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
80
71
|
external_index_manager = await get_external_index_manager(kbid=kbid)
|
81
72
|
if external_index_manager is not None:
|
82
73
|
return await _external_index_retrieval(
|
@@ -99,15 +90,17 @@ async def _index_node_retrieval(
|
|
99
90
|
x_forwarded_for: str,
|
100
91
|
generative_model: Optional[str] = None,
|
101
92
|
metrics: RAGMetrics = RAGMetrics(),
|
102
|
-
) -> tuple[KnowledgeboxFindResults, bool,
|
93
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
103
94
|
audit = get_audit()
|
104
95
|
start_time = time()
|
105
96
|
|
106
|
-
query_parser, rank_fusion, reranker = await query_parser_from_find_request(
|
107
|
-
kbid, item, generative_model=generative_model
|
108
|
-
)
|
109
97
|
with metrics.time("query_parse"):
|
110
|
-
|
98
|
+
parsed = await parse_find(kbid, item, generative_model=generative_model)
|
99
|
+
rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion)
|
100
|
+
reranker = get_reranker(parsed.retrieval.reranker)
|
101
|
+
pb_query, incomplete_results, autofilters, rephrased_query = await convert_retrieval_to_proto(
|
102
|
+
parsed
|
103
|
+
)
|
111
104
|
|
112
105
|
with metrics.time("node_query"):
|
113
106
|
results, query_incomplete_results, queried_nodes = await node_query(
|
@@ -119,13 +112,10 @@ async def _index_node_retrieval(
|
|
119
112
|
with metrics.time("results_merge"):
|
120
113
|
search_results = await build_find_response(
|
121
114
|
results,
|
115
|
+
retrieval=parsed.retrieval,
|
122
116
|
kbid=kbid,
|
123
117
|
query=pb_query.body,
|
124
118
|
rephrased_query=rephrased_query,
|
125
|
-
relation_subgraph_query=pb_query.relation_subgraph,
|
126
|
-
min_score_bm25=pb_query.min_score_bm25,
|
127
|
-
min_score_semantic=pb_query.min_score_semantic,
|
128
|
-
top_k=item.top_k,
|
129
119
|
show=item.show,
|
130
120
|
extracted=item.extracted,
|
131
121
|
field_type_filter=item.field_type_filter,
|
@@ -182,7 +172,7 @@ async def _index_node_retrieval(
|
|
182
172
|
},
|
183
173
|
)
|
184
174
|
|
185
|
-
return search_results, incomplete_results,
|
175
|
+
return search_results, incomplete_results, parsed
|
186
176
|
|
187
177
|
|
188
178
|
async def _external_index_retrieval(
|
@@ -190,15 +180,14 @@ async def _external_index_retrieval(
|
|
190
180
|
item: FindRequest,
|
191
181
|
external_index_manager: ExternalIndexManager,
|
192
182
|
generative_model: Optional[str] = None,
|
193
|
-
) -> tuple[KnowledgeboxFindResults, bool,
|
183
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
194
184
|
"""
|
195
185
|
Parse the query, query the external index, and hydrate the results.
|
196
186
|
"""
|
197
187
|
# Parse query
|
198
|
-
|
199
|
-
|
200
|
-
)
|
201
|
-
search_request, incomplete_results, _, rephrased_query = await query_parser.parse()
|
188
|
+
parsed = await parse_find(kbid, item, generative_model=generative_model)
|
189
|
+
reranker = get_reranker(parsed.retrieval.reranker)
|
190
|
+
search_request, incomplete_results, _, rephrased_query = await convert_retrieval_to_proto(parsed)
|
202
191
|
|
203
192
|
# Query index
|
204
193
|
query_results = await external_index_manager.query(search_request) # noqa
|
@@ -218,13 +207,15 @@ async def _external_index_retrieval(
|
|
218
207
|
kbid=kbid,
|
219
208
|
query=search_request.body,
|
220
209
|
),
|
221
|
-
top_k=
|
210
|
+
top_k=parsed.retrieval.top_k,
|
222
211
|
)
|
223
212
|
find_resources = compose_find_resources(text_blocks, resources)
|
224
213
|
|
225
214
|
results_min_score = MinScore(
|
226
215
|
bm25=0,
|
227
|
-
semantic=
|
216
|
+
semantic=parsed.retrieval.query.semantic.min_score
|
217
|
+
if parsed.retrieval.query.semantic is not None
|
218
|
+
else 0.0,
|
228
219
|
)
|
229
220
|
retrieval_results = KnowledgeboxFindResults(
|
230
221
|
resources=find_resources,
|
@@ -242,65 +233,4 @@ async def _external_index_retrieval(
|
|
242
233
|
nodes=None,
|
243
234
|
)
|
244
235
|
|
245
|
-
return retrieval_results, incomplete_results,
|
246
|
-
|
247
|
-
|
248
|
-
@dataclass
|
249
|
-
class ScoredParagraph:
|
250
|
-
id: str
|
251
|
-
score: float
|
252
|
-
|
253
|
-
|
254
|
-
async def query_parser_from_find_request(
|
255
|
-
kbid: str, item: FindRequest, *, generative_model: Optional[str] = None
|
256
|
-
) -> tuple[QueryParser, RankFusionAlgorithm, Reranker]:
|
257
|
-
item.min_score = min_score_from_payload(item.min_score)
|
258
|
-
|
259
|
-
if SearchOptions.SEMANTIC in item.features:
|
260
|
-
if should_disable_vector_search(item):
|
261
|
-
item.features.remove(SearchOptions.SEMANTIC)
|
262
|
-
|
263
|
-
hidden = await filter_hidden_resources(kbid, item.show_hidden)
|
264
|
-
|
265
|
-
# XXX this is becoming the new /find query parsing, this should be moved to
|
266
|
-
# a cleaner abstraction
|
267
|
-
|
268
|
-
parsed = await parse_find(kbid, item)
|
269
|
-
|
270
|
-
rank_fusion = get_rank_fusion(parsed.rank_fusion)
|
271
|
-
reranker = get_reranker(parsed.reranker)
|
272
|
-
|
273
|
-
query_parser = QueryParser(
|
274
|
-
kbid=kbid,
|
275
|
-
features=item.features,
|
276
|
-
query=item.query,
|
277
|
-
query_entities=item.query_entities,
|
278
|
-
filter_expression=item.filter_expression,
|
279
|
-
faceted=None,
|
280
|
-
sort=None,
|
281
|
-
top_k=item.top_k,
|
282
|
-
min_score=item.min_score,
|
283
|
-
old_filters=OldFilterParams(
|
284
|
-
label_filters=item.filters,
|
285
|
-
keyword_filters=item.keyword_filters,
|
286
|
-
range_creation_start=item.range_creation_start,
|
287
|
-
range_creation_end=item.range_creation_end,
|
288
|
-
range_modification_start=item.range_modification_start,
|
289
|
-
range_modification_end=item.range_modification_end,
|
290
|
-
fields=item.fields,
|
291
|
-
key_filters=item.resource_filters,
|
292
|
-
),
|
293
|
-
user_vector=item.vector,
|
294
|
-
vectorset=item.vectorset,
|
295
|
-
with_duplicates=item.with_duplicates,
|
296
|
-
with_synonyms=item.with_synonyms,
|
297
|
-
autofilter=item.autofilter,
|
298
|
-
security=item.security,
|
299
|
-
generative_model=generative_model,
|
300
|
-
rephrase=item.rephrase,
|
301
|
-
rephrase_prompt=item.rephrase_prompt,
|
302
|
-
hidden=hidden,
|
303
|
-
rank_fusion=rank_fusion,
|
304
|
-
reranker=reranker,
|
305
|
-
)
|
306
|
-
return (query_parser, rank_fusion, reranker)
|
236
|
+
return retrieval_results, incomplete_results, parsed
|
@@ -32,6 +32,7 @@ from nucliadb.search.search.hydrator import (
|
|
32
32
|
text_block_to_find_paragraph,
|
33
33
|
)
|
34
34
|
from nucliadb.search.search.merge import merge_relations_results
|
35
|
+
from nucliadb.search.search.query_parser.models import UnitRetrieval
|
35
36
|
from nucliadb.search.search.rank_fusion import RankFusionAlgorithm
|
36
37
|
from nucliadb.search.search.rerankers import (
|
37
38
|
RerankableItem,
|
@@ -51,7 +52,6 @@ from nucliadb_models.search import (
|
|
51
52
|
)
|
52
53
|
from nucliadb_protos.nodereader_pb2 import (
|
53
54
|
DocumentScored,
|
54
|
-
EntitiesSubgraphRequest,
|
55
55
|
ParagraphResult,
|
56
56
|
ParagraphSearchResponse,
|
57
57
|
RelationSearchResponse,
|
@@ -72,13 +72,10 @@ FIND_FETCH_OPS_DISTRIBUTION = metrics.Histogram(
|
|
72
72
|
async def build_find_response(
|
73
73
|
search_responses: list[SearchResponse],
|
74
74
|
*,
|
75
|
+
retrieval: UnitRetrieval,
|
75
76
|
kbid: str,
|
76
77
|
query: str,
|
77
78
|
rephrased_query: Optional[str],
|
78
|
-
relation_subgraph_query: EntitiesSubgraphRequest,
|
79
|
-
top_k: int,
|
80
|
-
min_score_bm25: float,
|
81
|
-
min_score_semantic: float,
|
82
79
|
rank_fusion_algorithm: RankFusionAlgorithm,
|
83
80
|
reranker: Reranker,
|
84
81
|
show: list[ResourceProperties] = [],
|
@@ -86,6 +83,15 @@ async def build_find_response(
|
|
86
83
|
field_type_filter: list[FieldTypeName] = [],
|
87
84
|
highlight: bool = False,
|
88
85
|
) -> KnowledgeboxFindResults:
|
86
|
+
# XXX: we shouldn't need a min score that we haven't used. Previous
|
87
|
+
# implementations got this value from the proto request (i.e., default to 0)
|
88
|
+
min_score_bm25 = 0.0
|
89
|
+
if retrieval.query.keyword is not None:
|
90
|
+
min_score_bm25 = retrieval.query.keyword.min_score
|
91
|
+
min_score_semantic = 0.0
|
92
|
+
if retrieval.query.semantic is not None:
|
93
|
+
min_score_semantic = retrieval.query.semantic.min_score
|
94
|
+
|
89
95
|
# merge
|
90
96
|
search_response = merge_shard_responses(search_responses)
|
91
97
|
|
@@ -112,7 +118,7 @@ async def build_find_response(
|
|
112
118
|
assert reranker.window is not None, "Reranker definition must enforce this condition"
|
113
119
|
text_blocks_page, next_page = cut_page(merged_text_blocks, reranker.window)
|
114
120
|
else:
|
115
|
-
text_blocks_page, next_page = cut_page(merged_text_blocks, top_k)
|
121
|
+
text_blocks_page, next_page = cut_page(merged_text_blocks, retrieval.top_k)
|
116
122
|
|
117
123
|
# hydrate and rerank
|
118
124
|
resource_hydration_options = ResourceHydrationOptions(
|
@@ -130,11 +136,14 @@ async def build_find_response(
|
|
130
136
|
text_block_hydration_options=text_block_hydration_options,
|
131
137
|
reranker=reranker,
|
132
138
|
reranking_options=reranking_options,
|
133
|
-
top_k=top_k,
|
139
|
+
top_k=retrieval.top_k,
|
134
140
|
)
|
135
141
|
|
136
142
|
# build relations graph
|
137
|
-
|
143
|
+
entry_points = []
|
144
|
+
if retrieval.query.relation is not None:
|
145
|
+
entry_points = retrieval.query.relation.detected_entities
|
146
|
+
relations = await merge_relations_results([search_response.relation], entry_points)
|
138
147
|
|
139
148
|
# compose response
|
140
149
|
find_resources = compose_find_resources(text_blocks, resources)
|
@@ -150,7 +159,7 @@ async def build_find_response(
|
|
150
159
|
relations=relations,
|
151
160
|
total=total_paragraphs,
|
152
161
|
page_number=0, # Bw/c with pagination
|
153
|
-
page_size=top_k,
|
162
|
+
page_size=retrieval.top_k,
|
154
163
|
next_page=next_page,
|
155
164
|
min_score=MinScore(bm25=_round(min_score_bm25), semantic=_round(min_score_semantic)),
|
156
165
|
)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
#
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
-
|
20
19
|
import heapq
|
21
20
|
import json
|
22
21
|
from collections import defaultdict
|
@@ -38,14 +37,16 @@ from nucliadb.search.search.chat.query import (
|
|
38
37
|
find_request_from_ask_request,
|
39
38
|
get_relations_results_from_entities,
|
40
39
|
)
|
41
|
-
from nucliadb.search.search.find import query_parser_from_find_request
|
42
40
|
from nucliadb.search.search.find_merge import (
|
43
41
|
compose_find_resources,
|
44
42
|
hydrate_and_rerank,
|
45
43
|
)
|
46
44
|
from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
|
47
45
|
from nucliadb.search.search.metrics import RAGMetrics
|
48
|
-
from nucliadb.search.search.rerankers import
|
46
|
+
from nucliadb.search.search.rerankers import (
|
47
|
+
Reranker,
|
48
|
+
RerankingOptions,
|
49
|
+
)
|
49
50
|
from nucliadb.search.utilities import get_predict
|
50
51
|
from nucliadb_models.common import FieldTypeName
|
51
52
|
from nucliadb_models.internal.predict import (
|
@@ -303,6 +304,7 @@ async def get_graph_results(
|
|
303
304
|
user: str,
|
304
305
|
origin: str,
|
305
306
|
graph_strategy: GraphStrategy,
|
307
|
+
text_block_reranker: Reranker,
|
306
308
|
generative_model: Optional[str] = None,
|
307
309
|
metrics: RAGMetrics = RAGMetrics(),
|
308
310
|
shards: Optional[list[str]] = None,
|
@@ -419,19 +421,16 @@ async def get_graph_results(
|
|
419
421
|
# Get the text blocks of the paragraphs that contain the top relations
|
420
422
|
with metrics.time("graph_strat_build_response"):
|
421
423
|
find_request = find_request_from_ask_request(item, query)
|
422
|
-
query_parser, rank_fusion, reranker = await query_parser_from_find_request(
|
423
|
-
kbid, find_request, generative_model=generative_model
|
424
|
-
)
|
425
424
|
find_results = await build_graph_response(
|
426
425
|
kbid=kbid,
|
427
426
|
query=query,
|
428
427
|
final_relations=relations,
|
429
428
|
scores=scores,
|
430
429
|
top_k=graph_strategy.top_k,
|
431
|
-
reranker=
|
432
|
-
show=
|
433
|
-
extracted=
|
434
|
-
field_type_filter=
|
430
|
+
reranker=text_block_reranker,
|
431
|
+
show=item.show,
|
432
|
+
extracted=item.extracted,
|
433
|
+
field_type_filter=item.field_type_filter,
|
435
434
|
relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
|
436
435
|
)
|
437
436
|
return find_results, find_request
|