nucliadb 6.2.0.post2675__py3-none-any.whl → 6.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0028_extracted_vectors_reference.py +61 -0
- migrations/0029_backfill_field_status.py +149 -0
- migrations/0030_label_deduplication.py +60 -0
- nucliadb/common/cluster/manager.py +41 -331
- nucliadb/common/cluster/rebalance.py +2 -2
- nucliadb/common/cluster/rollover.py +12 -71
- nucliadb/common/cluster/settings.py +3 -0
- nucliadb/common/cluster/standalone/utils.py +0 -43
- nucliadb/common/cluster/utils.py +0 -16
- nucliadb/common/counters.py +1 -0
- nucliadb/common/datamanagers/fields.py +48 -7
- nucliadb/common/datamanagers/vectorsets.py +11 -2
- nucliadb/common/external_index_providers/base.py +2 -1
- nucliadb/common/external_index_providers/pinecone.py +3 -5
- nucliadb/common/ids.py +18 -4
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +76 -37
- nucliadb/export_import/models.py +3 -3
- nucliadb/health.py +0 -7
- nucliadb/ingest/app.py +0 -8
- nucliadb/ingest/consumer/auditing.py +1 -1
- nucliadb/ingest/consumer/shard_creator.py +1 -1
- nucliadb/ingest/fields/base.py +83 -21
- nucliadb/ingest/orm/brain.py +55 -56
- nucliadb/ingest/orm/broker_message.py +12 -2
- nucliadb/ingest/orm/entities.py +6 -17
- nucliadb/ingest/orm/knowledgebox.py +44 -22
- nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
- nucliadb/ingest/orm/processor/processor.py +5 -2
- nucliadb/ingest/orm/resource.py +222 -413
- nucliadb/ingest/processing.py +8 -2
- nucliadb/ingest/serialize.py +77 -46
- nucliadb/ingest/service/writer.py +2 -56
- nucliadb/ingest/settings.py +1 -4
- nucliadb/learning_proxy.py +6 -4
- nucliadb/purge/__init__.py +102 -12
- nucliadb/purge/orphan_shards.py +6 -4
- nucliadb/reader/api/models.py +3 -3
- nucliadb/reader/api/v1/__init__.py +1 -0
- nucliadb/reader/api/v1/download.py +2 -2
- nucliadb/reader/api/v1/knowledgebox.py +3 -3
- nucliadb/reader/api/v1/resource.py +23 -12
- nucliadb/reader/api/v1/services.py +4 -4
- nucliadb/reader/api/v1/vectorsets.py +48 -0
- nucliadb/search/api/v1/ask.py +11 -1
- nucliadb/search/api/v1/feedback.py +3 -3
- nucliadb/search/api/v1/knowledgebox.py +8 -13
- nucliadb/search/api/v1/search.py +3 -2
- nucliadb/search/api/v1/suggest.py +0 -2
- nucliadb/search/predict.py +6 -4
- nucliadb/search/requesters/utils.py +1 -2
- nucliadb/search/search/chat/ask.py +77 -13
- nucliadb/search/search/chat/prompt.py +16 -5
- nucliadb/search/search/chat/query.py +74 -34
- nucliadb/search/search/exceptions.py +2 -7
- nucliadb/search/search/find.py +9 -5
- nucliadb/search/search/find_merge.py +10 -4
- nucliadb/search/search/graph_strategy.py +884 -0
- nucliadb/search/search/hydrator.py +6 -0
- nucliadb/search/search/merge.py +79 -24
- nucliadb/search/search/query.py +74 -245
- nucliadb/search/search/query_parser/exceptions.py +11 -1
- nucliadb/search/search/query_parser/fetcher.py +405 -0
- nucliadb/search/search/query_parser/models.py +0 -3
- nucliadb/search/search/query_parser/parser.py +22 -21
- nucliadb/search/search/rerankers.py +1 -42
- nucliadb/search/search/shards.py +19 -0
- nucliadb/standalone/api_router.py +2 -14
- nucliadb/standalone/settings.py +4 -0
- nucliadb/train/generators/field_streaming.py +7 -3
- nucliadb/train/lifecycle.py +3 -6
- nucliadb/train/nodes.py +14 -12
- nucliadb/train/resource.py +380 -0
- nucliadb/writer/api/constants.py +20 -16
- nucliadb/writer/api/v1/__init__.py +1 -0
- nucliadb/writer/api/v1/export_import.py +1 -1
- nucliadb/writer/api/v1/field.py +13 -7
- nucliadb/writer/api/v1/knowledgebox.py +3 -46
- nucliadb/writer/api/v1/resource.py +20 -13
- nucliadb/writer/api/v1/services.py +10 -1
- nucliadb/writer/api/v1/upload.py +61 -34
- nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
- nucliadb/writer/back_pressure.py +17 -46
- nucliadb/writer/resource/basic.py +9 -7
- nucliadb/writer/resource/field.py +42 -9
- nucliadb/writer/settings.py +2 -2
- nucliadb/writer/tus/gcs.py +11 -10
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
- nucliadb/common/cluster/discovery/base.py +0 -178
- nucliadb/common/cluster/discovery/k8s.py +0 -301
- nucliadb/common/cluster/discovery/manual.py +0 -57
- nucliadb/common/cluster/discovery/single.py +0 -51
- nucliadb/common/cluster/discovery/types.py +0 -32
- nucliadb/common/cluster/discovery/utils.py +0 -67
- nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
- nucliadb/common/cluster/standalone/index_node.py +0 -123
- nucliadb/common/cluster/standalone/service.py +0 -84
- nucliadb/standalone/introspect.py +0 -208
- nucliadb-6.2.0.post2675.dist-info/zip-safe +0 -1
- /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from nucliadb.search.search.chat.query import (
|
|
49
49
|
ChatAuditor,
|
50
50
|
get_find_results,
|
51
51
|
get_relations_results,
|
52
|
+
maybe_audit_chat,
|
52
53
|
rephrase_query,
|
53
54
|
sorted_prompt_context_list,
|
54
55
|
tokens_to_chars,
|
@@ -57,6 +58,7 @@ from nucliadb.search.search.exceptions import (
|
|
57
58
|
IncompleteFindResultsError,
|
58
59
|
InvalidQueryError,
|
59
60
|
)
|
61
|
+
from nucliadb.search.search.graph_strategy import get_graph_results
|
60
62
|
from nucliadb.search.search.metrics import RAGMetrics
|
61
63
|
from nucliadb.search.search.query import QueryParser
|
62
64
|
from nucliadb.search.utilities import get_predict
|
@@ -75,6 +77,7 @@ from nucliadb_models.search import (
|
|
75
77
|
ErrorAskResponseItem,
|
76
78
|
FindParagraph,
|
77
79
|
FindRequest,
|
80
|
+
GraphStrategy,
|
78
81
|
JSONAskResponseItem,
|
79
82
|
KnowledgeboxFindResults,
|
80
83
|
MetadataAskResponseItem,
|
@@ -126,7 +129,7 @@ class AskResult:
|
|
126
129
|
main_results: KnowledgeboxFindResults,
|
127
130
|
prequeries_results: Optional[list[PreQueryResult]],
|
128
131
|
nuclia_learning_id: Optional[str],
|
129
|
-
predict_answer_stream: AsyncGenerator[GenerativeChunk, None],
|
132
|
+
predict_answer_stream: Optional[AsyncGenerator[GenerativeChunk, None]],
|
130
133
|
prompt_context: PromptContext,
|
131
134
|
prompt_context_order: PromptContextOrder,
|
132
135
|
auditor: ChatAuditor,
|
@@ -393,6 +396,9 @@ class AskResult:
|
|
393
396
|
This method does not assume any order in the stream of items, but it assumes that at least
|
394
397
|
the answer text is streamed in order.
|
395
398
|
"""
|
399
|
+
if self.predict_answer_stream is None:
|
400
|
+
# In some cases, clients may want to skip the answer generation step
|
401
|
+
return
|
396
402
|
async for generative_chunk in self.predict_answer_stream:
|
397
403
|
item = generative_chunk.chunk
|
398
404
|
if isinstance(item, TextGenerativeResponse):
|
@@ -431,14 +437,14 @@ class NotEnoughContextAskResult(AskResult):
|
|
431
437
|
"""
|
432
438
|
yield self._ndjson_encode(RetrievalAskResponseItem(results=self.main_results))
|
433
439
|
yield self._ndjson_encode(AnswerAskResponseItem(text=NOT_ENOUGH_CONTEXT_ANSWER))
|
434
|
-
status = AnswerStatusCode.
|
440
|
+
status = AnswerStatusCode.NO_RETRIEVAL_DATA
|
435
441
|
yield self._ndjson_encode(StatusAskResponseItem(code=status.value, status=status.prettify()))
|
436
442
|
|
437
443
|
async def json(self) -> str:
|
438
444
|
return SyncAskResponse(
|
439
445
|
answer=NOT_ENOUGH_CONTEXT_ANSWER,
|
440
446
|
retrieval_results=self.main_results,
|
441
|
-
status=AnswerStatusCode.
|
447
|
+
status=AnswerStatusCode.NO_RETRIEVAL_DATA.prettify(),
|
442
448
|
).model_dump_json()
|
443
449
|
|
444
450
|
|
@@ -485,6 +491,31 @@ async def ask(
|
|
485
491
|
resource=resource,
|
486
492
|
)
|
487
493
|
except NoRetrievalResultsError as err:
|
494
|
+
try:
|
495
|
+
rephrase_time = metrics.elapsed("rephrase")
|
496
|
+
except KeyError:
|
497
|
+
# Not all ask requests have a rephrase step
|
498
|
+
rephrase_time = None
|
499
|
+
|
500
|
+
maybe_audit_chat(
|
501
|
+
kbid=kbid,
|
502
|
+
user_id=user_id,
|
503
|
+
client_type=client_type,
|
504
|
+
origin=origin,
|
505
|
+
generative_answer_time=0,
|
506
|
+
generative_answer_first_chunk_time=0,
|
507
|
+
rephrase_time=rephrase_time,
|
508
|
+
user_query=user_query,
|
509
|
+
rephrased_query=rephrased_query,
|
510
|
+
text_answer=b"",
|
511
|
+
status_code=AnswerStatusCode.NO_RETRIEVAL_DATA,
|
512
|
+
chat_history=chat_history,
|
513
|
+
query_context={},
|
514
|
+
query_context_order={},
|
515
|
+
learning_id=None,
|
516
|
+
model=ask_request.generative_model,
|
517
|
+
)
|
518
|
+
|
488
519
|
# If a retrieval was attempted but no results were found,
|
489
520
|
# early return the ask endpoint without querying the generative model
|
490
521
|
return NotEnoughContextAskResult(
|
@@ -503,6 +534,7 @@ async def ask(
|
|
503
534
|
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
504
535
|
resource=resource,
|
505
536
|
user_context=user_context,
|
537
|
+
user_image_context=ask_request.extra_context_images,
|
506
538
|
strategies=ask_request.rag_strategies,
|
507
539
|
image_strategies=ask_request.rag_images_strategies,
|
508
540
|
max_context_characters=tokens_to_chars(max_tokens_context),
|
@@ -534,14 +566,18 @@ async def ask(
|
|
534
566
|
rerank_context=False,
|
535
567
|
top_k=ask_request.top_k,
|
536
568
|
)
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
569
|
+
|
570
|
+
nuclia_learning_id = None
|
571
|
+
nuclia_learning_model = None
|
572
|
+
predict_answer_stream = None
|
573
|
+
if ask_request.generate_answer:
|
574
|
+
with metrics.time("stream_start"):
|
575
|
+
predict = get_predict()
|
576
|
+
(
|
577
|
+
nuclia_learning_id,
|
578
|
+
nuclia_learning_model,
|
579
|
+
predict_answer_stream,
|
580
|
+
) = await predict.chat_query_ndjson(kbid, chat_model)
|
545
581
|
|
546
582
|
auditor = ChatAuditor(
|
547
583
|
kbid=kbid,
|
@@ -562,13 +598,13 @@ async def ask(
|
|
562
598
|
main_results=retrieval_results.main_query,
|
563
599
|
prequeries_results=retrieval_results.prequeries,
|
564
600
|
nuclia_learning_id=nuclia_learning_id,
|
565
|
-
predict_answer_stream=predict_answer_stream,
|
601
|
+
predict_answer_stream=predict_answer_stream,
|
566
602
|
prompt_context=prompt_context,
|
567
603
|
prompt_context_order=prompt_context_order,
|
568
604
|
auditor=auditor,
|
569
605
|
metrics=metrics,
|
570
606
|
best_matches=retrieval_results.best_matches,
|
571
|
-
debug_chat_model=
|
607
|
+
debug_chat_model=chat_model,
|
572
608
|
)
|
573
609
|
|
574
610
|
|
@@ -629,6 +665,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
|
|
629
665
|
return None
|
630
666
|
|
631
667
|
|
668
|
+
def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
|
669
|
+
for rag_strategy in ask_request.rag_strategies:
|
670
|
+
if rag_strategy.name == RagStrategyName.GRAPH:
|
671
|
+
return cast(GraphStrategy, rag_strategy)
|
672
|
+
return None
|
673
|
+
|
674
|
+
|
632
675
|
async def retrieval_step(
|
633
676
|
kbid: str,
|
634
677
|
main_query: str,
|
@@ -675,6 +718,7 @@ async def retrieval_in_kb(
|
|
675
718
|
metrics: RAGMetrics,
|
676
719
|
) -> RetrievalResults:
|
677
720
|
prequeries = parse_prequeries(ask_request)
|
721
|
+
graph_strategy = parse_graph_strategy(ask_request)
|
678
722
|
with metrics.time("retrieval"):
|
679
723
|
main_results, prequeries_results, query_parser = await get_find_results(
|
680
724
|
kbid=kbid,
|
@@ -686,6 +730,26 @@ async def retrieval_in_kb(
|
|
686
730
|
metrics=metrics,
|
687
731
|
prequeries_strategy=prequeries,
|
688
732
|
)
|
733
|
+
|
734
|
+
if graph_strategy is not None:
|
735
|
+
graph_results, graph_request = await get_graph_results(
|
736
|
+
kbid=kbid,
|
737
|
+
query=main_query,
|
738
|
+
item=ask_request,
|
739
|
+
ndb_client=client_type,
|
740
|
+
user=user_id,
|
741
|
+
origin=origin,
|
742
|
+
graph_strategy=graph_strategy,
|
743
|
+
metrics=metrics,
|
744
|
+
shards=ask_request.shards,
|
745
|
+
)
|
746
|
+
|
747
|
+
if prequeries_results is None:
|
748
|
+
prequeries_results = []
|
749
|
+
|
750
|
+
prequery = PreQuery(id="graph", request=graph_request, weight=graph_strategy.weight)
|
751
|
+
prequeries_results.append((prequery, graph_results))
|
752
|
+
|
689
753
|
if len(main_results.resources) == 0 and all(
|
690
754
|
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
691
755
|
):
|
@@ -28,6 +28,7 @@ from pydantic import BaseModel
|
|
28
28
|
|
29
29
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
|
30
30
|
from nucliadb.common.maindb.utils import get_driver
|
31
|
+
from nucliadb.common.models_utils import from_proto
|
31
32
|
from nucliadb.ingest.fields.base import Field
|
32
33
|
from nucliadb.ingest.fields.conversation import Conversation
|
33
34
|
from nucliadb.ingest.fields.file import File
|
@@ -41,6 +42,7 @@ from nucliadb.search.search.chat.images import (
|
|
41
42
|
)
|
42
43
|
from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
|
43
44
|
from nucliadb.search.search.paragraphs import get_paragraph_text
|
45
|
+
from nucliadb_models.labels import translate_alias_to_system_label
|
44
46
|
from nucliadb_models.metadata import Extra, Origin
|
45
47
|
from nucliadb_models.search import (
|
46
48
|
SCORE_TYPE,
|
@@ -49,6 +51,7 @@ from nucliadb_models.search import (
|
|
49
51
|
FindParagraph,
|
50
52
|
FullResourceStrategy,
|
51
53
|
HierarchyResourceStrategy,
|
54
|
+
Image,
|
52
55
|
ImageRagStrategy,
|
53
56
|
ImageRagStrategyName,
|
54
57
|
MetadataExtensionStrategy,
|
@@ -266,7 +269,9 @@ async def full_resource_prompt_context(
|
|
266
269
|
if strategy.apply_to is not None:
|
267
270
|
# decide whether the resource should be extended or not
|
268
271
|
for label in strategy.apply_to.exclude:
|
269
|
-
skip = skip or (
|
272
|
+
skip = skip or (
|
273
|
+
translate_alias_to_system_label(label) in (paragraph.labels or [])
|
274
|
+
)
|
270
275
|
|
271
276
|
if not skip:
|
272
277
|
ordered_resources.append(resource_uuid)
|
@@ -346,7 +351,7 @@ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_i
|
|
346
351
|
if resource is not None:
|
347
352
|
pb_origin = await resource.get_origin()
|
348
353
|
if pb_origin is not None:
|
349
|
-
origin =
|
354
|
+
origin = from_proto.origin(pb_origin)
|
350
355
|
return rid, origin
|
351
356
|
|
352
357
|
rids = {tb_id.rid for tb_id in text_block_ids}
|
@@ -433,7 +438,7 @@ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_id
|
|
433
438
|
if resource is not None:
|
434
439
|
pb_extra = await resource.get_extra()
|
435
440
|
if pb_extra is not None:
|
436
|
-
extra =
|
441
|
+
extra = from_proto.extra(pb_extra)
|
437
442
|
return rid, extra
|
438
443
|
|
439
444
|
rids = {tb_id.rid for tb_id in text_block_ids}
|
@@ -876,6 +881,7 @@ class PromptContextBuilder:
|
|
876
881
|
ordered_paragraphs: list[FindParagraph],
|
877
882
|
resource: Optional[str] = None,
|
878
883
|
user_context: Optional[list[str]] = None,
|
884
|
+
user_image_context: Optional[list[Image]] = None,
|
879
885
|
strategies: Optional[Sequence[RagStrategy]] = None,
|
880
886
|
image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
|
881
887
|
max_context_characters: Optional[int] = None,
|
@@ -885,6 +891,7 @@ class PromptContextBuilder:
|
|
885
891
|
self.ordered_paragraphs = ordered_paragraphs
|
886
892
|
self.resource = resource
|
887
893
|
self.user_context = user_context
|
894
|
+
self.user_image_context = user_image_context
|
888
895
|
self.strategies = strategies
|
889
896
|
self.image_strategies = image_strategies
|
890
897
|
self.max_context_characters = max_context_characters
|
@@ -895,6 +902,8 @@ class PromptContextBuilder:
|
|
895
902
|
# it is added first, followed by the found text blocks in order of relevance
|
896
903
|
for i, text_block in enumerate(self.user_context or []):
|
897
904
|
context[f"USER_CONTEXT_{i}"] = text_block
|
905
|
+
for i, image in enumerate(self.user_image_context or []):
|
906
|
+
context.images[f"USER_IMAGE_CONTEXT_{i}"] = image
|
898
907
|
|
899
908
|
async def build(
|
900
909
|
self,
|
@@ -1012,8 +1021,10 @@ class PromptContextBuilder:
|
|
1012
1021
|
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
1013
1022
|
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
1014
1023
|
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
1015
|
-
elif
|
1016
|
-
|
1024
|
+
elif (
|
1025
|
+
strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
|
1026
|
+
): # pragma: no cover
|
1027
|
+
# Prequeries and graph are not handled here
|
1017
1028
|
logger.warning(
|
1018
1029
|
"Unknown rag strategy",
|
1019
1030
|
extra={"strategy": strategy.name, "kbid": self.kbid},
|
@@ -18,8 +18,9 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
-
from typing import Optional
|
21
|
+
from typing import Iterable, Optional
|
22
22
|
|
23
|
+
from nucliadb.common.models_utils import to_proto
|
23
24
|
from nucliadb.search import logger
|
24
25
|
from nucliadb.search.predict import AnswerStatusCode
|
25
26
|
from nucliadb.search.requesters.utils import Method, node_query
|
@@ -49,7 +50,13 @@ from nucliadb_models.search import (
|
|
49
50
|
parse_rephrase_prompt,
|
50
51
|
)
|
51
52
|
from nucliadb_protos import audit_pb2
|
52
|
-
from nucliadb_protos.nodereader_pb2 import
|
53
|
+
from nucliadb_protos.nodereader_pb2 import (
|
54
|
+
EntitiesSubgraphRequest,
|
55
|
+
RelationSearchResponse,
|
56
|
+
SearchRequest,
|
57
|
+
SearchResponse,
|
58
|
+
)
|
59
|
+
from nucliadb_protos.utils_pb2 import RelationNode
|
53
60
|
from nucliadb_telemetry.errors import capture_exception
|
54
61
|
from nucliadb_utils.utilities import get_audit
|
55
62
|
|
@@ -144,15 +151,7 @@ async def get_find_results(
|
|
144
151
|
return main_results, prequeries_results, query_parser
|
145
152
|
|
146
153
|
|
147
|
-
|
148
|
-
kbid: str,
|
149
|
-
query: str,
|
150
|
-
item: AskRequest,
|
151
|
-
ndb_client: NucliaDBClientType,
|
152
|
-
user: str,
|
153
|
-
origin: str,
|
154
|
-
metrics: RAGMetrics = RAGMetrics(),
|
155
|
-
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
154
|
+
def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
|
156
155
|
find_request = FindRequest()
|
157
156
|
find_request.resource_filters = item.resource_filters
|
158
157
|
find_request.features = []
|
@@ -188,7 +187,19 @@ async def run_main_query(
|
|
188
187
|
find_request.show_hidden = item.show_hidden
|
189
188
|
|
190
189
|
# this executes the model validators, that can tweak some fields
|
191
|
-
FindRequest.model_validate(find_request)
|
190
|
+
return FindRequest.model_validate(find_request)
|
191
|
+
|
192
|
+
|
193
|
+
async def run_main_query(
|
194
|
+
kbid: str,
|
195
|
+
query: str,
|
196
|
+
item: AskRequest,
|
197
|
+
ndb_client: NucliaDBClientType,
|
198
|
+
user: str,
|
199
|
+
origin: str,
|
200
|
+
metrics: RAGMetrics = RAGMetrics(),
|
201
|
+
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
202
|
+
find_request = find_request_from_ask_request(item, query)
|
192
203
|
|
193
204
|
find_results, incomplete, query_parser = await find(
|
194
205
|
kbid,
|
@@ -210,36 +221,65 @@ async def get_relations_results(
|
|
210
221
|
text_answer: str,
|
211
222
|
target_shard_replicas: Optional[list[str]],
|
212
223
|
timeout: Optional[float] = None,
|
224
|
+
only_with_metadata: bool = False,
|
225
|
+
only_agentic_relations: bool = False,
|
213
226
|
) -> Relations:
|
214
227
|
try:
|
215
228
|
predict = get_predict()
|
216
229
|
detected_entities = await predict.detect_entities(kbid, text_answer)
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
results: list[SearchResponse]
|
222
|
-
(
|
223
|
-
results,
|
224
|
-
_,
|
225
|
-
_,
|
226
|
-
) = await node_query(
|
227
|
-
kbid,
|
228
|
-
Method.SEARCH,
|
229
|
-
request,
|
230
|
+
|
231
|
+
return await get_relations_results_from_entities(
|
232
|
+
kbid=kbid,
|
233
|
+
entities=detected_entities,
|
230
234
|
target_shard_replicas=target_shard_replicas,
|
231
235
|
timeout=timeout,
|
232
|
-
|
233
|
-
|
236
|
+
only_with_metadata=only_with_metadata,
|
237
|
+
only_agentic_relations=only_agentic_relations,
|
234
238
|
)
|
235
|
-
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
236
|
-
return await merge_relations_results(relations_results, request.relation_subgraph)
|
237
239
|
except Exception as exc:
|
238
240
|
capture_exception(exc)
|
239
241
|
logger.exception("Error getting relations results")
|
240
242
|
return Relations(entities={})
|
241
243
|
|
242
244
|
|
245
|
+
async def get_relations_results_from_entities(
|
246
|
+
*,
|
247
|
+
kbid: str,
|
248
|
+
entities: Iterable[RelationNode],
|
249
|
+
target_shard_replicas: Optional[list[str]],
|
250
|
+
timeout: Optional[float] = None,
|
251
|
+
only_with_metadata: bool = False,
|
252
|
+
only_agentic_relations: bool = False,
|
253
|
+
deleted_entities: set[str] = set(),
|
254
|
+
) -> Relations:
|
255
|
+
request = SearchRequest()
|
256
|
+
request.relation_subgraph.entry_points.extend(entities)
|
257
|
+
request.relation_subgraph.depth = 1
|
258
|
+
|
259
|
+
deleted = EntitiesSubgraphRequest.DeletedEntities()
|
260
|
+
deleted.node_values.extend(deleted_entities)
|
261
|
+
request.relation_subgraph.deleted_entities.append(deleted)
|
262
|
+
|
263
|
+
results: list[SearchResponse]
|
264
|
+
(
|
265
|
+
results,
|
266
|
+
_,
|
267
|
+
_,
|
268
|
+
) = await node_query(
|
269
|
+
kbid,
|
270
|
+
Method.SEARCH,
|
271
|
+
request,
|
272
|
+
target_shard_replicas=target_shard_replicas,
|
273
|
+
timeout=timeout,
|
274
|
+
use_read_replica_nodes=True,
|
275
|
+
retry_on_primary=False,
|
276
|
+
)
|
277
|
+
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
278
|
+
return await merge_relations_results(
|
279
|
+
relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations
|
280
|
+
)
|
281
|
+
|
282
|
+
|
243
283
|
def maybe_audit_chat(
|
244
284
|
*,
|
245
285
|
kbid: str,
|
@@ -256,8 +296,8 @@ def maybe_audit_chat(
|
|
256
296
|
chat_history: list[ChatContextMessage],
|
257
297
|
query_context: PromptContext,
|
258
298
|
query_context_order: PromptContextOrder,
|
259
|
-
learning_id: str,
|
260
|
-
model: str,
|
299
|
+
learning_id: Optional[str],
|
300
|
+
model: Optional[str],
|
261
301
|
):
|
262
302
|
audit = get_audit()
|
263
303
|
if audit is None:
|
@@ -278,7 +318,7 @@ def maybe_audit_chat(
|
|
278
318
|
audit.chat(
|
279
319
|
kbid,
|
280
320
|
user_id,
|
281
|
-
client_type
|
321
|
+
to_proto.client_type(client_type),
|
282
322
|
origin,
|
283
323
|
question=user_query,
|
284
324
|
generative_answer_time=generative_answer_time,
|
@@ -295,7 +335,7 @@ def maybe_audit_chat(
|
|
295
335
|
|
296
336
|
|
297
337
|
def parse_audit_answer(raw_text_answer: bytes, status_code: AnswerStatusCode) -> Optional[str]:
|
298
|
-
if status_code == AnswerStatusCode.NO_CONTEXT:
|
338
|
+
if status_code == AnswerStatusCode.NO_CONTEXT or status_code == AnswerStatusCode.NO_RETRIEVAL_DATA:
|
299
339
|
# We don't want to audit "Not enough context to answer this." and instead set a None.
|
300
340
|
return None
|
301
341
|
return raw_text_answer.decode()
|
@@ -320,7 +360,7 @@ class ChatAuditor:
|
|
320
360
|
learning_id: Optional[str],
|
321
361
|
query_context: PromptContext,
|
322
362
|
query_context_order: PromptContextOrder,
|
323
|
-
model: str,
|
363
|
+
model: Optional[str],
|
324
364
|
):
|
325
365
|
self.kbid = kbid
|
326
366
|
self.user_id = user_id
|
@@ -17,6 +17,8 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
|
20
|
+
from nucliadb.search.search.query_parser.exceptions import InvalidQueryError as InvalidQueryError
|
21
|
+
|
20
22
|
|
21
23
|
class IncompleteFindResultsError(Exception):
|
22
24
|
pass
|
@@ -24,10 +26,3 @@ class IncompleteFindResultsError(Exception):
|
|
24
26
|
|
25
27
|
class ResourceNotFoundError(Exception):
|
26
28
|
pass
|
27
|
-
|
28
|
-
|
29
|
-
class InvalidQueryError(Exception):
|
30
|
-
def __init__(self, param: str, reason: str):
|
31
|
-
self.param = param
|
32
|
-
self.reason = reason
|
33
|
-
super().__init__(f"Invalid query. Error in {param}: {reason}")
|
nucliadb/search/search/find.py
CHANGED
@@ -24,6 +24,7 @@ from typing import Optional
|
|
24
24
|
|
25
25
|
from nucliadb.common.external_index_providers.base import ExternalIndexManager
|
26
26
|
from nucliadb.common.external_index_providers.manager import get_external_index_manager
|
27
|
+
from nucliadb.common.models_utils import to_proto
|
27
28
|
from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query
|
28
29
|
from nucliadb.search.search.find_merge import (
|
29
30
|
build_find_response,
|
@@ -105,7 +106,7 @@ async def _index_node_retrieval(
|
|
105
106
|
kbid, item, generative_model=generative_model
|
106
107
|
)
|
107
108
|
with metrics.time("query_parse"):
|
108
|
-
pb_query, incomplete_results, autofilters = await query_parser.parse()
|
109
|
+
pb_query, incomplete_results, autofilters, rephrased_query = await query_parser.parse()
|
109
110
|
|
110
111
|
with metrics.time("node_query"):
|
111
112
|
results, query_incomplete_results, queried_nodes = await node_query(
|
@@ -119,7 +120,8 @@ async def _index_node_retrieval(
|
|
119
120
|
results,
|
120
121
|
kbid=kbid,
|
121
122
|
query=pb_query.body,
|
122
|
-
|
123
|
+
rephrased_query=rephrased_query,
|
124
|
+
relation_subgraph_query=pb_query.relation_subgraph,
|
123
125
|
min_score_bm25=pb_query.min_score_bm25,
|
124
126
|
min_score_semantic=pb_query.min_score_semantic,
|
125
127
|
top_k=item.top_k,
|
@@ -136,7 +138,7 @@ async def _index_node_retrieval(
|
|
136
138
|
audit.search(
|
137
139
|
kbid,
|
138
140
|
x_nucliadb_user,
|
139
|
-
|
141
|
+
to_proto.client_type(x_ndb_client),
|
140
142
|
x_forwarded_for,
|
141
143
|
pb_query,
|
142
144
|
search_time,
|
@@ -193,7 +195,7 @@ async def _external_index_retrieval(
|
|
193
195
|
query_parser, _, reranker = await query_parser_from_find_request(
|
194
196
|
kbid, item, generative_model=generative_model
|
195
197
|
)
|
196
|
-
search_request, incomplete_results, _ = await query_parser.parse()
|
198
|
+
search_request, incomplete_results, _, rephrased_query = await query_parser.parse()
|
197
199
|
|
198
200
|
# Query index
|
199
201
|
query_results = await external_index_manager.query(search_request) # noqa
|
@@ -224,6 +226,7 @@ async def _external_index_retrieval(
|
|
224
226
|
retrieval_results = KnowledgeboxFindResults(
|
225
227
|
resources=find_resources,
|
226
228
|
query=item.query,
|
229
|
+
rephrased_query=rephrased_query,
|
227
230
|
total=0,
|
228
231
|
page_number=0,
|
229
232
|
page_size=item.top_k,
|
@@ -259,7 +262,7 @@ async def query_parser_from_find_request(
|
|
259
262
|
# XXX this is becoming the new /find query parsing, this should be moved to
|
260
263
|
# a cleaner abstraction
|
261
264
|
|
262
|
-
parsed = parse_find(item)
|
265
|
+
parsed = await parse_find(kbid, item)
|
263
266
|
|
264
267
|
rank_fusion = get_rank_fusion(parsed.rank_fusion)
|
265
268
|
reranker = get_reranker(parsed.reranker)
|
@@ -268,6 +271,7 @@ async def query_parser_from_find_request(
|
|
268
271
|
kbid=kbid,
|
269
272
|
features=item.features,
|
270
273
|
query=item.query,
|
274
|
+
query_entities=item.query_entities,
|
271
275
|
label_filters=item.filters,
|
272
276
|
keyword_filters=item.keyword_filters,
|
273
277
|
faceted=None,
|
@@ -18,7 +18,7 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
-
from typing import Iterable, Union
|
21
|
+
from typing import Iterable, Optional, Union
|
22
22
|
|
23
23
|
from nucliadb.common.external_index_providers.base import TextBlockMatch
|
24
24
|
from nucliadb.common.ids import ParagraphId, VectorId
|
@@ -74,6 +74,7 @@ async def build_find_response(
|
|
74
74
|
*,
|
75
75
|
kbid: str,
|
76
76
|
query: str,
|
77
|
+
rephrased_query: Optional[str],
|
77
78
|
relation_subgraph_query: EntitiesSubgraphRequest,
|
78
79
|
top_k: int,
|
79
80
|
min_score_bm25: float,
|
@@ -96,9 +97,13 @@ async def build_find_response(
|
|
96
97
|
)
|
97
98
|
)
|
98
99
|
|
99
|
-
merged_text_blocks: list[TextBlockMatch]
|
100
|
-
|
101
|
-
|
100
|
+
merged_text_blocks: list[TextBlockMatch]
|
101
|
+
if len(keyword_results) == 0:
|
102
|
+
merged_text_blocks = semantic_results
|
103
|
+
elif len(semantic_results) == 0:
|
104
|
+
merged_text_blocks = keyword_results
|
105
|
+
else:
|
106
|
+
merged_text_blocks = rank_fusion_algorithm.fuse(keyword_results, semantic_results)
|
102
107
|
|
103
108
|
# cut
|
104
109
|
# we assume pagination + predict reranker is forbidden and has been already
|
@@ -139,6 +144,7 @@ async def build_find_response(
|
|
139
144
|
|
140
145
|
find_results = KnowledgeboxFindResults(
|
141
146
|
query=query,
|
147
|
+
rephrased_query=rephrased_query,
|
142
148
|
resources=find_resources,
|
143
149
|
best_matches=best_matches,
|
144
150
|
relations=relations,
|