nucliadb 6.3.7.post4081__py3-none-any.whl → 6.3.7.post4114__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/common/context/__init__.py +90 -25
- nucliadb/common/context/fastapi.py +4 -2
- nucliadb/ingest/consumer/consumer.py +3 -4
- nucliadb/search/api/v1/find.py +5 -5
- nucliadb/search/api/v1/search.py +2 -10
- nucliadb/search/search/chat/ask.py +6 -3
- nucliadb/search/search/chat/query.py +21 -17
- nucliadb/search/search/find.py +14 -5
- nucliadb/search/search/find_merge.py +27 -13
- nucliadb/search/search/merge.py +17 -18
- nucliadb/search/search/query_parser/models.py +22 -27
- nucliadb/search/search/query_parser/parsers/common.py +32 -21
- nucliadb/search/search/query_parser/parsers/find.py +31 -8
- nucliadb/search/search/query_parser/parsers/search.py +33 -10
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +207 -115
- nucliadb/search/search/utils.py +2 -42
- nucliadb/train/app.py +0 -3
- nucliadb/train/lifecycle.py +16 -11
- {nucliadb-6.3.7.post4081.dist-info → nucliadb-6.3.7.post4114.dist-info}/METADATA +6 -6
- {nucliadb-6.3.7.post4081.dist-info → nucliadb-6.3.7.post4114.dist-info}/RECORD +23 -23
- {nucliadb-6.3.7.post4081.dist-info → nucliadb-6.3.7.post4114.dist-info}/WHEEL +1 -1
- {nucliadb-6.3.7.post4081.dist-info → nucliadb-6.3.7.post4114.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.3.7.post4081.dist-info → nucliadb-6.3.7.post4114.dist-info}/top_level.txt +0 -0
@@ -18,17 +18,19 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
+
from typing import Optional
|
21
22
|
|
22
23
|
from nucliadb.common.cluster.manager import KBShardManager
|
23
24
|
from nucliadb.common.cluster.settings import in_standalone_mode
|
24
25
|
from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
|
25
26
|
from nucliadb.common.maindb.driver import Driver
|
26
27
|
from nucliadb.common.maindb.utils import setup_driver, teardown_driver
|
27
|
-
from nucliadb.common.nidx import start_nidx_utility, stop_nidx_utility
|
28
|
+
from nucliadb.common.nidx import NidxUtility, start_nidx_utility, stop_nidx_utility
|
28
29
|
from nucliadb_utils.nats import NatsConnectionManager
|
29
30
|
from nucliadb_utils.partition import PartitionUtility
|
30
31
|
from nucliadb_utils.settings import indexing_settings
|
31
32
|
from nucliadb_utils.storages.storage import Storage
|
33
|
+
from nucliadb_utils.transaction import TransactionUtility
|
32
34
|
from nucliadb_utils.utilities import (
|
33
35
|
get_storage,
|
34
36
|
start_nats_manager,
|
@@ -42,16 +44,34 @@ from nucliadb_utils.utilities import (
|
|
42
44
|
|
43
45
|
|
44
46
|
class ApplicationContext:
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
service_name: str = "service",
|
50
|
+
kv_driver: bool = True,
|
51
|
+
blob_storage: bool = True,
|
52
|
+
shard_manager: bool = True,
|
53
|
+
partitioning: bool = True,
|
54
|
+
nats_manager: bool = True,
|
55
|
+
transaction: bool = True,
|
56
|
+
nidx: bool = True,
|
57
|
+
) -> None:
|
52
58
|
self.service_name = service_name
|
53
59
|
self._initialized: bool = False
|
54
60
|
self._lock = asyncio.Lock()
|
61
|
+
self._kv_driver: Optional[Driver] = None
|
62
|
+
self._blob_storage: Optional[Storage] = None
|
63
|
+
self._shard_manager: Optional[KBShardManager] = None
|
64
|
+
self._partitioning: Optional[PartitionUtility] = None
|
65
|
+
self._nats_manager: Optional[NatsConnectionManager] = None
|
66
|
+
self._transaction: Optional[TransactionUtility] = None
|
67
|
+
self._nidx: Optional[NidxUtility] = None
|
68
|
+
self.enabled_kv_driver = kv_driver
|
69
|
+
self.enabled_blob_storage = blob_storage
|
70
|
+
self.enabled_shard_manager = shard_manager
|
71
|
+
self.enabled_partitioning = partitioning
|
72
|
+
self.enabled_nats_manager = nats_manager
|
73
|
+
self.enabled_transaction = transaction
|
74
|
+
self.enabled_nidx = nidx
|
55
75
|
|
56
76
|
async def initialize(self) -> None:
|
57
77
|
if self._initialized:
|
@@ -63,30 +83,75 @@ class ApplicationContext:
|
|
63
83
|
self._initialized = True
|
64
84
|
|
65
85
|
async def _initialize(self):
|
66
|
-
self.
|
67
|
-
|
68
|
-
self.
|
69
|
-
|
70
|
-
if
|
71
|
-
self.
|
86
|
+
if self.enabled_kv_driver:
|
87
|
+
self._kv_driver = await setup_driver()
|
88
|
+
if self.enabled_blob_storage:
|
89
|
+
self._blob_storage = await get_storage()
|
90
|
+
if self.enabled_shard_manager:
|
91
|
+
self._shard_manager = await setup_cluster()
|
92
|
+
if self.enabled_partitioning:
|
93
|
+
self._partitioning = start_partitioning_utility()
|
94
|
+
if not in_standalone_mode() and self.enabled_nats_manager:
|
95
|
+
self._nats_manager = await start_nats_manager(
|
72
96
|
self.service_name,
|
73
97
|
indexing_settings.index_jetstream_servers,
|
74
98
|
indexing_settings.index_jetstream_auth,
|
75
99
|
)
|
76
|
-
|
77
|
-
|
100
|
+
if self.enabled_transaction:
|
101
|
+
self._transaction = await start_transaction_utility(self.service_name)
|
102
|
+
if self.enabled_nidx:
|
103
|
+
self._nidx = await start_nidx_utility()
|
104
|
+
|
105
|
+
@property
|
106
|
+
def kv_driver(self) -> Driver:
|
107
|
+
assert self._kv_driver is not None, "Driver not initialized"
|
108
|
+
return self._kv_driver
|
109
|
+
|
110
|
+
@property
|
111
|
+
def shard_manager(self) -> KBShardManager:
|
112
|
+
assert self._shard_manager is not None, "Shard manager not initialized"
|
113
|
+
return self._shard_manager
|
114
|
+
|
115
|
+
@property
|
116
|
+
def blob_storage(self) -> Storage:
|
117
|
+
assert self._blob_storage is not None, "Blob storage not initialized"
|
118
|
+
return self._blob_storage
|
119
|
+
|
120
|
+
@property
|
121
|
+
def partitioning(self) -> PartitionUtility:
|
122
|
+
assert self._partitioning is not None, "Partitioning not initialized"
|
123
|
+
return self._partitioning
|
124
|
+
|
125
|
+
@property
|
126
|
+
def nats_manager(self) -> NatsConnectionManager:
|
127
|
+
assert self._nats_manager is not None, "NATS manager not initialized"
|
128
|
+
return self._nats_manager
|
129
|
+
|
130
|
+
@property
|
131
|
+
def transaction(self) -> TransactionUtility:
|
132
|
+
assert self._transaction is not None, "Transaction utility not initialized"
|
133
|
+
return self._transaction
|
134
|
+
|
135
|
+
@property
|
136
|
+
def nidx(self) -> NidxUtility:
|
137
|
+
assert self._nidx is not None, "Nidx utility not initialized"
|
138
|
+
return self._nidx
|
78
139
|
|
79
140
|
async def finalize(self) -> None:
|
80
141
|
if not self._initialized:
|
81
142
|
return
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
143
|
+
if self.enabled_nidx:
|
144
|
+
await stop_nidx_utility()
|
145
|
+
if self.enabled_transaction:
|
146
|
+
await stop_transaction_utility()
|
147
|
+
if not in_standalone_mode() and self.enabled_nats_manager:
|
86
148
|
await stop_nats_manager()
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
149
|
+
if self.enabled_partitioning:
|
150
|
+
stop_partitioning_utility()
|
151
|
+
if self.enabled_shard_manager:
|
152
|
+
await teardown_cluster()
|
153
|
+
if self.enabled_blob_storage:
|
154
|
+
await teardown_storage()
|
155
|
+
if self.enabled_kv_driver:
|
156
|
+
await teardown_driver()
|
92
157
|
self._initialized = False
|
@@ -19,6 +19,7 @@
|
|
19
19
|
#
|
20
20
|
|
21
21
|
from contextlib import asynccontextmanager
|
22
|
+
from typing import Optional
|
22
23
|
|
23
24
|
from fastapi import FastAPI
|
24
25
|
from starlette.routing import Mount
|
@@ -27,8 +28,9 @@ from nucliadb.common.context import ApplicationContext
|
|
27
28
|
|
28
29
|
|
29
30
|
@asynccontextmanager
|
30
|
-
async def inject_app_context(app: FastAPI):
|
31
|
-
context
|
31
|
+
async def inject_app_context(app: FastAPI, context: Optional[ApplicationContext] = None):
|
32
|
+
if context is None:
|
33
|
+
context = ApplicationContext()
|
32
34
|
|
33
35
|
app.state.context = context
|
34
36
|
|
@@ -160,6 +160,8 @@ class IngestConsumer:
|
|
160
160
|
logger.warning("Could not delete blob reference", exc_info=True)
|
161
161
|
|
162
162
|
async def subscription_worker(self, msg: Msg):
|
163
|
+
context.clear_context()
|
164
|
+
|
163
165
|
kbid: Optional[str] = None
|
164
166
|
subject = msg.subject
|
165
167
|
reply = msg.reply
|
@@ -182,7 +184,6 @@ class IngestConsumer:
|
|
182
184
|
MessageProgressUpdater(msg, nats_consumer_settings.nats_ack_wait * 0.66),
|
183
185
|
self.lock,
|
184
186
|
):
|
185
|
-
logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
|
186
187
|
try:
|
187
188
|
pb = await self.get_broker_message(msg)
|
188
189
|
if pb.source == pb.MessageSource.PROCESSOR:
|
@@ -194,10 +195,8 @@ class IngestConsumer:
|
|
194
195
|
else:
|
195
196
|
audit_time = ""
|
196
197
|
|
197
|
-
logger.debug(
|
198
|
-
f"Received from {message_source} on {pb.kbid}/{pb.uuid} seq {seqid} partition {self.partition} at {time}" # noqa
|
199
|
-
)
|
200
198
|
context.add_context({"kbid": pb.kbid, "rid": pb.uuid})
|
199
|
+
logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
|
201
200
|
kbid = pb.kbid
|
202
201
|
try:
|
203
202
|
source = "writer" if pb.source == pb.MessageSource.WRITER else "processor"
|
nucliadb/search/api/v1/find.py
CHANGED
@@ -40,6 +40,7 @@ from nucliadb_models.configuration import FindConfig
|
|
40
40
|
from nucliadb_models.filters import FilterExpression
|
41
41
|
from nucliadb_models.resource import ExtractedDataTypeName, NucliaDBRoles
|
42
42
|
from nucliadb_models.search import (
|
43
|
+
FindOptions,
|
43
44
|
FindRequest,
|
44
45
|
KnowledgeboxFindResults,
|
45
46
|
NucliaDBClientType,
|
@@ -47,7 +48,6 @@ from nucliadb_models.search import (
|
|
47
48
|
Reranker,
|
48
49
|
RerankerName,
|
49
50
|
ResourceProperties,
|
50
|
-
SearchOptions,
|
51
51
|
SearchParamDefaults,
|
52
52
|
)
|
53
53
|
from nucliadb_models.security import RequestSecurity
|
@@ -61,7 +61,7 @@ FIND_EXAMPLES = {
|
|
61
61
|
description="Perform a hybrid search that will return text and semantic results matching the query",
|
62
62
|
value={
|
63
63
|
"query": "How can I be an effective product manager?",
|
64
|
-
"features": [
|
64
|
+
"features": [FindOptions.KEYWORD, FindOptions.SEMANTIC],
|
65
65
|
},
|
66
66
|
)
|
67
67
|
}
|
@@ -110,11 +110,11 @@ async def find_knowledgebox(
|
|
110
110
|
range_modification_end: Optional[DateTime] = fastapi_query(
|
111
111
|
SearchParamDefaults.range_modification_end
|
112
112
|
),
|
113
|
-
features: list[
|
113
|
+
features: list[FindOptions] = fastapi_query(
|
114
114
|
SearchParamDefaults.search_features,
|
115
115
|
default=[
|
116
|
-
|
117
|
-
|
116
|
+
FindOptions.KEYWORD,
|
117
|
+
FindOptions.SEMANTIC,
|
118
118
|
],
|
119
119
|
),
|
120
120
|
debug: bool = fastapi_query(SearchParamDefaults.debug),
|
nucliadb/search/api/v1/search.py
CHANGED
@@ -37,11 +37,9 @@ from nucliadb.search.search import cache
|
|
37
37
|
from nucliadb.search.search.exceptions import InvalidQueryError
|
38
38
|
from nucliadb.search.search.merge import merge_results
|
39
39
|
from nucliadb.search.search.query_parser.parsers.search import parse_search
|
40
|
-
from nucliadb.search.search.query_parser.parsers.unit_retrieval import
|
40
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
|
41
41
|
from nucliadb.search.search.utils import (
|
42
|
-
min_score_from_payload,
|
43
42
|
min_score_from_query_params,
|
44
|
-
should_disable_vector_search,
|
45
43
|
)
|
46
44
|
from nucliadb_models.common import FieldTypeName
|
47
45
|
from nucliadb_models.filters import FilterExpression
|
@@ -263,14 +261,8 @@ async def search(
|
|
263
261
|
audit = get_audit()
|
264
262
|
start_time = time()
|
265
263
|
|
266
|
-
item.min_score = min_score_from_payload(item.min_score)
|
267
|
-
|
268
|
-
if SearchOptions.SEMANTIC in item.features:
|
269
|
-
if should_disable_vector_search(item):
|
270
|
-
item.features.remove(SearchOptions.SEMANTIC)
|
271
|
-
|
272
264
|
parsed = await parse_search(kbid, item)
|
273
|
-
pb_query, incomplete_results, autofilters, _ = await
|
265
|
+
pb_query, incomplete_results, autofilters, _ = await legacy_convert_retrieval_to_proto(parsed)
|
274
266
|
|
275
267
|
# We need to query all nodes
|
276
268
|
results, query_incomplete_results, queried_nodes = await node_query(kbid, Method.SEARCH, pb_query)
|
@@ -80,6 +80,7 @@ from nucliadb_models.search import (
|
|
80
80
|
CitationsAskResponseItem,
|
81
81
|
DebugAskResponseItem,
|
82
82
|
ErrorAskResponseItem,
|
83
|
+
FindOptions,
|
83
84
|
FindParagraph,
|
84
85
|
FindRequest,
|
85
86
|
GraphStrategy,
|
@@ -97,7 +98,6 @@ from nucliadb_models.search import (
|
|
97
98
|
Relations,
|
98
99
|
RelationsAskResponseItem,
|
99
100
|
RetrievalAskResponseItem,
|
100
|
-
SearchOptions,
|
101
101
|
StatusAskResponseItem,
|
102
102
|
SyncAskMetadata,
|
103
103
|
SyncAskResponse,
|
@@ -755,6 +755,9 @@ async def retrieval_in_kb(
|
|
755
755
|
)
|
756
756
|
|
757
757
|
if graph_strategy is not None:
|
758
|
+
assert parsed_query.retrieval.reranker is not None, (
|
759
|
+
"find parser must provide a reranking algorithm"
|
760
|
+
)
|
758
761
|
reranker = get_reranker(parsed_query.retrieval.reranker)
|
759
762
|
graph_results, graph_request = await get_graph_results(
|
760
763
|
kbid=kbid,
|
@@ -952,9 +955,9 @@ def calculate_prequeries_for_json_schema(
|
|
952
955
|
json_schema = ask_request.answer_json_schema or {}
|
953
956
|
features = []
|
954
957
|
if ChatOptions.SEMANTIC in ask_request.features:
|
955
|
-
features.append(
|
958
|
+
features.append(FindOptions.SEMANTIC)
|
956
959
|
if ChatOptions.KEYWORD in ask_request.features:
|
957
|
-
features.append(
|
960
|
+
features.append(FindOptions.KEYWORD)
|
958
961
|
|
959
962
|
properties = json_schema.get("parameters", {}).get("properties", {})
|
960
963
|
if len(properties) == 0: # pragma: no cover
|
@@ -29,7 +29,8 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
|
29
29
|
from nucliadb.search.search.find import find
|
30
30
|
from nucliadb.search.search.merge import merge_relations_results
|
31
31
|
from nucliadb.search.search.metrics import RAGMetrics
|
32
|
-
from nucliadb.search.search.query_parser.models import ParsedQuery
|
32
|
+
from nucliadb.search.search.query_parser.models import ParsedQuery, Query, RelationQuery, UnitRetrieval
|
33
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
|
33
34
|
from nucliadb.search.settings import settings
|
34
35
|
from nucliadb.search.utilities import get_predict
|
35
36
|
from nucliadb_models import filters
|
@@ -37,6 +38,7 @@ from nucliadb_models.search import (
|
|
37
38
|
AskRequest,
|
38
39
|
ChatContextMessage,
|
39
40
|
ChatOptions,
|
41
|
+
FindOptions,
|
40
42
|
FindRequest,
|
41
43
|
KnowledgeboxFindResults,
|
42
44
|
NucliaDBClientType,
|
@@ -47,14 +49,11 @@ from nucliadb_models.search import (
|
|
47
49
|
PromptContextOrder,
|
48
50
|
Relations,
|
49
51
|
RephraseModel,
|
50
|
-
SearchOptions,
|
51
52
|
parse_rephrase_prompt,
|
52
53
|
)
|
53
54
|
from nucliadb_protos import audit_pb2
|
54
55
|
from nucliadb_protos.nodereader_pb2 import (
|
55
|
-
|
56
|
-
RelationSearchResponse,
|
57
|
-
SearchRequest,
|
56
|
+
GraphSearchResponse,
|
58
57
|
SearchResponse,
|
59
58
|
)
|
60
59
|
from nucliadb_protos.utils_pb2 import RelationNode
|
@@ -181,11 +180,11 @@ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
|
|
181
180
|
find_request.resource_filters = item.resource_filters
|
182
181
|
find_request.features = []
|
183
182
|
if ChatOptions.SEMANTIC in item.features:
|
184
|
-
find_request.features.append(
|
183
|
+
find_request.features.append(FindOptions.SEMANTIC)
|
185
184
|
if ChatOptions.KEYWORD in item.features:
|
186
|
-
find_request.features.append(
|
185
|
+
find_request.features.append(FindOptions.KEYWORD)
|
187
186
|
if ChatOptions.RELATIONS in item.features:
|
188
|
-
find_request.features.append(
|
187
|
+
find_request.features.append(FindOptions.RELATIONS)
|
189
188
|
find_request.query = query
|
190
189
|
find_request.fields = item.fields
|
191
190
|
find_request.filters = item.filters
|
@@ -274,13 +273,18 @@ async def get_relations_results_from_entities(
|
|
274
273
|
only_entity_to_entity: bool = False,
|
275
274
|
deleted_entities: set[str] = set(),
|
276
275
|
) -> Relations:
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
276
|
+
entry_points = list(entities)
|
277
|
+
retrieval = UnitRetrieval(
|
278
|
+
query=Query(
|
279
|
+
relation=RelationQuery(
|
280
|
+
entry_points=entry_points,
|
281
|
+
deleted_entities={"": list(deleted_entities)},
|
282
|
+
deleted_entity_groups=[],
|
283
|
+
)
|
284
|
+
),
|
285
|
+
top_k=50,
|
286
|
+
)
|
287
|
+
request = convert_retrieval_to_proto(retrieval)
|
284
288
|
|
285
289
|
results: list[SearchResponse]
|
286
290
|
(
|
@@ -293,10 +297,10 @@ async def get_relations_results_from_entities(
|
|
293
297
|
request,
|
294
298
|
timeout=timeout,
|
295
299
|
)
|
296
|
-
relations_results: list[
|
300
|
+
relations_results: list[GraphSearchResponse] = [result.graph for result in results]
|
297
301
|
return await merge_relations_results(
|
298
302
|
relations_results,
|
299
|
-
|
303
|
+
entry_points,
|
300
304
|
only_with_metadata,
|
301
305
|
only_agentic_relations,
|
302
306
|
only_entity_to_entity,
|
nucliadb/search/search/find.py
CHANGED
@@ -38,7 +38,7 @@ from nucliadb.search.search.metrics import (
|
|
38
38
|
)
|
39
39
|
from nucliadb.search.search.query_parser.models import ParsedQuery
|
40
40
|
from nucliadb.search.search.query_parser.parsers import parse_find
|
41
|
-
from nucliadb.search.search.query_parser.parsers.unit_retrieval import
|
41
|
+
from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
|
42
42
|
from nucliadb.search.search.rank_fusion import (
|
43
43
|
get_rank_fusion,
|
44
44
|
)
|
@@ -92,11 +92,17 @@ async def _index_node_retrieval(
|
|
92
92
|
|
93
93
|
with metrics.time("query_parse"):
|
94
94
|
parsed = await parse_find(kbid, item)
|
95
|
+
assert parsed.retrieval.rank_fusion is not None and parsed.retrieval.reranker is not None, (
|
96
|
+
"find parser must provide rank fusion and reranker algorithms"
|
97
|
+
)
|
95
98
|
rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion)
|
96
99
|
reranker = get_reranker(parsed.retrieval.reranker)
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
+
(
|
101
|
+
pb_query,
|
102
|
+
incomplete_results,
|
103
|
+
autofilters,
|
104
|
+
rephrased_query,
|
105
|
+
) = await legacy_convert_retrieval_to_proto(parsed)
|
100
106
|
|
101
107
|
with metrics.time("node_query"):
|
102
108
|
results, query_incomplete_results, queried_nodes = await node_query(
|
@@ -181,8 +187,11 @@ async def _external_index_retrieval(
|
|
181
187
|
"""
|
182
188
|
# Parse query
|
183
189
|
parsed = await parse_find(kbid, item)
|
190
|
+
assert parsed.retrieval.reranker is not None, "find parser must provide a reranking algorithm"
|
184
191
|
reranker = get_reranker(parsed.retrieval.reranker)
|
185
|
-
search_request, incomplete_results, _, rephrased_query = await
|
192
|
+
search_request, incomplete_results, _, rephrased_query = await legacy_convert_retrieval_to_proto(
|
193
|
+
parsed
|
194
|
+
)
|
186
195
|
|
187
196
|
# Query index
|
188
197
|
query_results = await external_index_manager.query(search_request) # noqa
|
@@ -52,9 +52,9 @@ from nucliadb_models.search import (
|
|
52
52
|
)
|
53
53
|
from nucliadb_protos.nodereader_pb2 import (
|
54
54
|
DocumentScored,
|
55
|
+
GraphSearchResponse,
|
55
56
|
ParagraphResult,
|
56
57
|
ParagraphSearchResponse,
|
57
|
-
RelationSearchResponse,
|
58
58
|
SearchResponse,
|
59
59
|
VectorSearchResponse,
|
60
60
|
)
|
@@ -142,8 +142,8 @@ async def build_find_response(
|
|
142
142
|
# build relations graph
|
143
143
|
entry_points = []
|
144
144
|
if retrieval.query.relation is not None:
|
145
|
-
entry_points = retrieval.query.relation.
|
146
|
-
relations = await merge_relations_results([search_response.
|
145
|
+
entry_points = retrieval.query.relation.entry_points
|
146
|
+
relations = await merge_relations_results([search_response.graph], entry_points)
|
147
147
|
|
148
148
|
# compose response
|
149
149
|
find_resources = compose_find_resources(text_blocks, resources)
|
@@ -178,16 +178,16 @@ def merge_shard_responses(
|
|
178
178
|
"""
|
179
179
|
paragraphs = []
|
180
180
|
vectors = []
|
181
|
-
|
181
|
+
graphs = []
|
182
182
|
for response in responses:
|
183
183
|
paragraphs.append(response.paragraph)
|
184
184
|
vectors.append(response.vector)
|
185
|
-
|
185
|
+
graphs.append(response.graph)
|
186
186
|
|
187
187
|
merged = SearchResponse(
|
188
188
|
paragraph=merge_shards_keyword_responses(paragraphs),
|
189
189
|
vector=merge_shards_semantic_responses(vectors),
|
190
|
-
|
190
|
+
graph=merge_shards_graph_responses(graphs),
|
191
191
|
)
|
192
192
|
return merged
|
193
193
|
|
@@ -230,13 +230,27 @@ def merge_shards_semantic_responses(
|
|
230
230
|
return merged
|
231
231
|
|
232
232
|
|
233
|
-
def
|
234
|
-
|
235
|
-
)
|
236
|
-
merged =
|
237
|
-
|
238
|
-
|
239
|
-
merged.
|
233
|
+
def merge_shards_graph_responses(
|
234
|
+
graph_responses: list[GraphSearchResponse],
|
235
|
+
):
|
236
|
+
merged = GraphSearchResponse()
|
237
|
+
|
238
|
+
for response in graph_responses:
|
239
|
+
nodes_offset = len(merged.nodes)
|
240
|
+
relations_offset = len(merged.relations)
|
241
|
+
|
242
|
+
# paths contain indexes to nodes and relations, we must offset them
|
243
|
+
# while merging responses to maintain valid data
|
244
|
+
for path in response.graph:
|
245
|
+
merged_path = GraphSearchResponse.Path()
|
246
|
+
merged_path.CopyFrom(path)
|
247
|
+
merged_path.source += nodes_offset
|
248
|
+
merged_path.relation += relations_offset
|
249
|
+
merged_path.destination += nodes_offset
|
250
|
+
merged.graph.append(merged_path)
|
251
|
+
|
252
|
+
merged.nodes.extend(response.nodes)
|
253
|
+
merged.relations.extend(response.relations)
|
240
254
|
|
241
255
|
return merged
|
242
256
|
|
nucliadb/search/search/merge.py
CHANGED
@@ -65,9 +65,9 @@ from nucliadb_protos.nodereader_pb2 import (
|
|
65
65
|
DocumentResult,
|
66
66
|
DocumentScored,
|
67
67
|
DocumentSearchResponse,
|
68
|
+
GraphSearchResponse,
|
68
69
|
ParagraphResult,
|
69
70
|
ParagraphSearchResponse,
|
70
|
-
RelationSearchResponse,
|
71
71
|
SearchResponse,
|
72
72
|
SuggestResponse,
|
73
73
|
VectorSearchResponse,
|
@@ -438,7 +438,7 @@ async def merge_paragraph_results(
|
|
438
438
|
|
439
439
|
@merge_observer.wrap({"type": "merge_relations"})
|
440
440
|
async def merge_relations_results(
|
441
|
-
|
441
|
+
graph_responses: list[GraphSearchResponse],
|
442
442
|
query_entry_points: Iterable[RelationNode],
|
443
443
|
only_with_metadata: bool = False,
|
444
444
|
only_agentic: bool = False,
|
@@ -448,7 +448,7 @@ async def merge_relations_results(
|
|
448
448
|
return await loop.run_in_executor(
|
449
449
|
None,
|
450
450
|
_merge_relations_results,
|
451
|
-
|
451
|
+
graph_responses,
|
452
452
|
query_entry_points,
|
453
453
|
only_with_metadata,
|
454
454
|
only_agentic,
|
@@ -457,7 +457,7 @@ async def merge_relations_results(
|
|
457
457
|
|
458
458
|
|
459
459
|
def _merge_relations_results(
|
460
|
-
|
460
|
+
graph_responses: list[GraphSearchResponse],
|
461
461
|
query_entry_points: Iterable[RelationNode],
|
462
462
|
only_with_metadata: bool,
|
463
463
|
only_agentic: bool,
|
@@ -480,17 +480,16 @@ def _merge_relations_results(
|
|
480
480
|
for entry_point in query_entry_points:
|
481
481
|
relations.entities[entry_point.value] = EntitySubgraph(related_to=[])
|
482
482
|
|
483
|
-
for
|
484
|
-
for
|
485
|
-
relation =
|
486
|
-
origin =
|
487
|
-
destination =
|
488
|
-
relation_type = RelationTypePbMap[relation.
|
489
|
-
relation_label = relation.
|
490
|
-
metadata =
|
491
|
-
|
492
|
-
|
493
|
-
resource_id = index_relation.resource_field_id.split("/")[0]
|
483
|
+
for graph_response in graph_responses:
|
484
|
+
for path in graph_response.graph:
|
485
|
+
relation = graph_response.relations[path.relation]
|
486
|
+
origin = graph_response.nodes[path.source]
|
487
|
+
destination = graph_response.nodes[path.destination]
|
488
|
+
relation_type = RelationTypePbMap[relation.relation_type]
|
489
|
+
relation_label = relation.label
|
490
|
+
metadata = path.metadata if path.HasField("metadata") else None
|
491
|
+
if path.resource_field_id is not None:
|
492
|
+
resource_id = path.resource_field_id.split("/")[0]
|
494
493
|
|
495
494
|
# If only_with_metadata is True, we check that metadata for the relation is not None
|
496
495
|
# If only_agentic is True, we check that metadata for the relation is not None and that it has a data_augmentation_task_id
|
@@ -547,13 +546,13 @@ async def merge_results(
|
|
547
546
|
paragraphs = []
|
548
547
|
documents = []
|
549
548
|
vectors = []
|
550
|
-
|
549
|
+
graphs = []
|
551
550
|
|
552
551
|
for response in search_responses:
|
553
552
|
paragraphs.append(response.paragraph)
|
554
553
|
documents.append(response.document)
|
555
554
|
vectors.append(response.vector)
|
556
|
-
|
555
|
+
graphs.append(response.graph)
|
557
556
|
|
558
557
|
api_results = KnowledgeboxSearchResults()
|
559
558
|
|
@@ -595,7 +594,7 @@ async def merge_results(
|
|
595
594
|
|
596
595
|
if retrieval.query.relation is not None:
|
597
596
|
api_results.relations = await merge_relations_results(
|
598
|
-
|
597
|
+
graphs, retrieval.query.relation.entry_points
|
599
598
|
)
|
600
599
|
|
601
600
|
api_results.resources = await fetch_resources(resources, kbid, show, field_type_filter, extracted)
|