nucliadb 6.3.7.post4091__py3-none-any.whl → 6.3.7.post4116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/ingest/consumer/consumer.py +3 -4
- nucliadb/search/api/v1/find.py +5 -5
- nucliadb/search/api/v1/search.py +2 -10
- nucliadb/search/search/chat/ask.py +6 -3
- nucliadb/search/search/chat/query.py +21 -17
- nucliadb/search/search/find.py +14 -5
- nucliadb/search/search/find_merge.py +27 -13
- nucliadb/search/search/merge.py +17 -18
- nucliadb/search/search/query_parser/models.py +22 -27
- nucliadb/search/search/query_parser/parsers/common.py +32 -21
- nucliadb/search/search/query_parser/parsers/find.py +31 -8
- nucliadb/search/search/query_parser/parsers/search.py +33 -10
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +207 -115
- nucliadb/search/search/utils.py +2 -42
- {nucliadb-6.3.7.post4091.dist-info → nucliadb-6.3.7.post4116.dist-info}/METADATA +6 -6
- {nucliadb-6.3.7.post4091.dist-info → nucliadb-6.3.7.post4116.dist-info}/RECORD +19 -19
- {nucliadb-6.3.7.post4091.dist-info → nucliadb-6.3.7.post4116.dist-info}/WHEEL +1 -1
- {nucliadb-6.3.7.post4091.dist-info → nucliadb-6.3.7.post4116.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.3.7.post4091.dist-info → nucliadb-6.3.7.post4116.dist-info}/top_level.txt +0 -0
@@ -160,6 +160,8 @@ class IngestConsumer:
|
|
160
160
|
logger.warning("Could not delete blob reference", exc_info=True)
|
161
161
|
|
162
162
|
async def subscription_worker(self, msg: Msg):
|
163
|
+
context.clear_context()
|
164
|
+
|
163
165
|
kbid: Optional[str] = None
|
164
166
|
subject = msg.subject
|
165
167
|
reply = msg.reply
|
@@ -182,7 +184,6 @@ class IngestConsumer:
|
|
182
184
|
MessageProgressUpdater(msg, nats_consumer_settings.nats_ack_wait * 0.66),
|
183
185
|
self.lock,
|
184
186
|
):
|
185
|
-
logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
|
186
187
|
try:
|
187
188
|
pb = await self.get_broker_message(msg)
|
188
189
|
if pb.source == pb.MessageSource.PROCESSOR:
|
@@ -194,10 +195,8 @@ class IngestConsumer:
|
|
194
195
|
else:
|
195
196
|
audit_time = ""
|
196
197
|
|
197
|
-
logger.debug(
|
198
|
-
f"Received from {message_source} on {pb.kbid}/{pb.uuid} seq {seqid} partition {self.partition} at {time}" # noqa
|
199
|
-
)
|
200
198
|
context.add_context({"kbid": pb.kbid, "rid": pb.uuid})
|
199
|
+
logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
|
201
200
|
kbid = pb.kbid
|
202
201
|
try:
|
203
202
|
source = "writer" if pb.source == pb.MessageSource.WRITER else "processor"
|
nucliadb/search/api/v1/find.py
CHANGED
@@ -40,6 +40,7 @@ from nucliadb_models.configuration import FindConfig
|
|
40
40
|
from nucliadb_models.filters import FilterExpression
|
41
41
|
from nucliadb_models.resource import ExtractedDataTypeName, NucliaDBRoles
|
42
42
|
from nucliadb_models.search import (
|
43
|
+
FindOptions,
|
43
44
|
FindRequest,
|
44
45
|
KnowledgeboxFindResults,
|
45
46
|
NucliaDBClientType,
|
@@ -47,7 +48,6 @@ from nucliadb_models.search import (
|
|
47
48
|
Reranker,
|
48
49
|
RerankerName,
|
49
50
|
ResourceProperties,
|
50
|
-
SearchOptions,
|
51
51
|
SearchParamDefaults,
|
52
52
|
)
|
53
53
|
from nucliadb_models.security import RequestSecurity
|
@@ -61,7 +61,7 @@ FIND_EXAMPLES = {
|
|
61
61
|
description="Perform a hybrid search that will return text and semantic results matching the query",
|
62
62
|
value={
|
63
63
|
"query": "How can I be an effective product manager?",
|
64
|
-
"features": [
|
64
|
+
"features": [FindOptions.KEYWORD, FindOptions.SEMANTIC],
|
65
65
|
},
|
66
66
|
)
|
67
67
|
}
|
@@ -110,11 +110,11 @@ async def find_knowledgebox(
|
|
110
110
|
range_modification_end: Optional[DateTime] = fastapi_query(
|
111
111
|
SearchParamDefaults.range_modification_end
|
112
112
|
),
|
113
|
-
features: list[
|
113
|
+
features: list[FindOptions] = fastapi_query(
|
114
114
|
SearchParamDefaults.search_features,
|
115
115
|
default=[
|
116
|
-
|
117
|
-
|
116
|
+
FindOptions.KEYWORD,
|
117
|
+
FindOptions.SEMANTIC,
|
118
118
|
],
|
119
119
|
),
|
120
120
|
debug: bool = fastapi_query(SearchParamDefaults.debug),
|
nucliadb/search/api/v1/search.py
CHANGED
@@ -37,11 +37,9 @@ from nucliadb.search.search import cache
|
|
37
37
|
from nucliadb.search.search.exceptions import InvalidQueryError
|
38
38
|
from nucliadb.search.search.merge import merge_results
|
39
39
|
from nucliadb.search.search.query_parser.parsers.search import parse_search
|
40
|
-
from nucliadb.search.search.query_parser.parsers.unit_retrieval import
|
40
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
|
41
41
|
from nucliadb.search.search.utils import (
|
42
|
-
min_score_from_payload,
|
43
42
|
min_score_from_query_params,
|
44
|
-
should_disable_vector_search,
|
45
43
|
)
|
46
44
|
from nucliadb_models.common import FieldTypeName
|
47
45
|
from nucliadb_models.filters import FilterExpression
|
@@ -263,14 +261,8 @@ async def search(
|
|
263
261
|
audit = get_audit()
|
264
262
|
start_time = time()
|
265
263
|
|
266
|
-
item.min_score = min_score_from_payload(item.min_score)
|
267
|
-
|
268
|
-
if SearchOptions.SEMANTIC in item.features:
|
269
|
-
if should_disable_vector_search(item):
|
270
|
-
item.features.remove(SearchOptions.SEMANTIC)
|
271
|
-
|
272
264
|
parsed = await parse_search(kbid, item)
|
273
|
-
pb_query, incomplete_results, autofilters, _ = await
|
265
|
+
pb_query, incomplete_results, autofilters, _ = await legacy_convert_retrieval_to_proto(parsed)
|
274
266
|
|
275
267
|
# We need to query all nodes
|
276
268
|
results, query_incomplete_results, queried_nodes = await node_query(kbid, Method.SEARCH, pb_query)
|
@@ -80,6 +80,7 @@ from nucliadb_models.search import (
|
|
80
80
|
CitationsAskResponseItem,
|
81
81
|
DebugAskResponseItem,
|
82
82
|
ErrorAskResponseItem,
|
83
|
+
FindOptions,
|
83
84
|
FindParagraph,
|
84
85
|
FindRequest,
|
85
86
|
GraphStrategy,
|
@@ -97,7 +98,6 @@ from nucliadb_models.search import (
|
|
97
98
|
Relations,
|
98
99
|
RelationsAskResponseItem,
|
99
100
|
RetrievalAskResponseItem,
|
100
|
-
SearchOptions,
|
101
101
|
StatusAskResponseItem,
|
102
102
|
SyncAskMetadata,
|
103
103
|
SyncAskResponse,
|
@@ -755,6 +755,9 @@ async def retrieval_in_kb(
|
|
755
755
|
)
|
756
756
|
|
757
757
|
if graph_strategy is not None:
|
758
|
+
assert parsed_query.retrieval.reranker is not None, (
|
759
|
+
"find parser must provide a reranking algorithm"
|
760
|
+
)
|
758
761
|
reranker = get_reranker(parsed_query.retrieval.reranker)
|
759
762
|
graph_results, graph_request = await get_graph_results(
|
760
763
|
kbid=kbid,
|
@@ -952,9 +955,9 @@ def calculate_prequeries_for_json_schema(
|
|
952
955
|
json_schema = ask_request.answer_json_schema or {}
|
953
956
|
features = []
|
954
957
|
if ChatOptions.SEMANTIC in ask_request.features:
|
955
|
-
features.append(
|
958
|
+
features.append(FindOptions.SEMANTIC)
|
956
959
|
if ChatOptions.KEYWORD in ask_request.features:
|
957
|
-
features.append(
|
960
|
+
features.append(FindOptions.KEYWORD)
|
958
961
|
|
959
962
|
properties = json_schema.get("parameters", {}).get("properties", {})
|
960
963
|
if len(properties) == 0: # pragma: no cover
|
@@ -29,7 +29,8 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
|
29
29
|
from nucliadb.search.search.find import find
|
30
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
31
|
from nucliadb.search.search.metrics import RAGMetrics
|
32
|
-
from nucliadb.search.search.query_parser.models import ParsedQuery
|
32
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery, Query, RelationQuery, UnitRetrieval
|
33
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
33
34
|
from nucliadb.search.settings import settings
|
34
35
|
from nucliadb.search.utilities import get_predict
|
35
36
|
from nucliadb_models import filters
|
@@ -37,6 +38,7 @@ from nucliadb_models.search import (
|
|
37
38
|
AskRequest,
|
38
39
|
ChatContextMessage,
|
39
40
|
ChatOptions,
|
41
|
+
FindOptions,
|
40
42
|
FindRequest,
|
41
43
|
KnowledgeboxFindResults,
|
42
44
|
NucliaDBClientType,
|
@@ -47,14 +49,11 @@ from nucliadb_models.search import (
|
|
47
49
|
PromptContextOrder,
|
48
50
|
Relations,
|
49
51
|
RephraseModel,
|
50
|
-
SearchOptions,
|
51
52
|
parse_rephrase_prompt,
|
52
53
|
)
|
53
54
|
from nucliadb_protos import audit_pb2
|
54
55
|
from nucliadb_protos.nodereader_pb2 import (
|
55
|
-
|
56
|
-
RelationSearchResponse,
|
57
|
-
SearchRequest,
|
56
|
+
GraphSearchResponse,
|
58
57
|
SearchResponse,
|
59
58
|
)
|
60
59
|
from nucliadb_protos.utils_pb2 import RelationNode
|
@@ -181,11 +180,11 @@ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
|
|
181
180
|
find_request.resource_filters = item.resource_filters
|
182
181
|
find_request.features = []
|
183
182
|
if ChatOptions.SEMANTIC in item.features:
|
184
|
-
find_request.features.append(
|
183
|
+
find_request.features.append(FindOptions.SEMANTIC)
|
185
184
|
if ChatOptions.KEYWORD in item.features:
|
186
|
-
find_request.features.append(
|
185
|
+
find_request.features.append(FindOptions.KEYWORD)
|
187
186
|
if ChatOptions.RELATIONS in item.features:
|
188
|
-
find_request.features.append(
|
187
|
+
find_request.features.append(FindOptions.RELATIONS)
|
189
188
|
find_request.query = query
|
190
189
|
find_request.fields = item.fields
|
191
190
|
find_request.filters = item.filters
|
@@ -274,13 +273,18 @@ async def get_relations_results_from_entities(
|
|
274
273
|
only_entity_to_entity: bool = False,
|
275
274
|
deleted_entities: set[str] = set(),
|
276
275
|
) -> Relations:
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
276
|
+
entry_points = list(entities)
|
277
|
+
retrieval = UnitRetrieval(
|
278
|
+
query=Query(
|
279
|
+
relation=RelationQuery(
|
280
|
+
entry_points=entry_points,
|
281
|
+
deleted_entities={"": list(deleted_entities)},
|
282
|
+
deleted_entity_groups=[],
|
283
|
+
)
|
284
|
+
),
|
285
|
+
top_k=50,
|
286
|
+
)
|
287
|
+
request = convert_retrieval_to_proto(retrieval)
|
284
288
|
|
285
289
|
results: list[SearchResponse]
|
286
290
|
(
|
@@ -293,10 +297,10 @@ async def get_relations_results_from_entities(
|
|
293
297
|
request,
|
294
298
|
timeout=timeout,
|
295
299
|
)
|
296
|
-
relations_results: list[
|
300
|
+
relations_results: list[GraphSearchResponse] = [result.graph for result in results]
|
297
301
|
return await merge_relations_results(
|
298
302
|
relations_results,
|
299
|
-
|
303
|
+
entry_points,
|
300
304
|
only_with_metadata,
|
301
305
|
only_agentic_relations,
|
302
306
|
only_entity_to_entity,
|
nucliadb/search/search/find.py
CHANGED
@@ -38,7 +38,7 @@ from nucliadb.search.search.metrics import (
|
|
38
38
|
)
|
39
39
|
from nucliadb.search.search.query_parser.models import ParsedQuery
|
40
40
|
from nucliadb.search.search.query_parser.parsers import parse_find
|
41
|
-
from nucliadb.search.search.query_parser.parsers.unit_retrieval import
|
41
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
|
42
42
|
from nucliadb.search.search.rank_fusion import (
|
43
43
|
get_rank_fusion,
|
44
44
|
)
|
@@ -92,11 +92,17 @@ async def _index_node_retrieval(
|
|
92
92
|
|
93
93
|
with metrics.time("query_parse"):
|
94
94
|
parsed = await parse_find(kbid, item)
|
95
|
+
assert parsed.retrieval.rank_fusion is not None and parsed.retrieval.reranker is not None, (
|
96
|
+
"find parser must provide rank fusion and reranker algorithms"
|
97
|
+
)
|
95
98
|
rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion)
|
96
99
|
reranker = get_reranker(parsed.retrieval.reranker)
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
+
(
|
101
|
+
pb_query,
|
102
|
+
incomplete_results,
|
103
|
+
autofilters,
|
104
|
+
rephrased_query,
|
105
|
+
) = await legacy_convert_retrieval_to_proto(parsed)
|
100
106
|
|
101
107
|
with metrics.time("node_query"):
|
102
108
|
results, query_incomplete_results, queried_nodes = await node_query(
|
@@ -181,8 +187,11 @@ async def _external_index_retrieval(
|
|
181
187
|
"""
|
182
188
|
# Parse query
|
183
189
|
parsed = await parse_find(kbid, item)
|
190
|
+
assert parsed.retrieval.reranker is not None, "find parser must provide a reranking algorithm"
|
184
191
|
reranker = get_reranker(parsed.retrieval.reranker)
|
185
|
-
search_request, incomplete_results, _, rephrased_query = await
|
192
|
+
search_request, incomplete_results, _, rephrased_query = await legacy_convert_retrieval_to_proto(
|
193
|
+
parsed
|
194
|
+
)
|
186
195
|
|
187
196
|
# Query index
|
188
197
|
query_results = await external_index_manager.query(search_request) # noqa
|
@@ -52,9 +52,9 @@ from nucliadb_models.search import (
|
|
52
52
|
)
|
53
53
|
from nucliadb_protos.nodereader_pb2 import (
|
54
54
|
DocumentScored,
|
55
|
+
GraphSearchResponse,
|
55
56
|
ParagraphResult,
|
56
57
|
ParagraphSearchResponse,
|
57
|
-
RelationSearchResponse,
|
58
58
|
SearchResponse,
|
59
59
|
VectorSearchResponse,
|
60
60
|
)
|
@@ -142,8 +142,8 @@ async def build_find_response(
|
|
142
142
|
# build relations graph
|
143
143
|
entry_points = []
|
144
144
|
if retrieval.query.relation is not None:
|
145
|
-
entry_points = retrieval.query.relation.
|
146
|
-
relations = await merge_relations_results([search_response.
|
145
|
+
entry_points = retrieval.query.relation.entry_points
|
146
|
+
relations = await merge_relations_results([search_response.graph], entry_points)
|
147
147
|
|
148
148
|
# compose response
|
149
149
|
find_resources = compose_find_resources(text_blocks, resources)
|
@@ -178,16 +178,16 @@ def merge_shard_responses(
|
|
178
178
|
"""
|
179
179
|
paragraphs = []
|
180
180
|
vectors = []
|
181
|
-
|
181
|
+
graphs = []
|
182
182
|
for response in responses:
|
183
183
|
paragraphs.append(response.paragraph)
|
184
184
|
vectors.append(response.vector)
|
185
|
-
|
185
|
+
graphs.append(response.graph)
|
186
186
|
|
187
187
|
merged = SearchResponse(
|
188
188
|
paragraph=merge_shards_keyword_responses(paragraphs),
|
189
189
|
vector=merge_shards_semantic_responses(vectors),
|
190
|
-
|
190
|
+
graph=merge_shards_graph_responses(graphs),
|
191
191
|
)
|
192
192
|
return merged
|
193
193
|
|
@@ -230,13 +230,27 @@ def merge_shards_semantic_responses(
|
|
230
230
|
return merged
|
231
231
|
|
232
232
|
|
233
|
-
def
|
234
|
-
|
235
|
-
)
|
236
|
-
merged =
|
237
|
-
|
238
|
-
|
239
|
-
merged.
|
233
|
+
def merge_shards_graph_responses(
|
234
|
+
graph_responses: list[GraphSearchResponse],
|
235
|
+
):
|
236
|
+
merged = GraphSearchResponse()
|
237
|
+
|
238
|
+
for response in graph_responses:
|
239
|
+
nodes_offset = len(merged.nodes)
|
240
|
+
relations_offset = len(merged.relations)
|
241
|
+
|
242
|
+
# paths contain indexes to nodes and relations, we must offset them
|
243
|
+
# while merging responses to maintain valid data
|
244
|
+
for path in response.graph:
|
245
|
+
merged_path = GraphSearchResponse.Path()
|
246
|
+
merged_path.CopyFrom(path)
|
247
|
+
merged_path.source += nodes_offset
|
248
|
+
merged_path.relation += relations_offset
|
249
|
+
merged_path.destination += nodes_offset
|
250
|
+
merged.graph.append(merged_path)
|
251
|
+
|
252
|
+
merged.nodes.extend(response.nodes)
|
253
|
+
merged.relations.extend(response.relations)
|
240
254
|
|
241
255
|
return merged
|
242
256
|
|
nucliadb/search/search/merge.py
CHANGED
@@ -65,9 +65,9 @@ from nucliadb_protos.nodereader_pb2 import (
|
|
65
65
|
DocumentResult,
|
66
66
|
DocumentScored,
|
67
67
|
DocumentSearchResponse,
|
68
|
+
GraphSearchResponse,
|
68
69
|
ParagraphResult,
|
69
70
|
ParagraphSearchResponse,
|
70
|
-
RelationSearchResponse,
|
71
71
|
SearchResponse,
|
72
72
|
SuggestResponse,
|
73
73
|
VectorSearchResponse,
|
@@ -438,7 +438,7 @@ async def merge_paragraph_results(
|
|
438
438
|
|
439
439
|
@merge_observer.wrap({"type": "merge_relations"})
|
440
440
|
async def merge_relations_results(
|
441
|
-
|
441
|
+
graph_responses: list[GraphSearchResponse],
|
442
442
|
query_entry_points: Iterable[RelationNode],
|
443
443
|
only_with_metadata: bool = False,
|
444
444
|
only_agentic: bool = False,
|
@@ -448,7 +448,7 @@ async def merge_relations_results(
|
|
448
448
|
return await loop.run_in_executor(
|
449
449
|
None,
|
450
450
|
_merge_relations_results,
|
451
|
-
|
451
|
+
graph_responses,
|
452
452
|
query_entry_points,
|
453
453
|
only_with_metadata,
|
454
454
|
only_agentic,
|
@@ -457,7 +457,7 @@ async def merge_relations_results(
|
|
457
457
|
|
458
458
|
|
459
459
|
def _merge_relations_results(
|
460
|
-
|
460
|
+
graph_responses: list[GraphSearchResponse],
|
461
461
|
query_entry_points: Iterable[RelationNode],
|
462
462
|
only_with_metadata: bool,
|
463
463
|
only_agentic: bool,
|
@@ -480,17 +480,16 @@ def _merge_relations_results(
|
|
480
480
|
for entry_point in query_entry_points:
|
481
481
|
relations.entities[entry_point.value] = EntitySubgraph(related_to=[])
|
482
482
|
|
483
|
-
for
|
484
|
-
for
|
485
|
-
relation =
|
486
|
-
origin =
|
487
|
-
destination =
|
488
|
-
relation_type = RelationTypePbMap[relation.
|
489
|
-
relation_label = relation.
|
490
|
-
metadata =
|
491
|
-
|
492
|
-
|
493
|
-
resource_id = index_relation.resource_field_id.split("/")[0]
|
483
|
+
for graph_response in graph_responses:
|
484
|
+
for path in graph_response.graph:
|
485
|
+
relation = graph_response.relations[path.relation]
|
486
|
+
origin = graph_response.nodes[path.source]
|
487
|
+
destination = graph_response.nodes[path.destination]
|
488
|
+
relation_type = RelationTypePbMap[relation.relation_type]
|
489
|
+
relation_label = relation.label
|
490
|
+
metadata = path.metadata if path.HasField("metadata") else None
|
491
|
+
if path.resource_field_id is not None:
|
492
|
+
resource_id = path.resource_field_id.split("/")[0]
|
494
493
|
|
495
494
|
# If only_with_metadata is True, we check that metadata for the relation is not None
|
496
495
|
# If only_agentic is True, we check that metadata for the relation is not None and that it has a data_augmentation_task_id
|
@@ -547,13 +546,13 @@ async def merge_results(
|
|
547
546
|
paragraphs = []
|
548
547
|
documents = []
|
549
548
|
vectors = []
|
550
|
-
|
549
|
+
graphs = []
|
551
550
|
|
552
551
|
for response in search_responses:
|
553
552
|
paragraphs.append(response.paragraph)
|
554
553
|
documents.append(response.document)
|
555
554
|
vectors.append(response.vector)
|
556
|
-
|
555
|
+
graphs.append(response.graph)
|
557
556
|
|
558
557
|
api_results = KnowledgeboxSearchResults()
|
559
558
|
|
@@ -595,7 +594,7 @@ async def merge_results(
|
|
595
594
|
|
596
595
|
if retrieval.query.relation is not None:
|
597
596
|
api_results.relations = await merge_relations_results(
|
598
|
-
|
597
|
+
graphs, retrieval.query.relation.entry_points
|
599
598
|
)
|
600
599
|
|
601
600
|
api_results.resources = await fetch_resources(resources, kbid, show, field_type_filter, extracted)
|
@@ -21,10 +21,7 @@ from dataclasses import dataclass
|
|
21
21
|
from datetime import datetime
|
22
22
|
from typing import Literal, Optional, Union
|
23
23
|
|
24
|
-
from pydantic import
|
25
|
-
BaseModel,
|
26
|
-
Field,
|
27
|
-
)
|
24
|
+
from pydantic import BaseModel, ConfigDict, Field
|
28
25
|
|
29
26
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
30
27
|
from nucliadb_models import search as search_models
|
@@ -35,8 +32,7 @@ from nucliadb_protos import nodereader_pb2, utils_pb2
|
|
35
32
|
# query
|
36
33
|
|
37
34
|
|
38
|
-
|
39
|
-
class _TextQuery:
|
35
|
+
class _TextQuery(BaseModel):
|
40
36
|
query: str
|
41
37
|
is_synonyms_query: bool
|
42
38
|
min_score: float
|
@@ -48,24 +44,23 @@ FulltextQuery = _TextQuery
|
|
48
44
|
KeywordQuery = _TextQuery
|
49
45
|
|
50
46
|
|
51
|
-
|
52
|
-
class SemanticQuery:
|
47
|
+
class SemanticQuery(BaseModel):
|
53
48
|
query: Optional[list[float]]
|
54
49
|
vectorset: str
|
55
50
|
min_score: float
|
56
51
|
|
57
52
|
|
58
|
-
|
59
|
-
|
60
|
-
|
53
|
+
class RelationQuery(BaseModel):
|
54
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
55
|
+
|
56
|
+
entry_points: list[utils_pb2.RelationNode]
|
61
57
|
# list[subtype]
|
62
58
|
deleted_entity_groups: list[str]
|
63
59
|
# subtype -> list[entity]
|
64
60
|
deleted_entities: dict[str, list[str]]
|
65
61
|
|
66
62
|
|
67
|
-
|
68
|
-
class Query:
|
63
|
+
class Query(BaseModel):
|
69
64
|
fulltext: Optional[FulltextQuery] = None
|
70
65
|
keyword: Optional[KeywordQuery] = None
|
71
66
|
semantic: Optional[SemanticQuery] = None
|
@@ -75,8 +70,9 @@ class Query:
|
|
75
70
|
# filters
|
76
71
|
|
77
72
|
|
78
|
-
|
79
|
-
|
73
|
+
class Filters(BaseModel):
|
74
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
75
|
+
|
80
76
|
field_expression: Optional[nodereader_pb2.FilterExpression] = None
|
81
77
|
paragraph_expression: Optional[nodereader_pb2.FilterExpression] = None
|
82
78
|
filter_expression_operator: nodereader_pb2.FilterOperator.ValueType = (
|
@@ -125,30 +121,29 @@ Reranker = Union[NoopReranker, PredictReranker]
|
|
125
121
|
# retrieval and generation operations
|
126
122
|
|
127
123
|
|
128
|
-
|
129
|
-
class UnitRetrieval:
|
124
|
+
class UnitRetrieval(BaseModel):
|
130
125
|
query: Query
|
131
126
|
top_k: int
|
132
|
-
filters: Filters
|
133
|
-
|
134
|
-
|
135
|
-
# TODO: reranking fusion depends on the response building, not the retrieval
|
136
|
-
reranker: Reranker
|
127
|
+
filters: Filters = Field(default_factory=Filters)
|
128
|
+
rank_fusion: Optional[RankFusion] = None
|
129
|
+
reranker: Optional[Reranker] = None
|
137
130
|
|
138
131
|
|
139
|
-
|
140
|
-
|
132
|
+
# TODO: augmentation things: hydration...
|
133
|
+
|
134
|
+
|
135
|
+
class Generation(BaseModel):
|
141
136
|
use_visual_llm: bool
|
142
137
|
max_context_tokens: int
|
143
138
|
max_answer_tokens: Optional[int]
|
144
139
|
|
145
140
|
|
146
|
-
|
147
|
-
|
141
|
+
class ParsedQuery(BaseModel):
|
142
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
143
|
+
|
148
144
|
fetcher: Fetcher
|
149
145
|
retrieval: UnitRetrieval
|
150
146
|
generation: Optional[Generation] = None
|
151
|
-
# TODO: add merge, rank fusion, rerank...
|
152
147
|
|
153
148
|
|
154
149
|
### Catalog
|
@@ -28,7 +28,6 @@ from nucliadb.search.search.query_parser.models import (
|
|
28
28
|
KeywordQuery,
|
29
29
|
SemanticQuery,
|
30
30
|
)
|
31
|
-
from nucliadb.search.search.utils import should_disable_vector_search
|
32
31
|
from nucliadb_models import search as search_models
|
33
32
|
|
34
33
|
DEFAULT_GENERIC_SEMANTIC_THRESHOLD = 0.7
|
@@ -38,28 +37,40 @@ DEFAULT_GENERIC_SEMANTIC_THRESHOLD = 0.7
|
|
38
37
|
INVALID_QUERY = re.compile(r"- +\*")
|
39
38
|
|
40
39
|
|
41
|
-
def
|
40
|
+
def validate_query_syntax(query: str):
|
42
41
|
# Filter some queries that panic tantivy, better than returning the 500
|
43
|
-
if INVALID_QUERY.search(
|
42
|
+
if INVALID_QUERY.search(query):
|
44
43
|
raise InvalidQueryError("query", "Invalid query syntax")
|
45
44
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
45
|
+
|
46
|
+
def is_empty_query(request: search_models.BaseSearchRequest) -> bool:
|
47
|
+
return len(request.query) == 0
|
48
|
+
|
49
|
+
|
50
|
+
def has_user_vectors(request: search_models.BaseSearchRequest) -> bool:
|
51
|
+
return request.vector is not None and len(request.vector) > 0
|
52
|
+
|
53
|
+
|
54
|
+
def is_exact_match_only_query(request: search_models.BaseSearchRequest) -> bool:
|
55
|
+
"""
|
56
|
+
'"something"' -> True
|
57
|
+
'foo "something" else' -> False
|
58
|
+
"""
|
59
|
+
query = request.query.strip()
|
60
|
+
return len(query) > 0 and query.startswith('"') and query.endswith('"')
|
61
|
+
|
62
|
+
|
63
|
+
def should_disable_vector_search(request: search_models.BaseSearchRequest) -> bool:
|
64
|
+
if has_user_vectors(request):
|
65
|
+
return False
|
66
|
+
|
67
|
+
if is_exact_match_only_query(request):
|
68
|
+
return True
|
69
|
+
|
70
|
+
if is_empty_query(request):
|
71
|
+
return True
|
72
|
+
|
73
|
+
return False
|
63
74
|
|
64
75
|
|
65
76
|
def parse_top_k(item: search_models.BaseSearchRequest) -> int:
|
@@ -92,7 +103,7 @@ async def parse_keyword_query(
|
|
92
103
|
|
93
104
|
|
94
105
|
async def parse_semantic_query(
|
95
|
-
item: search_models.
|
106
|
+
item: Union[search_models.SearchRequest, search_models.FindRequest],
|
96
107
|
*,
|
97
108
|
fetcher: Fetcher,
|
98
109
|
) -> SemanticQuery:
|