nucliadb 6.3.7.post4066__py3-none-any.whl → 6.3.7.post4071__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/api/v1/search.py +6 -39
- nucliadb/search/search/chat/ask.py +19 -26
- nucliadb/search/search/chat/query.py +7 -9
- nucliadb/search/search/find.py +22 -97
- nucliadb/search/search/find_merge.py +18 -9
- nucliadb/search/search/graph_strategy.py +9 -10
- nucliadb/search/search/merge.py +76 -65
- nucliadb/search/search/query.py +2 -455
- nucliadb/search/search/query_parser/fetcher.py +41 -0
- nucliadb/search/search/query_parser/models.py +82 -8
- nucliadb/search/search/query_parser/parsers/ask.py +77 -0
- nucliadb/search/search/query_parser/parsers/common.py +189 -0
- nucliadb/search/search/query_parser/parsers/find.py +174 -13
- nucliadb/search/search/query_parser/parsers/search.py +249 -0
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +176 -0
- nucliadb/search/search/rerankers.py +4 -2
- {nucliadb-6.3.7.post4066.dist-info → nucliadb-6.3.7.post4071.dist-info}/METADATA +6 -6
- {nucliadb-6.3.7.post4066.dist-info → nucliadb-6.3.7.post4071.dist-info}/RECORD +21 -17
- {nucliadb-6.3.7.post4066.dist-info → nucliadb-6.3.7.post4071.dist-info}/WHEEL +0 -0
- {nucliadb-6.3.7.post4066.dist-info → nucliadb-6.3.7.post4071.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.3.7.post4066.dist-info → nucliadb-6.3.7.post4071.dist-info}/top_level.txt +0 -0
nucliadb/search/api/v1/search.py
CHANGED
@@ -36,10 +36,9 @@ from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_quer
|
|
36
36
|
from nucliadb.search.search import cache
|
37
37
|
from nucliadb.search.search.exceptions import InvalidQueryError
|
38
38
|
from nucliadb.search.search.merge import merge_results
|
39
|
-
from nucliadb.search.search.
|
40
|
-
from nucliadb.search.search.query_parser.
|
39
|
+
from nucliadb.search.search.query_parser.parsers.search import parse_search
|
40
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
41
41
|
from nucliadb.search.search.utils import (
|
42
|
-
filter_hidden_resources,
|
43
42
|
min_score_from_payload,
|
44
43
|
min_score_from_query_params,
|
45
44
|
should_disable_vector_search,
|
@@ -270,53 +269,21 @@ async def search(
|
|
270
269
|
if should_disable_vector_search(item):
|
271
270
|
item.features.remove(SearchOptions.SEMANTIC)
|
272
271
|
|
273
|
-
|
274
|
-
|
275
|
-
kbid=kbid,
|
276
|
-
features=item.features,
|
277
|
-
query=item.query,
|
278
|
-
filter_expression=item.filter_expression,
|
279
|
-
faceted=item.faceted,
|
280
|
-
sort=item.sort,
|
281
|
-
top_k=item.top_k,
|
282
|
-
min_score=item.min_score,
|
283
|
-
old_filters=OldFilterParams(
|
284
|
-
label_filters=item.filters,
|
285
|
-
keyword_filters=[],
|
286
|
-
range_creation_start=item.range_creation_start,
|
287
|
-
range_creation_end=item.range_creation_end,
|
288
|
-
range_modification_start=item.range_modification_start,
|
289
|
-
range_modification_end=item.range_modification_end,
|
290
|
-
fields=item.fields,
|
291
|
-
),
|
292
|
-
user_vector=item.vector,
|
293
|
-
vectorset=item.vectorset,
|
294
|
-
with_duplicates=item.with_duplicates,
|
295
|
-
with_status=with_status,
|
296
|
-
with_synonyms=item.with_synonyms,
|
297
|
-
autofilter=item.autofilter,
|
298
|
-
security=item.security,
|
299
|
-
rephrase=item.rephrase,
|
300
|
-
hidden=await filter_hidden_resources(kbid, item.show_hidden),
|
301
|
-
rephrase_prompt=item.rephrase_prompt,
|
302
|
-
)
|
303
|
-
pb_query, incomplete_results, autofilters, _ = await query_parser.parse()
|
272
|
+
parsed = await parse_search(kbid, item)
|
273
|
+
pb_query, incomplete_results, autofilters, _ = await convert_retrieval_to_proto(parsed)
|
304
274
|
|
275
|
+
# We need to query all nodes
|
305
276
|
results, query_incomplete_results, queried_nodes = await node_query(kbid, Method.SEARCH, pb_query)
|
306
|
-
|
307
277
|
incomplete_results = incomplete_results or query_incomplete_results
|
308
278
|
|
309
279
|
# We need to merge
|
310
280
|
search_results = await merge_results(
|
311
281
|
results,
|
312
|
-
|
282
|
+
parsed.retrieval,
|
313
283
|
kbid=kbid,
|
314
284
|
show=item.show,
|
315
285
|
field_type_filter=item.field_type_filter,
|
316
286
|
extracted=item.extracted,
|
317
|
-
sort=query_parser.sort, # type: ignore
|
318
|
-
requested_relations=pb_query.relation_subgraph,
|
319
|
-
min_score=query_parser.min_score,
|
320
287
|
highlight=item.highlight,
|
321
288
|
)
|
322
289
|
|
@@ -61,8 +61,11 @@ from nucliadb.search.search.exceptions import (
|
|
61
61
|
)
|
62
62
|
from nucliadb.search.search.graph_strategy import get_graph_results
|
63
63
|
from nucliadb.search.search.metrics import RAGMetrics
|
64
|
-
from nucliadb.search.search.
|
65
|
-
from nucliadb.search.search.query_parser.
|
64
|
+
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
65
|
+
from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
|
66
|
+
from nucliadb.search.search.rerankers import (
|
67
|
+
get_reranker,
|
68
|
+
)
|
66
69
|
from nucliadb.search.utilities import get_predict
|
67
70
|
from nucliadb_models.search import (
|
68
71
|
AnswerAskResponseItem,
|
@@ -83,7 +86,6 @@ from nucliadb_models.search import (
|
|
83
86
|
JSONAskResponseItem,
|
84
87
|
KnowledgeboxFindResults,
|
85
88
|
MetadataAskResponseItem,
|
86
|
-
MinScore,
|
87
89
|
NucliaDBClientType,
|
88
90
|
PrequeriesAskResponseItem,
|
89
91
|
PreQueriesStrategy,
|
@@ -116,7 +118,7 @@ class RetrievalMatch:
|
|
116
118
|
@dataclasses.dataclass
|
117
119
|
class RetrievalResults:
|
118
120
|
main_query: KnowledgeboxFindResults
|
119
|
-
|
121
|
+
fetcher: Fetcher
|
120
122
|
main_query_weight: float
|
121
123
|
prequeries: Optional[list[PreQueryResult]] = None
|
122
124
|
best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
|
@@ -543,12 +545,12 @@ async def ask(
|
|
543
545
|
prequeries_results=err.prequeries,
|
544
546
|
)
|
545
547
|
|
546
|
-
|
548
|
+
# parse ask request generation parameters reusing the same fetcher as
|
549
|
+
# retrieval, to avoid multiple round trips to Predict API
|
550
|
+
generation = await parse_ask(kbid, ask_request, fetcher=retrieval_results.fetcher)
|
547
551
|
|
548
552
|
# Now we build the prompt context
|
549
553
|
with metrics.time("context_building"):
|
550
|
-
query_parser.max_tokens = ask_request.max_tokens # type: ignore
|
551
|
-
max_tokens_context = await query_parser.get_max_tokens_context()
|
552
554
|
prompt_context_builder = PromptContextBuilder(
|
553
555
|
kbid=kbid,
|
554
556
|
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
@@ -557,8 +559,8 @@ async def ask(
|
|
557
559
|
user_image_context=ask_request.extra_context_images,
|
558
560
|
strategies=ask_request.rag_strategies,
|
559
561
|
image_strategies=ask_request.rag_images_strategies,
|
560
|
-
max_context_characters=tokens_to_chars(
|
561
|
-
visual_llm=
|
562
|
+
max_context_characters=tokens_to_chars(generation.max_context_tokens),
|
563
|
+
visual_llm=generation.use_visual_llm,
|
562
564
|
)
|
563
565
|
(
|
564
566
|
prompt_context,
|
@@ -580,7 +582,7 @@ async def ask(
|
|
580
582
|
citations=ask_request.citations,
|
581
583
|
citation_threshold=ask_request.citation_threshold,
|
582
584
|
generative_model=ask_request.generative_model,
|
583
|
-
max_tokens=
|
585
|
+
max_tokens=generation.max_answer_tokens,
|
584
586
|
query_context_images=prompt_context_images,
|
585
587
|
json_schema=ask_request.answer_json_schema,
|
586
588
|
rerank_context=False,
|
@@ -741,7 +743,7 @@ async def retrieval_in_kb(
|
|
741
743
|
prequeries = parse_prequeries(ask_request)
|
742
744
|
graph_strategy = parse_graph_strategy(ask_request)
|
743
745
|
with metrics.time("retrieval"):
|
744
|
-
main_results, prequeries_results,
|
746
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
745
747
|
kbid=kbid,
|
746
748
|
query=main_query,
|
747
749
|
item=ask_request,
|
@@ -753,6 +755,7 @@ async def retrieval_in_kb(
|
|
753
755
|
)
|
754
756
|
|
755
757
|
if graph_strategy is not None:
|
758
|
+
reranker = get_reranker(parsed_query.retrieval.reranker)
|
756
759
|
graph_results, graph_request = await get_graph_results(
|
757
760
|
kbid=kbid,
|
758
761
|
query=main_query,
|
@@ -762,6 +765,7 @@ async def retrieval_in_kb(
|
|
762
765
|
origin=origin,
|
763
766
|
graph_strategy=graph_strategy,
|
764
767
|
metrics=metrics,
|
768
|
+
text_block_reranker=reranker,
|
765
769
|
)
|
766
770
|
|
767
771
|
if prequeries_results is None:
|
@@ -784,7 +788,7 @@ async def retrieval_in_kb(
|
|
784
788
|
return RetrievalResults(
|
785
789
|
main_query=main_results,
|
786
790
|
prequeries=prequeries_results,
|
787
|
-
|
791
|
+
fetcher=parsed_query.fetcher,
|
788
792
|
main_query_weight=main_query_weight,
|
789
793
|
best_matches=best_matches,
|
790
794
|
)
|
@@ -805,18 +809,7 @@ async def retrieval_in_resource(
|
|
805
809
|
return RetrievalResults(
|
806
810
|
main_query=KnowledgeboxFindResults(resources={}, min_score=None),
|
807
811
|
prequeries=None,
|
808
|
-
|
809
|
-
kbid=kbid,
|
810
|
-
features=[],
|
811
|
-
query="",
|
812
|
-
filter_expression=ask_request.filter_expression,
|
813
|
-
old_filters=OldFilterParams(
|
814
|
-
label_filters=ask_request.filters,
|
815
|
-
keyword_filters=ask_request.keyword_filters,
|
816
|
-
),
|
817
|
-
top_k=0,
|
818
|
-
min_score=MinScore(),
|
819
|
-
),
|
812
|
+
fetcher=fetcher_for_ask(kbid, ask_request),
|
820
813
|
main_query_weight=1.0,
|
821
814
|
)
|
822
815
|
|
@@ -836,7 +829,7 @@ async def retrieval_in_resource(
|
|
836
829
|
add_resource_filter(prequery.request, [resource])
|
837
830
|
|
838
831
|
with metrics.time("retrieval"):
|
839
|
-
main_results, prequeries_results,
|
832
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
840
833
|
kbid=kbid,
|
841
834
|
query=main_query,
|
842
835
|
item=ask_request,
|
@@ -859,7 +852,7 @@ async def retrieval_in_resource(
|
|
859
852
|
return RetrievalResults(
|
860
853
|
main_query=main_results,
|
861
854
|
prequeries=prequeries_results,
|
862
|
-
|
855
|
+
fetcher=parsed_query.fetcher,
|
863
856
|
main_query_weight=main_query_weight,
|
864
857
|
best_matches=best_matches,
|
865
858
|
)
|
@@ -29,7 +29,7 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
|
29
29
|
from nucliadb.search.search.find import find
|
30
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
31
|
from nucliadb.search.search.metrics import RAGMetrics
|
32
|
-
from nucliadb.search.search.
|
32
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery
|
33
33
|
from nucliadb.search.settings import settings
|
34
34
|
from nucliadb.search.utilities import get_predict
|
35
35
|
from nucliadb_models import filters
|
@@ -93,7 +93,7 @@ async def get_find_results(
|
|
93
93
|
origin: str,
|
94
94
|
metrics: RAGMetrics = RAGMetrics(),
|
95
95
|
prequeries_strategy: Optional[PreQueriesStrategy] = None,
|
96
|
-
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]],
|
96
|
+
) -> tuple[KnowledgeboxFindResults, Optional[list[PreQueryResult]], ParsedQuery]:
|
97
97
|
prequeries_results = None
|
98
98
|
prefilter_queries_results = None
|
99
99
|
queries_results = None
|
@@ -108,7 +108,6 @@ async def get_find_results(
|
|
108
108
|
x_ndb_client=ndb_client,
|
109
109
|
x_nucliadb_user=user,
|
110
110
|
x_forwarded_for=origin,
|
111
|
-
generative_model=item.generative_model,
|
112
111
|
metrics=metrics,
|
113
112
|
)
|
114
113
|
prefilter_matching_resources = {
|
@@ -210,6 +209,7 @@ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
|
|
210
209
|
# We don't support pagination, we always get the top_k results.
|
211
210
|
find_request.top_k = item.top_k
|
212
211
|
find_request.show_hidden = item.show_hidden
|
212
|
+
find_request.generative_model = item.generative_model
|
213
213
|
|
214
214
|
# this executes the model validators, that can tweak some fields
|
215
215
|
return FindRequest.model_validate(find_request)
|
@@ -223,21 +223,20 @@ async def run_main_query(
|
|
223
223
|
user: str,
|
224
224
|
origin: str,
|
225
225
|
metrics: RAGMetrics = RAGMetrics(),
|
226
|
-
) -> tuple[KnowledgeboxFindResults,
|
226
|
+
) -> tuple[KnowledgeboxFindResults, ParsedQuery]:
|
227
227
|
find_request = find_request_from_ask_request(item, query)
|
228
228
|
|
229
|
-
find_results, incomplete,
|
229
|
+
find_results, incomplete, parsed_query = await find(
|
230
230
|
kbid,
|
231
231
|
find_request,
|
232
232
|
ndb_client,
|
233
233
|
user,
|
234
234
|
origin,
|
235
|
-
generative_model=item.generative_model,
|
236
235
|
metrics=metrics,
|
237
236
|
)
|
238
237
|
if incomplete:
|
239
238
|
raise IncompleteFindResultsError()
|
240
|
-
return find_results,
|
239
|
+
return find_results, parsed_query
|
241
240
|
|
242
241
|
|
243
242
|
async def get_relations_results(
|
@@ -297,7 +296,7 @@ async def get_relations_results_from_entities(
|
|
297
296
|
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
298
297
|
return await merge_relations_results(
|
299
298
|
relations_results,
|
300
|
-
request.relation_subgraph,
|
299
|
+
request.relation_subgraph.entry_points,
|
301
300
|
only_with_metadata,
|
302
301
|
only_agentic_relations,
|
303
302
|
only_entity_to_entity,
|
@@ -469,7 +468,6 @@ async def run_prequeries(
|
|
469
468
|
x_ndb_client,
|
470
469
|
x_nucliadb_user,
|
471
470
|
x_forwarded_for,
|
472
|
-
generative_model=generative_model,
|
473
471
|
metrics=metrics,
|
474
472
|
)
|
475
473
|
return prequery, find_results
|
nucliadb/search/search/find.py
CHANGED
@@ -18,9 +18,7 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import logging
|
21
|
-
from dataclasses import dataclass
|
22
21
|
from time import time
|
23
|
-
from typing import Optional
|
24
22
|
|
25
23
|
from nucliadb.common.external_index_providers.base import ExternalIndexManager
|
26
24
|
from nucliadb.common.external_index_providers.manager import get_external_index_manager
|
@@ -38,30 +36,22 @@ from nucliadb.search.search.hydrator import (
|
|
38
36
|
from nucliadb.search.search.metrics import (
|
39
37
|
RAGMetrics,
|
40
38
|
)
|
41
|
-
from nucliadb.search.search.
|
42
|
-
from nucliadb.search.search.query_parser.old_filters import OldFilterParams
|
39
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery
|
43
40
|
from nucliadb.search.search.query_parser.parsers import parse_find
|
41
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
44
42
|
from nucliadb.search.search.rank_fusion import (
|
45
|
-
RankFusionAlgorithm,
|
46
43
|
get_rank_fusion,
|
47
44
|
)
|
48
45
|
from nucliadb.search.search.rerankers import (
|
49
|
-
Reranker,
|
50
46
|
RerankingOptions,
|
51
47
|
get_reranker,
|
52
48
|
)
|
53
|
-
from nucliadb.search.search.utils import (
|
54
|
-
filter_hidden_resources,
|
55
|
-
min_score_from_payload,
|
56
|
-
should_disable_vector_search,
|
57
|
-
)
|
58
49
|
from nucliadb.search.settings import settings
|
59
50
|
from nucliadb_models.search import (
|
60
51
|
FindRequest,
|
61
52
|
KnowledgeboxFindResults,
|
62
53
|
MinScore,
|
63
54
|
NucliaDBClientType,
|
64
|
-
SearchOptions,
|
65
55
|
)
|
66
56
|
from nucliadb_utils.utilities import get_audit
|
67
57
|
|
@@ -74,20 +64,18 @@ async def find(
|
|
74
64
|
x_ndb_client: NucliaDBClientType,
|
75
65
|
x_nucliadb_user: str,
|
76
66
|
x_forwarded_for: str,
|
77
|
-
generative_model: Optional[str] = None,
|
78
67
|
metrics: RAGMetrics = RAGMetrics(),
|
79
|
-
) -> tuple[KnowledgeboxFindResults, bool,
|
68
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
80
69
|
external_index_manager = await get_external_index_manager(kbid=kbid)
|
81
70
|
if external_index_manager is not None:
|
82
71
|
return await _external_index_retrieval(
|
83
72
|
kbid,
|
84
73
|
item,
|
85
74
|
external_index_manager,
|
86
|
-
generative_model,
|
87
75
|
)
|
88
76
|
else:
|
89
77
|
return await _index_node_retrieval(
|
90
|
-
kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for,
|
78
|
+
kbid, item, x_ndb_client, x_nucliadb_user, x_forwarded_for, metrics
|
91
79
|
)
|
92
80
|
|
93
81
|
|
@@ -97,17 +85,18 @@ async def _index_node_retrieval(
|
|
97
85
|
x_ndb_client: NucliaDBClientType,
|
98
86
|
x_nucliadb_user: str,
|
99
87
|
x_forwarded_for: str,
|
100
|
-
generative_model: Optional[str] = None,
|
101
88
|
metrics: RAGMetrics = RAGMetrics(),
|
102
|
-
) -> tuple[KnowledgeboxFindResults, bool,
|
89
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
103
90
|
audit = get_audit()
|
104
91
|
start_time = time()
|
105
92
|
|
106
|
-
query_parser, rank_fusion, reranker = await query_parser_from_find_request(
|
107
|
-
kbid, item, generative_model=generative_model
|
108
|
-
)
|
109
93
|
with metrics.time("query_parse"):
|
110
|
-
|
94
|
+
parsed = await parse_find(kbid, item)
|
95
|
+
rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion)
|
96
|
+
reranker = get_reranker(parsed.retrieval.reranker)
|
97
|
+
pb_query, incomplete_results, autofilters, rephrased_query = await convert_retrieval_to_proto(
|
98
|
+
parsed
|
99
|
+
)
|
111
100
|
|
112
101
|
with metrics.time("node_query"):
|
113
102
|
results, query_incomplete_results, queried_nodes = await node_query(
|
@@ -119,13 +108,10 @@ async def _index_node_retrieval(
|
|
119
108
|
with metrics.time("results_merge"):
|
120
109
|
search_results = await build_find_response(
|
121
110
|
results,
|
111
|
+
retrieval=parsed.retrieval,
|
122
112
|
kbid=kbid,
|
123
113
|
query=pb_query.body,
|
124
114
|
rephrased_query=rephrased_query,
|
125
|
-
relation_subgraph_query=pb_query.relation_subgraph,
|
126
|
-
min_score_bm25=pb_query.min_score_bm25,
|
127
|
-
min_score_semantic=pb_query.min_score_semantic,
|
128
|
-
top_k=item.top_k,
|
129
115
|
show=item.show,
|
130
116
|
extracted=item.extracted,
|
131
117
|
field_type_filter=item.field_type_filter,
|
@@ -182,23 +168,21 @@ async def _index_node_retrieval(
|
|
182
168
|
},
|
183
169
|
)
|
184
170
|
|
185
|
-
return search_results, incomplete_results,
|
171
|
+
return search_results, incomplete_results, parsed
|
186
172
|
|
187
173
|
|
188
174
|
async def _external_index_retrieval(
|
189
175
|
kbid: str,
|
190
176
|
item: FindRequest,
|
191
177
|
external_index_manager: ExternalIndexManager,
|
192
|
-
|
193
|
-
) -> tuple[KnowledgeboxFindResults, bool, QueryParser]:
|
178
|
+
) -> tuple[KnowledgeboxFindResults, bool, ParsedQuery]:
|
194
179
|
"""
|
195
180
|
Parse the query, query the external index, and hydrate the results.
|
196
181
|
"""
|
197
182
|
# Parse query
|
198
|
-
|
199
|
-
|
200
|
-
)
|
201
|
-
search_request, incomplete_results, _, rephrased_query = await query_parser.parse()
|
183
|
+
parsed = await parse_find(kbid, item)
|
184
|
+
reranker = get_reranker(parsed.retrieval.reranker)
|
185
|
+
search_request, incomplete_results, _, rephrased_query = await convert_retrieval_to_proto(parsed)
|
202
186
|
|
203
187
|
# Query index
|
204
188
|
query_results = await external_index_manager.query(search_request) # noqa
|
@@ -218,13 +202,15 @@ async def _external_index_retrieval(
|
|
218
202
|
kbid=kbid,
|
219
203
|
query=search_request.body,
|
220
204
|
),
|
221
|
-
top_k=
|
205
|
+
top_k=parsed.retrieval.top_k,
|
222
206
|
)
|
223
207
|
find_resources = compose_find_resources(text_blocks, resources)
|
224
208
|
|
225
209
|
results_min_score = MinScore(
|
226
210
|
bm25=0,
|
227
|
-
semantic=
|
211
|
+
semantic=parsed.retrieval.query.semantic.min_score
|
212
|
+
if parsed.retrieval.query.semantic is not None
|
213
|
+
else 0.0,
|
228
214
|
)
|
229
215
|
retrieval_results = KnowledgeboxFindResults(
|
230
216
|
resources=find_resources,
|
@@ -242,65 +228,4 @@ async def _external_index_retrieval(
|
|
242
228
|
nodes=None,
|
243
229
|
)
|
244
230
|
|
245
|
-
return retrieval_results, incomplete_results,
|
246
|
-
|
247
|
-
|
248
|
-
@dataclass
|
249
|
-
class ScoredParagraph:
|
250
|
-
id: str
|
251
|
-
score: float
|
252
|
-
|
253
|
-
|
254
|
-
async def query_parser_from_find_request(
|
255
|
-
kbid: str, item: FindRequest, *, generative_model: Optional[str] = None
|
256
|
-
) -> tuple[QueryParser, RankFusionAlgorithm, Reranker]:
|
257
|
-
item.min_score = min_score_from_payload(item.min_score)
|
258
|
-
|
259
|
-
if SearchOptions.SEMANTIC in item.features:
|
260
|
-
if should_disable_vector_search(item):
|
261
|
-
item.features.remove(SearchOptions.SEMANTIC)
|
262
|
-
|
263
|
-
hidden = await filter_hidden_resources(kbid, item.show_hidden)
|
264
|
-
|
265
|
-
# XXX this is becoming the new /find query parsing, this should be moved to
|
266
|
-
# a cleaner abstraction
|
267
|
-
|
268
|
-
parsed = await parse_find(kbid, item)
|
269
|
-
|
270
|
-
rank_fusion = get_rank_fusion(parsed.rank_fusion)
|
271
|
-
reranker = get_reranker(parsed.reranker)
|
272
|
-
|
273
|
-
query_parser = QueryParser(
|
274
|
-
kbid=kbid,
|
275
|
-
features=item.features,
|
276
|
-
query=item.query,
|
277
|
-
query_entities=item.query_entities,
|
278
|
-
filter_expression=item.filter_expression,
|
279
|
-
faceted=None,
|
280
|
-
sort=None,
|
281
|
-
top_k=item.top_k,
|
282
|
-
min_score=item.min_score,
|
283
|
-
old_filters=OldFilterParams(
|
284
|
-
label_filters=item.filters,
|
285
|
-
keyword_filters=item.keyword_filters,
|
286
|
-
range_creation_start=item.range_creation_start,
|
287
|
-
range_creation_end=item.range_creation_end,
|
288
|
-
range_modification_start=item.range_modification_start,
|
289
|
-
range_modification_end=item.range_modification_end,
|
290
|
-
fields=item.fields,
|
291
|
-
key_filters=item.resource_filters,
|
292
|
-
),
|
293
|
-
user_vector=item.vector,
|
294
|
-
vectorset=item.vectorset,
|
295
|
-
with_duplicates=item.with_duplicates,
|
296
|
-
with_synonyms=item.with_synonyms,
|
297
|
-
autofilter=item.autofilter,
|
298
|
-
security=item.security,
|
299
|
-
generative_model=generative_model,
|
300
|
-
rephrase=item.rephrase,
|
301
|
-
rephrase_prompt=item.rephrase_prompt,
|
302
|
-
hidden=hidden,
|
303
|
-
rank_fusion=rank_fusion,
|
304
|
-
reranker=reranker,
|
305
|
-
)
|
306
|
-
return (query_parser, rank_fusion, reranker)
|
231
|
+
return retrieval_results, incomplete_results, parsed
|
@@ -32,6 +32,7 @@ from nucliadb.search.search.hydrator import (
|
|
32
32
|
text_block_to_find_paragraph,
|
33
33
|
)
|
34
34
|
from nucliadb.search.search.merge import merge_relations_results
|
35
|
+
from nucliadb.search.search.query_parser.models import UnitRetrieval
|
35
36
|
from nucliadb.search.search.rank_fusion import RankFusionAlgorithm
|
36
37
|
from nucliadb.search.search.rerankers import (
|
37
38
|
RerankableItem,
|
@@ -51,7 +52,6 @@ from nucliadb_models.search import (
|
|
51
52
|
)
|
52
53
|
from nucliadb_protos.nodereader_pb2 import (
|
53
54
|
DocumentScored,
|
54
|
-
EntitiesSubgraphRequest,
|
55
55
|
ParagraphResult,
|
56
56
|
ParagraphSearchResponse,
|
57
57
|
RelationSearchResponse,
|
@@ -72,13 +72,10 @@ FIND_FETCH_OPS_DISTRIBUTION = metrics.Histogram(
|
|
72
72
|
async def build_find_response(
|
73
73
|
search_responses: list[SearchResponse],
|
74
74
|
*,
|
75
|
+
retrieval: UnitRetrieval,
|
75
76
|
kbid: str,
|
76
77
|
query: str,
|
77
78
|
rephrased_query: Optional[str],
|
78
|
-
relation_subgraph_query: EntitiesSubgraphRequest,
|
79
|
-
top_k: int,
|
80
|
-
min_score_bm25: float,
|
81
|
-
min_score_semantic: float,
|
82
79
|
rank_fusion_algorithm: RankFusionAlgorithm,
|
83
80
|
reranker: Reranker,
|
84
81
|
show: list[ResourceProperties] = [],
|
@@ -86,6 +83,15 @@ async def build_find_response(
|
|
86
83
|
field_type_filter: list[FieldTypeName] = [],
|
87
84
|
highlight: bool = False,
|
88
85
|
) -> KnowledgeboxFindResults:
|
86
|
+
# XXX: we shouldn't need a min score that we haven't used. Previous
|
87
|
+
# implementations got this value from the proto request (i.e., default to 0)
|
88
|
+
min_score_bm25 = 0.0
|
89
|
+
if retrieval.query.keyword is not None:
|
90
|
+
min_score_bm25 = retrieval.query.keyword.min_score
|
91
|
+
min_score_semantic = 0.0
|
92
|
+
if retrieval.query.semantic is not None:
|
93
|
+
min_score_semantic = retrieval.query.semantic.min_score
|
94
|
+
|
89
95
|
# merge
|
90
96
|
search_response = merge_shard_responses(search_responses)
|
91
97
|
|
@@ -112,7 +118,7 @@ async def build_find_response(
|
|
112
118
|
assert reranker.window is not None, "Reranker definition must enforce this condition"
|
113
119
|
text_blocks_page, next_page = cut_page(merged_text_blocks, reranker.window)
|
114
120
|
else:
|
115
|
-
text_blocks_page, next_page = cut_page(merged_text_blocks, top_k)
|
121
|
+
text_blocks_page, next_page = cut_page(merged_text_blocks, retrieval.top_k)
|
116
122
|
|
117
123
|
# hydrate and rerank
|
118
124
|
resource_hydration_options = ResourceHydrationOptions(
|
@@ -130,11 +136,14 @@ async def build_find_response(
|
|
130
136
|
text_block_hydration_options=text_block_hydration_options,
|
131
137
|
reranker=reranker,
|
132
138
|
reranking_options=reranking_options,
|
133
|
-
top_k=top_k,
|
139
|
+
top_k=retrieval.top_k,
|
134
140
|
)
|
135
141
|
|
136
142
|
# build relations graph
|
137
|
-
|
143
|
+
entry_points = []
|
144
|
+
if retrieval.query.relation is not None:
|
145
|
+
entry_points = retrieval.query.relation.detected_entities
|
146
|
+
relations = await merge_relations_results([search_response.relation], entry_points)
|
138
147
|
|
139
148
|
# compose response
|
140
149
|
find_resources = compose_find_resources(text_blocks, resources)
|
@@ -150,7 +159,7 @@ async def build_find_response(
|
|
150
159
|
relations=relations,
|
151
160
|
total=total_paragraphs,
|
152
161
|
page_number=0, # Bw/c with pagination
|
153
|
-
page_size=top_k,
|
162
|
+
page_size=retrieval.top_k,
|
154
163
|
next_page=next_page,
|
155
164
|
min_score=MinScore(bm25=_round(min_score_bm25), semantic=_round(min_score_semantic)),
|
156
165
|
)
|
@@ -16,7 +16,6 @@
|
|
16
16
|
#
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
-
|
20
19
|
import heapq
|
21
20
|
import json
|
22
21
|
from collections import defaultdict
|
@@ -38,14 +37,16 @@ from nucliadb.search.search.chat.query import (
|
|
38
37
|
find_request_from_ask_request,
|
39
38
|
get_relations_results_from_entities,
|
40
39
|
)
|
41
|
-
from nucliadb.search.search.find import query_parser_from_find_request
|
42
40
|
from nucliadb.search.search.find_merge import (
|
43
41
|
compose_find_resources,
|
44
42
|
hydrate_and_rerank,
|
45
43
|
)
|
46
44
|
from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
|
47
45
|
from nucliadb.search.search.metrics import RAGMetrics
|
48
|
-
from nucliadb.search.search.rerankers import
|
46
|
+
from nucliadb.search.search.rerankers import (
|
47
|
+
Reranker,
|
48
|
+
RerankingOptions,
|
49
|
+
)
|
49
50
|
from nucliadb.search.utilities import get_predict
|
50
51
|
from nucliadb_models.common import FieldTypeName
|
51
52
|
from nucliadb_models.internal.predict import (
|
@@ -303,6 +304,7 @@ async def get_graph_results(
|
|
303
304
|
user: str,
|
304
305
|
origin: str,
|
305
306
|
graph_strategy: GraphStrategy,
|
307
|
+
text_block_reranker: Reranker,
|
306
308
|
generative_model: Optional[str] = None,
|
307
309
|
metrics: RAGMetrics = RAGMetrics(),
|
308
310
|
shards: Optional[list[str]] = None,
|
@@ -419,19 +421,16 @@ async def get_graph_results(
|
|
419
421
|
# Get the text blocks of the paragraphs that contain the top relations
|
420
422
|
with metrics.time("graph_strat_build_response"):
|
421
423
|
find_request = find_request_from_ask_request(item, query)
|
422
|
-
query_parser, rank_fusion, reranker = await query_parser_from_find_request(
|
423
|
-
kbid, find_request, generative_model=generative_model
|
424
|
-
)
|
425
424
|
find_results = await build_graph_response(
|
426
425
|
kbid=kbid,
|
427
426
|
query=query,
|
428
427
|
final_relations=relations,
|
429
428
|
scores=scores,
|
430
429
|
top_k=graph_strategy.top_k,
|
431
|
-
reranker=
|
432
|
-
show=
|
433
|
-
extracted=
|
434
|
-
field_type_filter=
|
430
|
+
reranker=text_block_reranker,
|
431
|
+
show=item.show,
|
432
|
+
extracted=item.extracted,
|
433
|
+
field_type_filter=item.field_type_filter,
|
435
434
|
relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
|
436
435
|
)
|
437
436
|
return find_results, find_request
|