nucliadb 6.2.0.post2679__py3-none-any.whl → 6.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. migrations/0028_extracted_vectors_reference.py +61 -0
  2. migrations/0029_backfill_field_status.py +149 -0
  3. migrations/0030_label_deduplication.py +60 -0
  4. nucliadb/common/cluster/manager.py +41 -331
  5. nucliadb/common/cluster/rebalance.py +2 -2
  6. nucliadb/common/cluster/rollover.py +12 -71
  7. nucliadb/common/cluster/settings.py +3 -0
  8. nucliadb/common/cluster/standalone/utils.py +0 -43
  9. nucliadb/common/cluster/utils.py +0 -16
  10. nucliadb/common/counters.py +1 -0
  11. nucliadb/common/datamanagers/fields.py +48 -7
  12. nucliadb/common/datamanagers/vectorsets.py +11 -2
  13. nucliadb/common/external_index_providers/base.py +2 -1
  14. nucliadb/common/external_index_providers/pinecone.py +3 -5
  15. nucliadb/common/ids.py +18 -4
  16. nucliadb/common/models_utils/from_proto.py +479 -0
  17. nucliadb/common/models_utils/to_proto.py +60 -0
  18. nucliadb/common/nidx.py +76 -37
  19. nucliadb/export_import/models.py +3 -3
  20. nucliadb/health.py +0 -7
  21. nucliadb/ingest/app.py +0 -8
  22. nucliadb/ingest/consumer/auditing.py +1 -1
  23. nucliadb/ingest/consumer/shard_creator.py +1 -1
  24. nucliadb/ingest/fields/base.py +83 -21
  25. nucliadb/ingest/orm/brain.py +55 -56
  26. nucliadb/ingest/orm/broker_message.py +12 -2
  27. nucliadb/ingest/orm/entities.py +6 -17
  28. nucliadb/ingest/orm/knowledgebox.py +44 -22
  29. nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
  30. nucliadb/ingest/orm/processor/processor.py +5 -2
  31. nucliadb/ingest/orm/resource.py +222 -413
  32. nucliadb/ingest/processing.py +8 -2
  33. nucliadb/ingest/serialize.py +77 -46
  34. nucliadb/ingest/service/writer.py +2 -56
  35. nucliadb/ingest/settings.py +1 -4
  36. nucliadb/learning_proxy.py +6 -4
  37. nucliadb/purge/__init__.py +102 -12
  38. nucliadb/purge/orphan_shards.py +6 -4
  39. nucliadb/reader/api/models.py +3 -3
  40. nucliadb/reader/api/v1/__init__.py +1 -0
  41. nucliadb/reader/api/v1/download.py +2 -2
  42. nucliadb/reader/api/v1/knowledgebox.py +3 -3
  43. nucliadb/reader/api/v1/resource.py +23 -12
  44. nucliadb/reader/api/v1/services.py +4 -4
  45. nucliadb/reader/api/v1/vectorsets.py +48 -0
  46. nucliadb/search/api/v1/ask.py +11 -1
  47. nucliadb/search/api/v1/feedback.py +3 -3
  48. nucliadb/search/api/v1/knowledgebox.py +8 -13
  49. nucliadb/search/api/v1/search.py +3 -2
  50. nucliadb/search/api/v1/suggest.py +0 -2
  51. nucliadb/search/predict.py +6 -4
  52. nucliadb/search/requesters/utils.py +1 -2
  53. nucliadb/search/search/chat/ask.py +77 -13
  54. nucliadb/search/search/chat/prompt.py +16 -5
  55. nucliadb/search/search/chat/query.py +74 -34
  56. nucliadb/search/search/exceptions.py +2 -7
  57. nucliadb/search/search/find.py +9 -5
  58. nucliadb/search/search/find_merge.py +10 -4
  59. nucliadb/search/search/graph_strategy.py +884 -0
  60. nucliadb/search/search/hydrator.py +6 -0
  61. nucliadb/search/search/merge.py +79 -24
  62. nucliadb/search/search/query.py +74 -245
  63. nucliadb/search/search/query_parser/exceptions.py +11 -1
  64. nucliadb/search/search/query_parser/fetcher.py +405 -0
  65. nucliadb/search/search/query_parser/models.py +0 -3
  66. nucliadb/search/search/query_parser/parser.py +22 -21
  67. nucliadb/search/search/rerankers.py +1 -42
  68. nucliadb/search/search/shards.py +19 -0
  69. nucliadb/standalone/api_router.py +2 -14
  70. nucliadb/standalone/settings.py +4 -0
  71. nucliadb/train/generators/field_streaming.py +7 -3
  72. nucliadb/train/lifecycle.py +3 -6
  73. nucliadb/train/nodes.py +14 -12
  74. nucliadb/train/resource.py +380 -0
  75. nucliadb/writer/api/constants.py +20 -16
  76. nucliadb/writer/api/v1/__init__.py +1 -0
  77. nucliadb/writer/api/v1/export_import.py +1 -1
  78. nucliadb/writer/api/v1/field.py +13 -7
  79. nucliadb/writer/api/v1/knowledgebox.py +3 -46
  80. nucliadb/writer/api/v1/resource.py +20 -13
  81. nucliadb/writer/api/v1/services.py +10 -1
  82. nucliadb/writer/api/v1/upload.py +61 -34
  83. nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
  84. nucliadb/writer/back_pressure.py +17 -46
  85. nucliadb/writer/resource/basic.py +9 -7
  86. nucliadb/writer/resource/field.py +42 -9
  87. nucliadb/writer/settings.py +2 -2
  88. nucliadb/writer/tus/gcs.py +11 -10
  89. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
  90. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
  91. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
  92. nucliadb/common/cluster/discovery/base.py +0 -178
  93. nucliadb/common/cluster/discovery/k8s.py +0 -301
  94. nucliadb/common/cluster/discovery/manual.py +0 -57
  95. nucliadb/common/cluster/discovery/single.py +0 -51
  96. nucliadb/common/cluster/discovery/types.py +0 -32
  97. nucliadb/common/cluster/discovery/utils.py +0 -67
  98. nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
  99. nucliadb/common/cluster/standalone/index_node.py +0 -123
  100. nucliadb/common/cluster/standalone/service.py +0 -84
  101. nucliadb/standalone/introspect.py +0 -208
  102. nucliadb-6.2.0.post2679.dist-info/zip-safe +0 -1
  103. /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
  104. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
  105. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from nucliadb.search.search.chat.query import (
49
49
  ChatAuditor,
50
50
  get_find_results,
51
51
  get_relations_results,
52
+ maybe_audit_chat,
52
53
  rephrase_query,
53
54
  sorted_prompt_context_list,
54
55
  tokens_to_chars,
@@ -57,6 +58,7 @@ from nucliadb.search.search.exceptions import (
57
58
  IncompleteFindResultsError,
58
59
  InvalidQueryError,
59
60
  )
61
+ from nucliadb.search.search.graph_strategy import get_graph_results
60
62
  from nucliadb.search.search.metrics import RAGMetrics
61
63
  from nucliadb.search.search.query import QueryParser
62
64
  from nucliadb.search.utilities import get_predict
@@ -75,6 +77,7 @@ from nucliadb_models.search import (
75
77
  ErrorAskResponseItem,
76
78
  FindParagraph,
77
79
  FindRequest,
80
+ GraphStrategy,
78
81
  JSONAskResponseItem,
79
82
  KnowledgeboxFindResults,
80
83
  MetadataAskResponseItem,
@@ -126,7 +129,7 @@ class AskResult:
126
129
  main_results: KnowledgeboxFindResults,
127
130
  prequeries_results: Optional[list[PreQueryResult]],
128
131
  nuclia_learning_id: Optional[str],
129
- predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
132
+ predict_answer_stream: Optional[AsyncGenerator[GenerativeChunk, None]],
130
133
  prompt_context: PromptContext,
131
134
  prompt_context_order: PromptContextOrder,
132
135
  auditor: ChatAuditor,
@@ -393,6 +396,9 @@ class AskResult:
393
396
  This method does not assume any order in the stream of items, but it assumes that at least
394
397
  the answer text is streamed in order.
395
398
  """
399
+ if self.predict_answer_stream is None:
400
+ # In some cases, clients may want to skip the answer generation step
401
+ return
396
402
  async for generative_chunk in self.predict_answer_stream:
397
403
  item = generative_chunk.chunk
398
404
  if isinstance(item, TextGenerativeResponse):
@@ -431,14 +437,14 @@ class NotEnoughContextAskResult(AskResult):
431
437
  """
432
438
  yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
433
439
  yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
434
- status = AnswerStatusCode.NO_CONTEXT
440
+ status = AnswerStatusCode.NO_RETRIEVAL_DATA
435
441
  yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
436
442
 
437
443
  async def json(self) -> str:
438
444
  return SyncAskResponse(
439
445
  answer=NOT_ENOUGH_CONTEXT_ANSWER,
440
446
  retrieval_results=self.main_results,
441
- status=AnswerStatusCode.NO_CONTEXT,
447
+ status=AnswerStatusCode.NO_RETRIEVAL_DATA.prettify(),
442
448
  ).model_dump_json()
443
449
 
444
450
 
@@ -485,6 +491,31 @@ async def ask(
485
491
  resource=resource,
486
492
  )
487
493
  except NoRetrievalResultsError as err:
494
+ try:
495
+ rephrase_time = metrics.elapsed("rephrase")
496
+ except KeyError:
497
+ # Not all ask requests have a rephrase step
498
+ rephrase_time = None
499
+
500
+ maybe_audit_chat(
501
+ kbid=kbid,
502
+ user_id=user_id,
503
+ client_type=client_type,
504
+ origin=origin,
505
+ generative_answer_time=0,
506
+ generative_answer_first_chunk_time=0,
507
+ rephrase_time=rephrase_time,
508
+ user_query=user_query,
509
+ rephrased_query=rephrased_query,
510
+ text_answer=b"",
511
+ status_code=AnswerStatusCode.NO_RETRIEVAL_DATA,
512
+ chat_history=chat_history,
513
+ query_context={},
514
+ query_context_order={},
515
+ learning_id=None,
516
+ model=ask_request.generative_model,
517
+ )
518
+
488
519
  # If a retrieval was attempted but no results were found,
489
520
  # early return the ask endpoint without querying the generative model
490
521
  return NotEnoughContextAskResult(
@@ -503,6 +534,7 @@ async def ask(
503
534
  ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
504
535
  resource=resource,
505
536
  user_context=user_context,
537
+ user_image_context=ask_request.extra_context_images,
506
538
  strategies=ask_request.rag_strategies,
507
539
  image_strategies=ask_request.rag_images_strategies,
508
540
  max_context_characters=tokens_to_chars(max_tokens_context),
@@ -534,14 +566,18 @@ async def ask(
534
566
  rerank_context=False,
535
567
  top_k=ask_request.top_k,
536
568
  )
537
- with metrics.time("stream_start"):
538
- predict = get_predict()
539
- (
540
- nuclia_learning_id,
541
- nuclia_learning_model,
542
- predict_answer_stream,
543
- ) = await predict.chat_query_ndjson(kbid, chat_model)
544
- debug_chat_model = chat_model
569
+
570
+ nuclia_learning_id = None
571
+ nuclia_learning_model = None
572
+ predict_answer_stream = None
573
+ if ask_request.generate_answer:
574
+ with metrics.time("stream_start"):
575
+ predict = get_predict()
576
+ (
577
+ nuclia_learning_id,
578
+ nuclia_learning_model,
579
+ predict_answer_stream,
580
+ ) = await predict.chat_query_ndjson(kbid, chat_model)
545
581
 
546
582
  auditor = ChatAuditor(
547
583
  kbid=kbid,
@@ -562,13 +598,13 @@ async def ask(
562
598
  main_results=retrieval_results.main_query,
563
599
  prequeries_results=retrieval_results.prequeries,
564
600
  nuclia_learning_id=nuclia_learning_id,
565
- predict_answer_stream=predict_answer_stream, # type: ignore
601
+ predict_answer_stream=predict_answer_stream,
566
602
  prompt_context=prompt_context,
567
603
  prompt_context_order=prompt_context_order,
568
604
  auditor=auditor,
569
605
  metrics=metrics,
570
606
  best_matches=retrieval_results.best_matches,
571
- debug_chat_model=debug_chat_model,
607
+ debug_chat_model=chat_model,
572
608
  )
573
609
 
574
610
 
@@ -629,6 +665,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
629
665
  return None
630
666
 
631
667
 
668
+ def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
669
+ for rag_strategy in ask_request.rag_strategies:
670
+ if rag_strategy.name == RagStrategyName.GRAPH:
671
+ return cast(GraphStrategy, rag_strategy)
672
+ return None
673
+
674
+
632
675
  async def retrieval_step(
633
676
  kbid: str,
634
677
  main_query: str,
@@ -675,6 +718,7 @@ async def retrieval_in_kb(
675
718
  metrics: RAGMetrics,
676
719
  ) -> RetrievalResults:
677
720
  prequeries = parse_prequeries(ask_request)
721
+ graph_strategy = parse_graph_strategy(ask_request)
678
722
  with metrics.time("retrieval"):
679
723
  main_results, prequeries_results, query_parser = await get_find_results(
680
724
  kbid=kbid,
@@ -686,6 +730,26 @@ async def retrieval_in_kb(
686
730
  metrics=metrics,
687
731
  prequeries_strategy=prequeries,
688
732
  )
733
+
734
+ if graph_strategy is not None:
735
+ graph_results, graph_request = await get_graph_results(
736
+ kbid=kbid,
737
+ query=main_query,
738
+ item=ask_request,
739
+ ndb_client=client_type,
740
+ user=user_id,
741
+ origin=origin,
742
+ graph_strategy=graph_strategy,
743
+ metrics=metrics,
744
+ shards=ask_request.shards,
745
+ )
746
+
747
+ if prequeries_results is None:
748
+ prequeries_results = []
749
+
750
+ prequery = PreQuery(id="graph", request=graph_request, weight=graph_strategy.weight)
751
+ prequeries_results.append((prequery, graph_results))
752
+
689
753
  if len(main_results.resources) == 0 and all(
690
754
  len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
691
755
  ):
@@ -28,6 +28,7 @@ from pydantic import BaseModel
28
28
 
29
29
  from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
30
30
  from nucliadb.common.maindb.utils import get_driver
31
+ from nucliadb.common.models_utils import from_proto
31
32
  from nucliadb.ingest.fields.base import Field
32
33
  from nucliadb.ingest.fields.conversation import Conversation
33
34
  from nucliadb.ingest.fields.file import File
@@ -41,6 +42,7 @@ from nucliadb.search.search.chat.images import (
41
42
  )
42
43
  from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
43
44
  from nucliadb.search.search.paragraphs import get_paragraph_text
45
+ from nucliadb_models.labels import translate_alias_to_system_label
44
46
  from nucliadb_models.metadata import Extra, Origin
45
47
  from nucliadb_models.search import (
46
48
  SCORE_TYPE,
@@ -49,6 +51,7 @@ from nucliadb_models.search import (
49
51
  FindParagraph,
50
52
  FullResourceStrategy,
51
53
  HierarchyResourceStrategy,
54
+ Image,
52
55
  ImageRagStrategy,
53
56
  ImageRagStrategyName,
54
57
  MetadataExtensionStrategy,
@@ -266,7 +269,9 @@ async def full_resource_prompt_context(
266
269
  if strategy.apply_to is not None:
267
270
  # decide whether the resource should be extended or not
268
271
  for label in strategy.apply_to.exclude:
269
- skip = skip or (label in (paragraph.labels or []))
272
+ skip = skip or (
273
+ translate_alias_to_system_label(label) in (paragraph.labels or [])
274
+ )
270
275
 
271
276
  if not skip:
272
277
  ordered_resources.append(resource_uuid)
@@ -346,7 +351,7 @@ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_i
346
351
  if resource is not None:
347
352
  pb_origin = await resource.get_origin()
348
353
  if pb_origin is not None:
349
- origin = Origin.from_message(pb_origin)
354
+ origin = from_proto.origin(pb_origin)
350
355
  return rid, origin
351
356
 
352
357
  rids = {tb_id.rid for tb_id in text_block_ids}
@@ -433,7 +438,7 @@ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_id
433
438
  if resource is not None:
434
439
  pb_extra = await resource.get_extra()
435
440
  if pb_extra is not None:
436
- extra = Extra.from_message(pb_extra)
441
+ extra = from_proto.extra(pb_extra)
437
442
  return rid, extra
438
443
 
439
444
  rids = {tb_id.rid for tb_id in text_block_ids}
@@ -876,6 +881,7 @@ class PromptContextBuilder:
876
881
  ordered_paragraphs: list[FindParagraph],
877
882
  resource: Optional[str] = None,
878
883
  user_context: Optional[list[str]] = None,
884
+ user_image_context: Optional[list[Image]] = None,
879
885
  strategies: Optional[Sequence[RagStrategy]] = None,
880
886
  image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
881
887
  max_context_characters: Optional[int] = None,
@@ -885,6 +891,7 @@ class PromptContextBuilder:
885
891
  self.ordered_paragraphs = ordered_paragraphs
886
892
  self.resource = resource
887
893
  self.user_context = user_context
894
+ self.user_image_context = user_image_context
888
895
  self.strategies = strategies
889
896
  self.image_strategies = image_strategies
890
897
  self.max_context_characters = max_context_characters
@@ -895,6 +902,8 @@ class PromptContextBuilder:
895
902
  # it is added first, followed by the found text blocks in order of relevance
896
903
  for i, text_block in enumerate(self.user_context or []):
897
904
  context[f"USER_CONTEXT_{i}"] = text_block
905
+ for i, image in enumerate(self.user_image_context or []):
906
+ context.images[f"USER_IMAGE_CONTEXT_{i}"] = image
898
907
 
899
908
  async def build(
900
909
  self,
@@ -1012,8 +1021,10 @@ class PromptContextBuilder:
1012
1021
  neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1013
1022
  elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1014
1023
  metadata_extension = cast(MetadataExtensionStrategy, strategy)
1015
- elif strategy.name != RagStrategyName.PREQUERIES: # pragma: no cover
1016
- # Prequeries are not handled here
1024
+ elif (
1025
+ strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
1026
+ ): # pragma: no cover
1027
+ # Prequeries and graph are not handled here
1017
1028
  logger.warning(
1018
1029
  "Unknown rag strategy",
1019
1030
  extra={"strategy": strategy.name, "kbid": self.kbid},
@@ -18,8 +18,9 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from typing import Optional
21
+ from typing import Iterable, Optional
22
22
 
23
+ from nucliadb.common.models_utils import to_proto
23
24
  from nucliadb.search import logger
24
25
  from nucliadb.search.predict import AnswerStatusCode
25
26
  from nucliadb.search.requesters.utils import Method, node_query
@@ -49,7 +50,13 @@ from nucliadb_models.search import (
49
50
  parse_rephrase_prompt,
50
51
  )
51
52
  from nucliadb_protos import audit_pb2
52
- from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
53
+ from nucliadb_protos.nodereader_pb2 import (
54
+ EntitiesSubgraphRequest,
55
+ RelationSearchResponse,
56
+ SearchRequest,
57
+ SearchResponse,
58
+ )
59
+ from nucliadb_protos.utils_pb2 import RelationNode
53
60
  from nucliadb_telemetry.errors import capture_exception
54
61
  from nucliadb_utils.utilities import get_audit
55
62
 
@@ -144,15 +151,7 @@ async def get_find_results(
144
151
  return main_results, prequeries_results, query_parser
145
152
 
146
153
 
147
- async def run_main_query(
148
- kbid: str,
149
- query: str,
150
- item: AskRequest,
151
- ndb_client: NucliaDBClientType,
152
- user: str,
153
- origin: str,
154
- metrics: RAGMetrics = RAGMetrics(),
155
- ) -> tuple[KnowledgeboxFindResults, QueryParser]:
154
+ def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
156
155
  find_request = FindRequest()
157
156
  find_request.resource_filters = item.resource_filters
158
157
  find_request.features = []
@@ -188,7 +187,19 @@ async def run_main_query(
188
187
  find_request.show_hidden = item.show_hidden
189
188
 
190
189
  # this executes the model validators, that can tweak some fields
191
- FindRequest.model_validate(find_request)
190
+ return FindRequest.model_validate(find_request)
191
+
192
+
193
+ async def run_main_query(
194
+ kbid: str,
195
+ query: str,
196
+ item: AskRequest,
197
+ ndb_client: NucliaDBClientType,
198
+ user: str,
199
+ origin: str,
200
+ metrics: RAGMetrics = RAGMetrics(),
201
+ ) -> tuple[KnowledgeboxFindResults, QueryParser]:
202
+ find_request = find_request_from_ask_request(item, query)
192
203
 
193
204
  find_results, incomplete, query_parser = await find(
194
205
  kbid,
@@ -210,36 +221,65 @@ async def get_relations_results(
210
221
  text_answer: str,
211
222
  target_shard_replicas: Optional[list[str]],
212
223
  timeout: Optional[float] = None,
224
+ only_with_metadata: bool = False,
225
+ only_agentic_relations: bool = False,
213
226
  ) -> Relations:
214
227
  try:
215
228
  predict = get_predict()
216
229
  detected_entities = await predict.detect_entities(kbid, text_answer)
217
- request = SearchRequest()
218
- request.relation_subgraph.entry_points.extend(detected_entities)
219
- request.relation_subgraph.depth = 1
220
-
221
- results: list[SearchResponse]
222
- (
223
- results,
224
- _,
225
- _,
226
- ) = await node_query(
227
- kbid,
228
- Method.SEARCH,
229
- request,
230
+
231
+ return await get_relations_results_from_entities(
232
+ kbid=kbid,
233
+ entities=detected_entities,
230
234
  target_shard_replicas=target_shard_replicas,
231
235
  timeout=timeout,
232
- use_read_replica_nodes=True,
233
- retry_on_primary=False,
236
+ only_with_metadata=only_with_metadata,
237
+ only_agentic_relations=only_agentic_relations,
234
238
  )
235
- relations_results: list[RelationSearchResponse] = [result.relation for result in results]
236
- return await merge_relations_results(relations_results, request.relation_subgraph)
237
239
  except Exception as exc:
238
240
  capture_exception(exc)
239
241
  logger.exception("Error getting relations results")
240
242
  return Relations(entities={})
241
243
 
242
244
 
245
+ async def get_relations_results_from_entities(
246
+ *,
247
+ kbid: str,
248
+ entities: Iterable[RelationNode],
249
+ target_shard_replicas: Optional[list[str]],
250
+ timeout: Optional[float] = None,
251
+ only_with_metadata: bool = False,
252
+ only_agentic_relations: bool = False,
253
+ deleted_entities: set[str] = set(),
254
+ ) -> Relations:
255
+ request = SearchRequest()
256
+ request.relation_subgraph.entry_points.extend(entities)
257
+ request.relation_subgraph.depth = 1
258
+
259
+ deleted = EntitiesSubgraphRequest.DeletedEntities()
260
+ deleted.node_values.extend(deleted_entities)
261
+ request.relation_subgraph.deleted_entities.append(deleted)
262
+
263
+ results: list[SearchResponse]
264
+ (
265
+ results,
266
+ _,
267
+ _,
268
+ ) = await node_query(
269
+ kbid,
270
+ Method.SEARCH,
271
+ request,
272
+ target_shard_replicas=target_shard_replicas,
273
+ timeout=timeout,
274
+ use_read_replica_nodes=True,
275
+ retry_on_primary=False,
276
+ )
277
+ relations_results: list[RelationSearchResponse] = [result.relation for result in results]
278
+ return await merge_relations_results(
279
+ relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations
280
+ )
281
+
282
+
243
283
  def maybe_audit_chat(
244
284
  *,
245
285
  kbid: str,
@@ -256,8 +296,8 @@ def maybe_audit_chat(
256
296
  chat_history: list[ChatContextMessage],
257
297
  query_context: PromptContext,
258
298
  query_context_order: PromptContextOrder,
259
- learning_id: str,
260
- model: str,
299
+ learning_id: Optional[str],
300
+ model: Optional[str],
261
301
  ):
262
302
  audit = get_audit()
263
303
  if audit is None:
@@ -278,7 +318,7 @@ def maybe_audit_chat(
278
318
  audit.chat(
279
319
  kbid,
280
320
  user_id,
281
- client_type.to_proto(),
321
+ to_proto.client_type(client_type),
282
322
  origin,
283
323
  question=user_query,
284
324
  generative_answer_time=generative_answer_time,
@@ -295,7 +335,7 @@ def maybe_audit_chat(
295
335
 
296
336
 
297
337
  def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
298
- if status_code == AnswerStatusCode.NO_CONTEXT:
338
+ if status_code == AnswerStatusCode.NO_CONTEXT or status_code == AnswerStatusCode.NO_RETRIEVAL_DATA:
299
339
  # We don't want to audit "Not enough context to answer this." and instead set a None.
300
340
  return None
301
341
  return raw_text_answer.decode()
@@ -320,7 +360,7 @@ class ChatAuditor:
320
360
  learning_id: Optional[str],
321
361
  query_context: PromptContext,
322
362
  query_context_order: PromptContextOrder,
323
- model: str,
363
+ model: Optional[str],
324
364
  ):
325
365
  self.kbid = kbid
326
366
  self.user_id = user_id
@@ -17,6 +17,8 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
+ from nucliadb.search.search.query_parser.exceptions import InvalidQueryError as InvalidQueryError
21
+
20
22
 
21
23
  class IncompleteFindResultsError(Exception):
22
24
  pass
@@ -24,10 +26,3 @@ class IncompleteFindResultsError(Exception):
24
26
 
25
27
  class ResourceNotFoundError(Exception):
26
28
  pass
27
-
28
-
29
- class InvalidQueryError(Exception):
30
- def __init__(self, param: str, reason: str):
31
- self.param = param
32
- self.reason = reason
33
- super().__init__(f"Invalid query. Error in {param}: {reason}")
@@ -24,6 +24,7 @@ from typing import Optional
24
24
 
25
25
  from nucliadb.common.external_index_providers.base import ExternalIndexManager
26
26
  from nucliadb.common.external_index_providers.manager import get_external_index_manager
27
+ from nucliadb.common.models_utils import to_proto
27
28
  from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query
28
29
  from nucliadb.search.search.find_merge import (
29
30
  build_find_response,
@@ -105,7 +106,7 @@ async def _index_node_retrieval(
105
106
  kbid, item, generative_model=generative_model
106
107
  )
107
108
  with metrics.time("query_parse"):
108
- pb_query, incomplete_results, autofilters = await query_parser.parse()
109
+ pb_query, incomplete_results, autofilters, rephrased_query = await query_parser.parse()
109
110
 
110
111
  with metrics.time("node_query"):
111
112
  results, query_incomplete_results, queried_nodes = await node_query(
@@ -119,7 +120,8 @@ async def _index_node_retrieval(
119
120
  results,
120
121
  kbid=kbid,
121
122
  query=pb_query.body,
122
- relation_subgraph_query=pb_query.relations.subgraph,
123
+ rephrased_query=rephrased_query,
124
+ relation_subgraph_query=pb_query.relation_subgraph,
123
125
  min_score_bm25=pb_query.min_score_bm25,
124
126
  min_score_semantic=pb_query.min_score_semantic,
125
127
  top_k=item.top_k,
@@ -136,7 +138,7 @@ async def _index_node_retrieval(
136
138
  audit.search(
137
139
  kbid,
138
140
  x_nucliadb_user,
139
- x_ndb_client.to_proto(),
141
+ to_proto.client_type(x_ndb_client),
140
142
  x_forwarded_for,
141
143
  pb_query,
142
144
  search_time,
@@ -193,7 +195,7 @@ async def _external_index_retrieval(
193
195
  query_parser, _, reranker = await query_parser_from_find_request(
194
196
  kbid, item, generative_model=generative_model
195
197
  )
196
- search_request, incomplete_results, _ = await query_parser.parse()
198
+ search_request, incomplete_results, _, rephrased_query = await query_parser.parse()
197
199
 
198
200
  # Query index
199
201
  query_results = await external_index_manager.query(search_request) # noqa
@@ -224,6 +226,7 @@ async def _external_index_retrieval(
224
226
  retrieval_results = KnowledgeboxFindResults(
225
227
  resources=find_resources,
226
228
  query=item.query,
229
+ rephrased_query=rephrased_query,
227
230
  total=0,
228
231
  page_number=0,
229
232
  page_size=item.top_k,
@@ -259,7 +262,7 @@ async def query_parser_from_find_request(
259
262
  # XXX this is becoming the new /find query parsing, this should be moved to
260
263
  # a cleaner abstraction
261
264
 
262
- parsed = parse_find(item)
265
+ parsed = await parse_find(kbid, item)
263
266
 
264
267
  rank_fusion = get_rank_fusion(parsed.rank_fusion)
265
268
  reranker = get_reranker(parsed.reranker)
@@ -268,6 +271,7 @@ async def query_parser_from_find_request(
268
271
  kbid=kbid,
269
272
  features=item.features,
270
273
  query=item.query,
274
+ query_entities=item.query_entities,
271
275
  label_filters=item.filters,
272
276
  keyword_filters=item.keyword_filters,
273
277
  faceted=None,
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from typing import Iterable, Union
21
+ from typing import Iterable, Optional, Union
22
22
 
23
23
  from nucliadb.common.external_index_providers.base import TextBlockMatch
24
24
  from nucliadb.common.ids import ParagraphId, VectorId
@@ -74,6 +74,7 @@ async def build_find_response(
74
74
  *,
75
75
  kbid: str,
76
76
  query: str,
77
+ rephrased_query: Optional[str],
77
78
  relation_subgraph_query: EntitiesSubgraphRequest,
78
79
  top_k: int,
79
80
  min_score_bm25: float,
@@ -96,9 +97,13 @@ async def build_find_response(
96
97
  )
97
98
  )
98
99
 
99
- merged_text_blocks: list[TextBlockMatch] = rank_fusion_algorithm.fuse(
100
- keyword_results, semantic_results
101
- )
100
+ merged_text_blocks: list[TextBlockMatch]
101
+ if len(keyword_results) == 0:
102
+ merged_text_blocks = semantic_results
103
+ elif len(semantic_results) == 0:
104
+ merged_text_blocks = keyword_results
105
+ else:
106
+ merged_text_blocks = rank_fusion_algorithm.fuse(keyword_results, semantic_results)
102
107
 
103
108
  # cut
104
109
  # we assume pagination + predict reranker is forbidden and has been already
@@ -139,6 +144,7 @@ async def build_find_response(
139
144
 
140
145
  find_results = KnowledgeboxFindResults(
141
146
  query=query,
147
+ rephrased_query=rephrased_query,
142
148
  resources=find_resources,
143
149
  best_matches=best_matches,
144
150
  relations=relations,