nucliadb 6.3.7.post4081__py3-none-any.whl → 6.3.7.post4114__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,17 +18,19 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
+ from typing import Optional
21
22
 
22
23
  from nucliadb.common.cluster.manager import KBShardManager
23
24
  from nucliadb.common.cluster.settings import in_standalone_mode
24
25
  from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
25
26
  from nucliadb.common.maindb.driver import Driver
26
27
  from nucliadb.common.maindb.utils import setup_driver, teardown_driver
27
- from nucliadb.common.nidx import start_nidx_utility, stop_nidx_utility
28
+ from nucliadb.common.nidx import NidxUtility, start_nidx_utility, stop_nidx_utility
28
29
  from nucliadb_utils.nats import NatsConnectionManager
29
30
  from nucliadb_utils.partition import PartitionUtility
30
31
  from nucliadb_utils.settings import indexing_settings
31
32
  from nucliadb_utils.storages.storage import Storage
33
+ from nucliadb_utils.transaction import TransactionUtility
32
34
  from nucliadb_utils.utilities import (
33
35
  get_storage,
34
36
  start_nats_manager,
@@ -42,16 +44,34 @@ from nucliadb_utils.utilities import (
42
44
 
43
45
 
44
46
  class ApplicationContext:
45
- kv_driver: Driver
46
- shard_manager: KBShardManager
47
- blob_storage: Storage
48
- partitioning: PartitionUtility
49
- nats_manager: NatsConnectionManager
50
-
51
- def __init__(self, service_name: str = "service") -> None:
47
+ def __init__(
48
+ self,
49
+ service_name: str = "service",
50
+ kv_driver: bool = True,
51
+ blob_storage: bool = True,
52
+ shard_manager: bool = True,
53
+ partitioning: bool = True,
54
+ nats_manager: bool = True,
55
+ transaction: bool = True,
56
+ nidx: bool = True,
57
+ ) -> None:
52
58
  self.service_name = service_name
53
59
  self._initialized: bool = False
54
60
  self._lock = asyncio.Lock()
61
+ self._kv_driver: Optional[Driver] = None
62
+ self._blob_storage: Optional[Storage] = None
63
+ self._shard_manager: Optional[KBShardManager] = None
64
+ self._partitioning: Optional[PartitionUtility] = None
65
+ self._nats_manager: Optional[NatsConnectionManager] = None
66
+ self._transaction: Optional[TransactionUtility] = None
67
+ self._nidx: Optional[NidxUtility] = None
68
+ self.enabled_kv_driver = kv_driver
69
+ self.enabled_blob_storage = blob_storage
70
+ self.enabled_shard_manager = shard_manager
71
+ self.enabled_partitioning = partitioning
72
+ self.enabled_nats_manager = nats_manager
73
+ self.enabled_transaction = transaction
74
+ self.enabled_nidx = nidx
55
75
 
56
76
  async def initialize(self) -> None:
57
77
  if self._initialized:
@@ -63,30 +83,75 @@ class ApplicationContext:
63
83
  self._initialized = True
64
84
 
65
85
  async def _initialize(self):
66
- self.kv_driver = await setup_driver()
67
- self.blob_storage = await get_storage()
68
- self.shard_manager = await setup_cluster()
69
- self.partitioning = start_partitioning_utility()
70
- if not in_standalone_mode():
71
- self.nats_manager = await start_nats_manager(
86
+ if self.enabled_kv_driver:
87
+ self._kv_driver = await setup_driver()
88
+ if self.enabled_blob_storage:
89
+ self._blob_storage = await get_storage()
90
+ if self.enabled_shard_manager:
91
+ self._shard_manager = await setup_cluster()
92
+ if self.enabled_partitioning:
93
+ self._partitioning = start_partitioning_utility()
94
+ if not in_standalone_mode() and self.enabled_nats_manager:
95
+ self._nats_manager = await start_nats_manager(
72
96
  self.service_name,
73
97
  indexing_settings.index_jetstream_servers,
74
98
  indexing_settings.index_jetstream_auth,
75
99
  )
76
- self.transaction = await start_transaction_utility(self.service_name)
77
- self.nidx = await start_nidx_utility()
100
+ if self.enabled_transaction:
101
+ self._transaction = await start_transaction_utility(self.service_name)
102
+ if self.enabled_nidx:
103
+ self._nidx = await start_nidx_utility()
104
+
105
+ @property
106
+ def kv_driver(self) -> Driver:
107
+ assert self._kv_driver is not None, "Driver not initialized"
108
+ return self._kv_driver
109
+
110
+ @property
111
+ def shard_manager(self) -> KBShardManager:
112
+ assert self._shard_manager is not None, "Shard manager not initialized"
113
+ return self._shard_manager
114
+
115
+ @property
116
+ def blob_storage(self) -> Storage:
117
+ assert self._blob_storage is not None, "Blob storage not initialized"
118
+ return self._blob_storage
119
+
120
+ @property
121
+ def partitioning(self) -> PartitionUtility:
122
+ assert self._partitioning is not None, "Partitioning not initialized"
123
+ return self._partitioning
124
+
125
+ @property
126
+ def nats_manager(self) -> NatsConnectionManager:
127
+ assert self._nats_manager is not None, "NATS manager not initialized"
128
+ return self._nats_manager
129
+
130
+ @property
131
+ def transaction(self) -> TransactionUtility:
132
+ assert self._transaction is not None, "Transaction utility not initialized"
133
+ return self._transaction
134
+
135
+ @property
136
+ def nidx(self) -> NidxUtility:
137
+ assert self._nidx is not None, "Nidx utility not initialized"
138
+ return self._nidx
78
139
 
79
140
  async def finalize(self) -> None:
80
141
  if not self._initialized:
81
142
  return
82
-
83
- await stop_nidx_utility()
84
- await stop_transaction_utility()
85
- if not in_standalone_mode():
143
+ if self.enabled_nidx:
144
+ await stop_nidx_utility()
145
+ if self.enabled_transaction:
146
+ await stop_transaction_utility()
147
+ if not in_standalone_mode() and self.enabled_nats_manager:
86
148
  await stop_nats_manager()
87
-
88
- stop_partitioning_utility()
89
- await teardown_cluster()
90
- await teardown_driver()
91
- await teardown_storage()
149
+ if self.enabled_partitioning:
150
+ stop_partitioning_utility()
151
+ if self.enabled_shard_manager:
152
+ await teardown_cluster()
153
+ if self.enabled_blob_storage:
154
+ await teardown_storage()
155
+ if self.enabled_kv_driver:
156
+ await teardown_driver()
92
157
  self._initialized = False
@@ -19,6 +19,7 @@
19
19
  #
20
20
 
21
21
  from contextlib import asynccontextmanager
22
+ from typing import Optional
22
23
 
23
24
  from fastapi import FastAPI
24
25
  from starlette.routing import Mount
@@ -27,8 +28,9 @@ from nucliadb.common.context import ApplicationContext
27
28
 
28
29
 
29
30
  @asynccontextmanager
30
- async def inject_app_context(app: FastAPI):
31
- context = ApplicationContext()
31
+ async def inject_app_context(app: FastAPI, context: Optional[ApplicationContext] = None):
32
+ if context is None:
33
+ context = ApplicationContext()
32
34
 
33
35
  app.state.context = context
34
36
 
@@ -160,6 +160,8 @@ class IngestConsumer:
160
160
  logger.warning("Could not delete blob reference", exc_info=True)
161
161
 
162
162
  async def subscription_worker(self, msg: Msg):
163
+ context.clear_context()
164
+
163
165
  kbid: Optional[str] = None
164
166
  subject = msg.subject
165
167
  reply = msg.reply
@@ -182,7 +184,6 @@ class IngestConsumer:
182
184
  MessageProgressUpdater(msg, nats_consumer_settings.nats_ack_wait * 0.66),
183
185
  self.lock,
184
186
  ):
185
- logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
186
187
  try:
187
188
  pb = await self.get_broker_message(msg)
188
189
  if pb.source == pb.MessageSource.PROCESSOR:
@@ -194,10 +195,8 @@ class IngestConsumer:
194
195
  else:
195
196
  audit_time = ""
196
197
 
197
- logger.debug(
198
- f"Received from {message_source} on {pb.kbid}/{pb.uuid} seq {seqid} partition {self.partition} at {time}" # noqa
199
- )
200
198
  context.add_context({"kbid": pb.kbid, "rid": pb.uuid})
199
+ logger.info(f"Message processing: subject:{subject}, seqid: {seqid}, reply: {reply}")
201
200
  kbid = pb.kbid
202
201
  try:
203
202
  source = "writer" if pb.source == pb.MessageSource.WRITER else "processor"
@@ -40,6 +40,7 @@ from nucliadb_models.configuration import FindConfig
40
40
  from nucliadb_models.filters import FilterExpression
41
41
  from nucliadb_models.resource import ExtractedDataTypeName, NucliaDBRoles
42
42
  from nucliadb_models.search import (
43
+ FindOptions,
43
44
  FindRequest,
44
45
  KnowledgeboxFindResults,
45
46
  NucliaDBClientType,
@@ -47,7 +48,6 @@ from nucliadb_models.search import (
47
48
  Reranker,
48
49
  RerankerName,
49
50
  ResourceProperties,
50
- SearchOptions,
51
51
  SearchParamDefaults,
52
52
  )
53
53
  from nucliadb_models.security import RequestSecurity
@@ -61,7 +61,7 @@ FIND_EXAMPLES = {
61
61
  description="Perform a hybrid search that will return text and semantic results matching the query",
62
62
  value={
63
63
  "query": "How can I be an effective product manager?",
64
- "features": [SearchOptions.KEYWORD, SearchOptions.SEMANTIC],
64
+ "features": [FindOptions.KEYWORD, FindOptions.SEMANTIC],
65
65
  },
66
66
  )
67
67
  }
@@ -110,11 +110,11 @@ async def find_knowledgebox(
110
110
  range_modification_end: Optional[DateTime] = fastapi_query(
111
111
  SearchParamDefaults.range_modification_end
112
112
  ),
113
- features: list[SearchOptions] = fastapi_query(
113
+ features: list[FindOptions] = fastapi_query(
114
114
  SearchParamDefaults.search_features,
115
115
  default=[
116
- SearchOptions.KEYWORD,
117
- SearchOptions.SEMANTIC,
116
+ FindOptions.KEYWORD,
117
+ FindOptions.SEMANTIC,
118
118
  ],
119
119
  ),
120
120
  debug: bool = fastapi_query(SearchParamDefaults.debug),
@@ -37,11 +37,9 @@ from nucliadb.search.search import cache
37
37
  from nucliadb.search.search.exceptions import InvalidQueryError
38
38
  from nucliadb.search.search.merge import merge_results
39
39
  from nucliadb.search.search.query_parser.parsers.search import parse_search
40
- from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
40
+ from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
41
41
  from nucliadb.search.search.utils import (
42
- min_score_from_payload,
43
42
  min_score_from_query_params,
44
- should_disable_vector_search,
45
43
  )
46
44
  from nucliadb_models.common import FieldTypeName
47
45
  from nucliadb_models.filters import FilterExpression
@@ -263,14 +261,8 @@ async def search(
263
261
  audit = get_audit()
264
262
  start_time = time()
265
263
 
266
- item.min_score = min_score_from_payload(item.min_score)
267
-
268
- if SearchOptions.SEMANTIC in item.features:
269
- if should_disable_vector_search(item):
270
- item.features.remove(SearchOptions.SEMANTIC)
271
-
272
264
  parsed = await parse_search(kbid, item)
273
- pb_query, incomplete_results, autofilters, _ = await convert_retrieval_to_proto(parsed)
265
+ pb_query, incomplete_results, autofilters, _ = await legacy_convert_retrieval_to_proto(parsed)
274
266
 
275
267
  # We need to query all nodes
276
268
  results, query_incomplete_results, queried_nodes = await node_query(kbid, Method.SEARCH, pb_query)
@@ -80,6 +80,7 @@ from nucliadb_models.search import (
80
80
  CitationsAskResponseItem,
81
81
  DebugAskResponseItem,
82
82
  ErrorAskResponseItem,
83
+ FindOptions,
83
84
  FindParagraph,
84
85
  FindRequest,
85
86
  GraphStrategy,
@@ -97,7 +98,6 @@ from nucliadb_models.search import (
97
98
  Relations,
98
99
  RelationsAskResponseItem,
99
100
  RetrievalAskResponseItem,
100
- SearchOptions,
101
101
  StatusAskResponseItem,
102
102
  SyncAskMetadata,
103
103
  SyncAskResponse,
@@ -755,6 +755,9 @@ async def retrieval_in_kb(
755
755
  )
756
756
 
757
757
  if graph_strategy is not None:
758
+ assert parsed_query.retrieval.reranker is not None, (
759
+ "find parser must provide a reranking algorithm"
760
+ )
758
761
  reranker = get_reranker(parsed_query.retrieval.reranker)
759
762
  graph_results, graph_request = await get_graph_results(
760
763
  kbid=kbid,
@@ -952,9 +955,9 @@ def calculate_prequeries_for_json_schema(
952
955
  json_schema = ask_request.answer_json_schema or {}
953
956
  features = []
954
957
  if ChatOptions.SEMANTIC in ask_request.features:
955
- features.append(SearchOptions.SEMANTIC)
958
+ features.append(FindOptions.SEMANTIC)
956
959
  if ChatOptions.KEYWORD in ask_request.features:
957
- features.append(SearchOptions.KEYWORD)
960
+ features.append(FindOptions.KEYWORD)
958
961
 
959
962
  properties = json_schema.get("parameters", {}).get("properties", {})
960
963
  if len(properties) == 0: # pragma: no cover
@@ -29,7 +29,8 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError
29
29
  from nucliadb.search.search.find import find
30
30
  from nucliadb.search.search.merge import merge_relations_results
31
31
  from nucliadb.search.search.metrics import RAGMetrics
32
- from nucliadb.search.search.query_parser.models import ParsedQuery
32
+ from nucliadb.search.search.query_parser.models import ParsedQuery, Query, RelationQuery, UnitRetrieval
33
+ from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
33
34
  from nucliadb.search.settings import settings
34
35
  from nucliadb.search.utilities import get_predict
35
36
  from nucliadb_models import filters
@@ -37,6 +38,7 @@ from nucliadb_models.search import (
37
38
  AskRequest,
38
39
  ChatContextMessage,
39
40
  ChatOptions,
41
+ FindOptions,
40
42
  FindRequest,
41
43
  KnowledgeboxFindResults,
42
44
  NucliaDBClientType,
@@ -47,14 +49,11 @@ from nucliadb_models.search import (
47
49
  PromptContextOrder,
48
50
  Relations,
49
51
  RephraseModel,
50
- SearchOptions,
51
52
  parse_rephrase_prompt,
52
53
  )
53
54
  from nucliadb_protos import audit_pb2
54
55
  from nucliadb_protos.nodereader_pb2 import (
55
- EntitiesSubgraphRequest,
56
- RelationSearchResponse,
57
- SearchRequest,
56
+ GraphSearchResponse,
58
57
  SearchResponse,
59
58
  )
60
59
  from nucliadb_protos.utils_pb2 import RelationNode
@@ -181,11 +180,11 @@ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
181
180
  find_request.resource_filters = item.resource_filters
182
181
  find_request.features = []
183
182
  if ChatOptions.SEMANTIC in item.features:
184
- find_request.features.append(SearchOptions.SEMANTIC)
183
+ find_request.features.append(FindOptions.SEMANTIC)
185
184
  if ChatOptions.KEYWORD in item.features:
186
- find_request.features.append(SearchOptions.KEYWORD)
185
+ find_request.features.append(FindOptions.KEYWORD)
187
186
  if ChatOptions.RELATIONS in item.features:
188
- find_request.features.append(SearchOptions.RELATIONS)
187
+ find_request.features.append(FindOptions.RELATIONS)
189
188
  find_request.query = query
190
189
  find_request.fields = item.fields
191
190
  find_request.filters = item.filters
@@ -274,13 +273,18 @@ async def get_relations_results_from_entities(
274
273
  only_entity_to_entity: bool = False,
275
274
  deleted_entities: set[str] = set(),
276
275
  ) -> Relations:
277
- request = SearchRequest()
278
- request.relation_subgraph.entry_points.extend(entities)
279
- request.relation_subgraph.depth = 1
280
-
281
- deleted = EntitiesSubgraphRequest.DeletedEntities()
282
- deleted.node_values.extend(deleted_entities)
283
- request.relation_subgraph.deleted_entities.append(deleted)
276
+ entry_points = list(entities)
277
+ retrieval = UnitRetrieval(
278
+ query=Query(
279
+ relation=RelationQuery(
280
+ entry_points=entry_points,
281
+ deleted_entities={"": list(deleted_entities)},
282
+ deleted_entity_groups=[],
283
+ )
284
+ ),
285
+ top_k=50,
286
+ )
287
+ request = convert_retrieval_to_proto(retrieval)
284
288
 
285
289
  results: list[SearchResponse]
286
290
  (
@@ -293,10 +297,10 @@ async def get_relations_results_from_entities(
293
297
  request,
294
298
  timeout=timeout,
295
299
  )
296
- relations_results: list[RelationSearchResponse] = [result.relation for result in results]
300
+ relations_results: list[GraphSearchResponse] = [result.graph for result in results]
297
301
  return await merge_relations_results(
298
302
  relations_results,
299
- request.relation_subgraph.entry_points,
303
+ entry_points,
300
304
  only_with_metadata,
301
305
  only_agentic_relations,
302
306
  only_entity_to_entity,
@@ -38,7 +38,7 @@ from nucliadb.search.search.metrics import (
38
38
  )
39
39
  from nucliadb.search.search.query_parser.models import ParsedQuery
40
40
  from nucliadb.search.search.query_parser.parsers import parse_find
41
- from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
41
+ from nucliadb.search.search.query_parser.parsers.unit_retrieval import legacy_convert_retrieval_to_proto
42
42
  from nucliadb.search.search.rank_fusion import (
43
43
  get_rank_fusion,
44
44
  )
@@ -92,11 +92,17 @@ async def _index_node_retrieval(
92
92
 
93
93
  with metrics.time("query_parse"):
94
94
  parsed = await parse_find(kbid, item)
95
+ assert parsed.retrieval.rank_fusion is not None and parsed.retrieval.reranker is not None, (
96
+ "find parser must provide rank fusion and reranker algorithms"
97
+ )
95
98
  rank_fusion = get_rank_fusion(parsed.retrieval.rank_fusion)
96
99
  reranker = get_reranker(parsed.retrieval.reranker)
97
- pb_query, incomplete_results, autofilters, rephrased_query = await convert_retrieval_to_proto(
98
- parsed
99
- )
100
+ (
101
+ pb_query,
102
+ incomplete_results,
103
+ autofilters,
104
+ rephrased_query,
105
+ ) = await legacy_convert_retrieval_to_proto(parsed)
100
106
 
101
107
  with metrics.time("node_query"):
102
108
  results, query_incomplete_results, queried_nodes = await node_query(
@@ -181,8 +187,11 @@ async def _external_index_retrieval(
181
187
  """
182
188
  # Parse query
183
189
  parsed = await parse_find(kbid, item)
190
+ assert parsed.retrieval.reranker is not None, "find parser must provide a reranking algorithm"
184
191
  reranker = get_reranker(parsed.retrieval.reranker)
185
- search_request, incomplete_results, _, rephrased_query = await convert_retrieval_to_proto(parsed)
192
+ search_request, incomplete_results, _, rephrased_query = await legacy_convert_retrieval_to_proto(
193
+ parsed
194
+ )
186
195
 
187
196
  # Query index
188
197
  query_results = await external_index_manager.query(search_request) # noqa
@@ -52,9 +52,9 @@ from nucliadb_models.search import (
52
52
  )
53
53
  from nucliadb_protos.nodereader_pb2 import (
54
54
  DocumentScored,
55
+ GraphSearchResponse,
55
56
  ParagraphResult,
56
57
  ParagraphSearchResponse,
57
- RelationSearchResponse,
58
58
  SearchResponse,
59
59
  VectorSearchResponse,
60
60
  )
@@ -142,8 +142,8 @@ async def build_find_response(
142
142
  # build relations graph
143
143
  entry_points = []
144
144
  if retrieval.query.relation is not None:
145
- entry_points = retrieval.query.relation.detected_entities
146
- relations = await merge_relations_results([search_response.relation], entry_points)
145
+ entry_points = retrieval.query.relation.entry_points
146
+ relations = await merge_relations_results([search_response.graph], entry_points)
147
147
 
148
148
  # compose response
149
149
  find_resources = compose_find_resources(text_blocks, resources)
@@ -178,16 +178,16 @@ def merge_shard_responses(
178
178
  """
179
179
  paragraphs = []
180
180
  vectors = []
181
- relations = []
181
+ graphs = []
182
182
  for response in responses:
183
183
  paragraphs.append(response.paragraph)
184
184
  vectors.append(response.vector)
185
- relations.append(response.relation)
185
+ graphs.append(response.graph)
186
186
 
187
187
  merged = SearchResponse(
188
188
  paragraph=merge_shards_keyword_responses(paragraphs),
189
189
  vector=merge_shards_semantic_responses(vectors),
190
- relation=merge_shards_relation_responses(relations),
190
+ graph=merge_shards_graph_responses(graphs),
191
191
  )
192
192
  return merged
193
193
 
@@ -230,13 +230,27 @@ def merge_shards_semantic_responses(
230
230
  return merged
231
231
 
232
232
 
233
- def merge_shards_relation_responses(
234
- relation_responses: list[RelationSearchResponse],
235
- ) -> RelationSearchResponse:
236
- merged = RelationSearchResponse()
237
- for response in relation_responses:
238
- merged.prefix.nodes.extend(response.prefix.nodes)
239
- merged.subgraph.relations.extend(response.subgraph.relations)
233
+ def merge_shards_graph_responses(
234
+ graph_responses: list[GraphSearchResponse],
235
+ ):
236
+ merged = GraphSearchResponse()
237
+
238
+ for response in graph_responses:
239
+ nodes_offset = len(merged.nodes)
240
+ relations_offset = len(merged.relations)
241
+
242
+ # paths contain indexes to nodes and relations, we must offset them
243
+ # while merging responses to maintain valid data
244
+ for path in response.graph:
245
+ merged_path = GraphSearchResponse.Path()
246
+ merged_path.CopyFrom(path)
247
+ merged_path.source += nodes_offset
248
+ merged_path.relation += relations_offset
249
+ merged_path.destination += nodes_offset
250
+ merged.graph.append(merged_path)
251
+
252
+ merged.nodes.extend(response.nodes)
253
+ merged.relations.extend(response.relations)
240
254
 
241
255
  return merged
242
256
 
@@ -65,9 +65,9 @@ from nucliadb_protos.nodereader_pb2 import (
65
65
  DocumentResult,
66
66
  DocumentScored,
67
67
  DocumentSearchResponse,
68
+ GraphSearchResponse,
68
69
  ParagraphResult,
69
70
  ParagraphSearchResponse,
70
- RelationSearchResponse,
71
71
  SearchResponse,
72
72
  SuggestResponse,
73
73
  VectorSearchResponse,
@@ -438,7 +438,7 @@ async def merge_paragraph_results(
438
438
 
439
439
  @merge_observer.wrap({"type": "merge_relations"})
440
440
  async def merge_relations_results(
441
- relations_responses: list[RelationSearchResponse],
441
+ graph_responses: list[GraphSearchResponse],
442
442
  query_entry_points: Iterable[RelationNode],
443
443
  only_with_metadata: bool = False,
444
444
  only_agentic: bool = False,
@@ -448,7 +448,7 @@ async def merge_relations_results(
448
448
  return await loop.run_in_executor(
449
449
  None,
450
450
  _merge_relations_results,
451
- relations_responses,
451
+ graph_responses,
452
452
  query_entry_points,
453
453
  only_with_metadata,
454
454
  only_agentic,
@@ -457,7 +457,7 @@ async def merge_relations_results(
457
457
 
458
458
 
459
459
  def _merge_relations_results(
460
- relations_responses: list[RelationSearchResponse],
460
+ graph_responses: list[GraphSearchResponse],
461
461
  query_entry_points: Iterable[RelationNode],
462
462
  only_with_metadata: bool,
463
463
  only_agentic: bool,
@@ -480,17 +480,16 @@ def _merge_relations_results(
480
480
  for entry_point in query_entry_points:
481
481
  relations.entities[entry_point.value] = EntitySubgraph(related_to=[])
482
482
 
483
- for relation_response in relations_responses:
484
- for index_relation in relation_response.subgraph.relations:
485
- relation = index_relation.relation
486
- origin = relation.source
487
- destination = relation.to
488
- relation_type = RelationTypePbMap[relation.relation] # type: ignore
489
- relation_label = relation.relation_label
490
- metadata = relation.metadata if relation.HasField("metadata") else None
491
-
492
- if index_relation.resource_field_id is not None:
493
- resource_id = index_relation.resource_field_id.split("/")[0]
483
+ for graph_response in graph_responses:
484
+ for path in graph_response.graph:
485
+ relation = graph_response.relations[path.relation]
486
+ origin = graph_response.nodes[path.source]
487
+ destination = graph_response.nodes[path.destination]
488
+ relation_type = RelationTypePbMap[relation.relation_type]
489
+ relation_label = relation.label
490
+ metadata = path.metadata if path.HasField("metadata") else None
491
+ if path.resource_field_id is not None:
492
+ resource_id = path.resource_field_id.split("/")[0]
494
493
 
495
494
  # If only_with_metadata is True, we check that metadata for the relation is not None
496
495
  # If only_agentic is True, we check that metadata for the relation is not None and that it has a data_augmentation_task_id
@@ -547,13 +546,13 @@ async def merge_results(
547
546
  paragraphs = []
548
547
  documents = []
549
548
  vectors = []
550
- relations = []
549
+ graphs = []
551
550
 
552
551
  for response in search_responses:
553
552
  paragraphs.append(response.paragraph)
554
553
  documents.append(response.document)
555
554
  vectors.append(response.vector)
556
- relations.append(response.relation)
555
+ graphs.append(response.graph)
557
556
 
558
557
  api_results = KnowledgeboxSearchResults()
559
558
 
@@ -595,7 +594,7 @@ async def merge_results(
595
594
 
596
595
  if retrieval.query.relation is not None:
597
596
  api_results.relations = await merge_relations_results(
598
- relations, retrieval.query.relation.detected_entities
597
+ graphs, retrieval.query.relation.entry_points
599
598
  )
600
599
 
601
600
  api_results.resources = await fetch_resources(resources, kbid, show, field_type_filter, extracted)