nucliadb 6.2.1.post2835__py3-none-any.whl → 6.2.1.post2842__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ from nucliadb.common.counters import IndexCounts
28
28
  from nucliadb.common.external_index_providers.exceptions import ExternalIndexingError
29
29
  from nucliadb.common.ids import ParagraphId
30
30
  from nucliadb_models.external_index_providers import ExternalIndexProviderType
31
- from nucliadb_models.search import SCORE_TYPE, TextPosition
31
+ from nucliadb_models.search import SCORE_TYPE, Relations, TextPosition
32
32
  from nucliadb_protos.knowledgebox_pb2 import (
33
33
  CreateExternalIndexProviderMetadata,
34
34
  StoredExternalIndexProviderMetadata,
@@ -73,6 +73,7 @@ class TextBlockMatch(BaseModel):
73
73
  paragraph_labels: list[str] = []
74
74
  field_labels: list[str] = []
75
75
  text: Optional[str] = None
76
+ relevant_relations: Optional[Relations] = None
76
77
 
77
78
 
78
79
  class QueryResults(BaseModel):
nucliadb/common/ids.py CHANGED
@@ -111,13 +111,11 @@ class FieldId:
111
111
  parts = value.split("/")
112
112
  if len(parts) == 3:
113
113
  rid, _type, key = parts
114
- if _type not in FIELD_TYPE_STR_TO_PB:
115
- raise ValueError(f"Invalid FieldId: {value}")
114
+ _type = cls.parse_field_type(_type)
116
115
  return cls(rid=rid, type=_type, key=key)
117
116
  elif len(parts) == 4:
118
117
  rid, _type, key, subfield_id = parts
119
- if _type not in FIELD_TYPE_STR_TO_PB:
120
- raise ValueError(f"Invalid FieldId: {value}")
118
+ _type = cls.parse_field_type(_type)
121
119
  return cls(
122
120
  rid=rid,
123
121
  type=_type,
@@ -127,6 +125,22 @@ class FieldId:
127
125
  else:
128
126
  raise ValueError(f"Invalid FieldId: {value}")
129
127
 
128
+ @classmethod
129
+ def parse_field_type(cls, _type: str) -> str:
130
+ if _type not in FIELD_TYPE_STR_TO_PB:
131
+ # Try to parse the enum value
132
+ # XXX: This is to support field types that are integer values of FieldType
133
+ # Which is how legacy processor relations reported the paragraph_id
134
+ try:
135
+ type_pb = FieldType.ValueType(int(_type))
136
+ except ValueError:
137
+ raise ValueError(f"Invalid FieldId: {_type}")
138
+ if type_pb in FIELD_TYPE_PB_TO_STR:
139
+ return FIELD_TYPE_PB_TO_STR[type_pb]
140
+ else:
141
+ raise ValueError(f"Invalid FieldId: {_type}")
142
+ return _type
143
+
130
144
 
131
145
  @dataclass
132
146
  class ParagraphId:
@@ -151,8 +151,6 @@ async def suggest(
151
151
  search_results = await merge_suggest_results(
152
152
  results,
153
153
  kbid=kbid,
154
- show=show,
155
- field_type_filter=field_type_filter,
156
154
  highlight=highlight,
157
155
  )
158
156
 
@@ -57,6 +57,7 @@ from nucliadb.search.search.exceptions import (
57
57
  IncompleteFindResultsError,
58
58
  InvalidQueryError,
59
59
  )
60
+ from nucliadb.search.search.graph_strategy import get_graph_results
60
61
  from nucliadb.search.search.metrics import RAGMetrics
61
62
  from nucliadb.search.search.query import QueryParser
62
63
  from nucliadb.search.utilities import get_predict
@@ -75,6 +76,7 @@ from nucliadb_models.search import (
75
76
  ErrorAskResponseItem,
76
77
  FindParagraph,
77
78
  FindRequest,
79
+ GraphStrategy,
78
80
  JSONAskResponseItem,
79
81
  KnowledgeboxFindResults,
80
82
  MetadataAskResponseItem,
@@ -629,6 +631,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
629
631
  return None
630
632
 
631
633
 
634
+ def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
635
+ for rag_strategy in ask_request.rag_strategies:
636
+ if rag_strategy.name == RagStrategyName.GRAPH:
637
+ return cast(GraphStrategy, rag_strategy)
638
+ return None
639
+
640
+
632
641
  async def retrieval_step(
633
642
  kbid: str,
634
643
  main_query: str,
@@ -675,17 +684,33 @@ async def retrieval_in_kb(
675
684
  metrics: RAGMetrics,
676
685
  ) -> RetrievalResults:
677
686
  prequeries = parse_prequeries(ask_request)
687
+ graph_strategy = parse_graph_strategy(ask_request)
678
688
  with metrics.time("retrieval"):
679
- main_results, prequeries_results, query_parser = await get_find_results(
680
- kbid=kbid,
681
- query=main_query,
682
- item=ask_request,
683
- ndb_client=client_type,
684
- user=user_id,
685
- origin=origin,
686
- metrics=metrics,
687
- prequeries_strategy=prequeries,
688
- )
689
+ prequeries_results = None
690
+ if graph_strategy is not None:
691
+ main_results, query_parser = await get_graph_results(
692
+ kbid=kbid,
693
+ query=main_query,
694
+ item=ask_request,
695
+ ndb_client=client_type,
696
+ user=user_id,
697
+ origin=origin,
698
+ graph_strategy=graph_strategy,
699
+ metrics=metrics,
700
+ shards=ask_request.shards,
701
+ )
702
+ # TODO (oni): Fallback to normal retrieval if no graph results are found
703
+ else:
704
+ main_results, prequeries_results, query_parser = await get_find_results(
705
+ kbid=kbid,
706
+ query=main_query,
707
+ item=ask_request,
708
+ ndb_client=client_type,
709
+ user=user_id,
710
+ origin=origin,
711
+ metrics=metrics,
712
+ prequeries_strategy=prequeries,
713
+ )
689
714
  if len(main_results.resources) == 0 and all(
690
715
  len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
691
716
  ):
@@ -1013,8 +1013,10 @@ class PromptContextBuilder:
1013
1013
  neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1014
1014
  elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1015
1015
  metadata_extension = cast(MetadataExtensionStrategy, strategy)
1016
- elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
1017
- # Prequeries are not handled here
1016
+ elif (
1017
+ strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
1018
+ ): # pragma: no cover
1019
+ # Prequeries and graph are not handled here
1018
1020
  logger.warning(
1019
1021
  "Unknown rag strategy",
1020
1022
  extra={"strategy": strategy.name, "kbid": self.kbid},
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from typing import Optional
21
+ from typing import Iterable, Optional
22
22
 
23
23
  from nucliadb.common.models_utils import to_proto
24
24
  from nucliadb.search import logger
@@ -51,6 +51,7 @@ from nucliadb_models.search import (
51
51
  )
52
52
  from nucliadb_protos import audit_pb2
53
53
  from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
54
+ from nucliadb_protos.utils_pb2 import RelationNode
54
55
  from nucliadb_telemetry.errors import capture_exception
55
56
  from nucliadb_utils.utilities import get_audit
56
57
 
@@ -145,15 +146,7 @@ async def get_find_results(
145
146
  return main_results, prequeries_results, query_parser
146
147
 
147
148
 
148
- async def run_main_query(
149
- kbid: str,
150
- query: str,
151
- item: AskRequest,
152
- ndb_client: NucliaDBClientType,
153
- user: str,
154
- origin: str,
155
- metrics: RAGMetrics = RAGMetrics(),
156
- ) -> tuple[KnowledgeboxFindResults, QueryParser]:
149
+ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
157
150
  find_request = FindRequest()
158
151
  find_request.resource_filters = item.resource_filters
159
152
  find_request.features = []
@@ -189,7 +182,19 @@ async def run_main_query(
189
182
  find_request.show_hidden = item.show_hidden
190
183
 
191
184
  # this executes the model validators, that can tweak some fields
192
- FindRequest.model_validate(find_request)
185
+ return FindRequest.model_validate(find_request)
186
+
187
+
188
+ async def run_main_query(
189
+ kbid: str,
190
+ query: str,
191
+ item: AskRequest,
192
+ ndb_client: NucliaDBClientType,
193
+ user: str,
194
+ origin: str,
195
+ metrics: RAGMetrics = RAGMetrics(),
196
+ ) -> tuple[KnowledgeboxFindResults, QueryParser]:
197
+ find_request = find_request_from_ask_request(item, query)
193
198
 
194
199
  find_results, incomplete, query_parser = await find(
195
200
  kbid,
@@ -211,36 +216,59 @@ async def get_relations_results(
211
216
  text_answer: str,
212
217
  target_shard_replicas: Optional[list[str]],
213
218
  timeout: Optional[float] = None,
219
+ only_with_metadata: bool = False,
220
+ only_agentic_relations: bool = False,
214
221
  ) -> Relations:
215
222
  try:
216
223
  predict = get_predict()
217
224
  detected_entities = await predict.detect_entities(kbid, text_answer)
218
- request = SearchRequest()
219
- request.relation_subgraph.entry_points.extend(detected_entities)
220
- request.relation_subgraph.depth = 1
221
-
222
- results: list[SearchResponse]
223
- (
224
- results,
225
- _,
226
- _,
227
- ) = await node_query(
228
- kbid,
229
- Method.SEARCH,
230
- request,
225
+
226
+ return await get_relations_results_from_entities(
227
+ kbid=kbid,
228
+ entities=detected_entities,
231
229
  target_shard_replicas=target_shard_replicas,
232
230
  timeout=timeout,
233
- use_read_replica_nodes=True,
234
- retry_on_primary=False,
231
+ only_with_metadata=only_with_metadata,
232
+ only_agentic_relations=only_agentic_relations,
235
233
  )
236
- relations_results: list[RelationSearchResponse] = [result.relation for result in results]
237
- return await merge_relations_results(relations_results, request.relation_subgraph)
238
234
  except Exception as exc:
239
235
  capture_exception(exc)
240
236
  logger.exception("Error getting relations results")
241
237
  return Relations(entities={})
242
238
 
243
239
 
240
+ async def get_relations_results_from_entities(
241
+ *,
242
+ kbid: str,
243
+ entities: Iterable[RelationNode],
244
+ target_shard_replicas: Optional[list[str]],
245
+ timeout: Optional[float] = None,
246
+ only_with_metadata: bool = False,
247
+ only_agentic_relations: bool = False,
248
+ ) -> Relations:
249
+ request = SearchRequest()
250
+ request.relation_subgraph.entry_points.extend(entities)
251
+ request.relation_subgraph.depth = 1
252
+ results: list[SearchResponse]
253
+ (
254
+ results,
255
+ _,
256
+ _,
257
+ ) = await node_query(
258
+ kbid,
259
+ Method.SEARCH,
260
+ request,
261
+ target_shard_replicas=target_shard_replicas,
262
+ timeout=timeout,
263
+ use_read_replica_nodes=True,
264
+ retry_on_primary=False,
265
+ )
266
+ relations_results: list[RelationSearchResponse] = [result.relation for result in results]
267
+ return await merge_relations_results(
268
+ relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations
269
+ )
270
+
271
+
244
272
  def maybe_audit_chat(
245
273
  *,
246
274
  kbid: str,