nucliadb 6.4.0.post4265__py3-none-any.whl → 6.4.0.post4276__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/predict.py +19 -9
- nucliadb/search/search/chat/ask.py +63 -56
- nucliadb/search/search/chat/prompt.py +46 -6
- nucliadb/search/search/chat/query.py +4 -2
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4276.dist-info}/METADATA +6 -6
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4276.dist-info}/RECORD +9 -9
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4276.dist-info}/WHEEL +0 -0
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4276.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4276.dist-info}/top_level.txt +0 -0
nucliadb/search/predict.py
CHANGED
@@ -22,6 +22,7 @@ import json
|
|
22
22
|
import logging
|
23
23
|
import os
|
24
24
|
import random
|
25
|
+
from dataclasses import dataclass
|
25
26
|
from enum import Enum
|
26
27
|
from typing import Any, AsyncGenerator, Optional
|
27
28
|
from unittest.mock import AsyncMock, Mock
|
@@ -108,7 +109,7 @@ RERANK = "/rerank"
|
|
108
109
|
|
109
110
|
NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
|
110
111
|
NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"
|
111
|
-
|
112
|
+
NUCLIA_LEARNING_CHAT_HISTORY_HEADER = "NUCLIA-LEARNING-CHAT-HISTORY"
|
112
113
|
|
113
114
|
predict_observer = metrics.Observer(
|
114
115
|
"predict_engine",
|
@@ -139,6 +140,12 @@ class AnswerStatusCode(str, Enum):
|
|
139
140
|
}[self]
|
140
141
|
|
141
142
|
|
143
|
+
@dataclass
|
144
|
+
class RephraseResponse:
|
145
|
+
rephrased_query: str
|
146
|
+
use_chat_history: Optional[bool]
|
147
|
+
|
148
|
+
|
142
149
|
async def start_predict_engine():
|
143
150
|
if nuclia_settings.dummy_predict:
|
144
151
|
predict_util = DummyPredictEngine()
|
@@ -267,7 +274,7 @@ class PredictEngine:
|
|
267
274
|
return await func(**request_args)
|
268
275
|
|
269
276
|
@predict_observer.wrap({"type": "rephrase"})
|
270
|
-
async def rephrase_query(self, kbid: str, item: RephraseModel) ->
|
277
|
+
async def rephrase_query(self, kbid: str, item: RephraseModel) -> RephraseResponse:
|
271
278
|
try:
|
272
279
|
self.check_nua_key_is_configured_for_onprem()
|
273
280
|
except NUAKeyMissingError:
|
@@ -477,9 +484,9 @@ class DummyPredictEngine(PredictEngine):
|
|
477
484
|
response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
|
478
485
|
return response
|
479
486
|
|
480
|
-
async def rephrase_query(self, kbid: str, item: RephraseModel) ->
|
487
|
+
async def rephrase_query(self, kbid: str, item: RephraseModel) -> RephraseResponse:
|
481
488
|
self.calls.append(("rephrase_query", item))
|
482
|
-
return DUMMY_REPHRASE_QUERY
|
489
|
+
return RephraseResponse(rephrased_query=DUMMY_REPHRASE_QUERY, use_chat_history=None)
|
483
490
|
|
484
491
|
async def chat_query_ndjson(
|
485
492
|
self, kbid: str, item: ChatModel
|
@@ -624,7 +631,7 @@ def get_chat_ndjson_generator(
|
|
624
631
|
|
625
632
|
async def _parse_rephrase_response(
|
626
633
|
resp: aiohttp.ClientResponse,
|
627
|
-
) ->
|
634
|
+
) -> RephraseResponse:
|
628
635
|
"""
|
629
636
|
Predict api is returning a json payload that is a string with the following format:
|
630
637
|
<rephrased_query><status_code>
|
@@ -632,12 +639,15 @@ async def _parse_rephrase_response(
|
|
632
639
|
it will raise an exception if the status code is not 0
|
633
640
|
"""
|
634
641
|
content = await resp.json()
|
642
|
+
|
635
643
|
if content.endswith("0"):
|
636
|
-
|
644
|
+
content = content[:-1]
|
637
645
|
elif content.endswith("-1"):
|
638
646
|
raise RephraseError(content[:-2])
|
639
647
|
elif content.endswith("-2"):
|
640
648
|
raise RephraseMissingContextError(content[:-2])
|
641
|
-
|
642
|
-
|
643
|
-
|
649
|
+
|
650
|
+
use_chat_history = None
|
651
|
+
if NUCLIA_LEARNING_CHAT_HISTORY_HEADER in resp.headers:
|
652
|
+
use_chat_history = resp.headers[NUCLIA_LEARNING_CHAT_HISTORY_HEADER] == "true"
|
653
|
+
return RephraseResponse(rephrased_query=content, use_chat_history=use_chat_history)
|
@@ -488,29 +488,37 @@ async def ask(
|
|
488
488
|
if len(chat_history) > 0 or len(user_context) > 0:
|
489
489
|
try:
|
490
490
|
with metrics.time("rephrase"):
|
491
|
-
|
491
|
+
rephrase_response = await rephrase_query(
|
492
492
|
kbid,
|
493
493
|
chat_history=chat_history,
|
494
494
|
query=user_query,
|
495
495
|
user_id=user_id,
|
496
496
|
user_context=user_context,
|
497
497
|
generative_model=ask_request.generative_model,
|
498
|
+
chat_history_relevance_threshold=ask_request.chat_history_relevance_threshold,
|
498
499
|
)
|
500
|
+
rephrased_query = rephrase_response.rephrased_query
|
501
|
+
if rephrase_response.use_chat_history is False:
|
502
|
+
# Ignored if the question is not relevant enough with the chat history
|
503
|
+
logger.info("Chat history was ignored for this request")
|
504
|
+
chat_history = []
|
505
|
+
|
499
506
|
except RephraseMissingContextError:
|
500
507
|
logger.info("Failed to rephrase ask query, using original")
|
501
508
|
|
502
509
|
try:
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
510
|
+
with metrics.time("retrieval"):
|
511
|
+
retrieval_results = await retrieval_step(
|
512
|
+
kbid=kbid,
|
513
|
+
# Prefer the rephrased query for retrieval if available
|
514
|
+
main_query=rephrased_query or user_query,
|
515
|
+
ask_request=ask_request,
|
516
|
+
client_type=client_type,
|
517
|
+
user_id=user_id,
|
518
|
+
origin=origin,
|
519
|
+
metrics=metrics,
|
520
|
+
resource=resource,
|
521
|
+
)
|
514
522
|
except NoRetrievalResultsError as err:
|
515
523
|
maybe_audit_chat(
|
516
524
|
kbid=kbid,
|
@@ -555,6 +563,7 @@ async def ask(
|
|
555
563
|
image_strategies=ask_request.rag_images_strategies,
|
556
564
|
max_context_characters=tokens_to_chars(generation.max_context_tokens),
|
557
565
|
visual_llm=generation.use_visual_llm,
|
566
|
+
metrics=metrics.child_span("context_building"),
|
558
567
|
)
|
559
568
|
(
|
560
569
|
prompt_context,
|
@@ -740,45 +749,44 @@ async def retrieval_in_kb(
|
|
740
749
|
) -> RetrievalResults:
|
741
750
|
prequeries = parse_prequeries(ask_request)
|
742
751
|
graph_strategy = parse_graph_strategy(ask_request)
|
743
|
-
|
744
|
-
|
752
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
753
|
+
kbid=kbid,
|
754
|
+
query=main_query,
|
755
|
+
item=ask_request,
|
756
|
+
ndb_client=client_type,
|
757
|
+
user=user_id,
|
758
|
+
origin=origin,
|
759
|
+
metrics=metrics.child_span("hybrid_retrieval"),
|
760
|
+
prequeries_strategy=prequeries,
|
761
|
+
)
|
762
|
+
|
763
|
+
if graph_strategy is not None:
|
764
|
+
assert parsed_query.retrieval.reranker is not None, (
|
765
|
+
"find parser must provide a reranking algorithm"
|
766
|
+
)
|
767
|
+
reranker = get_reranker(parsed_query.retrieval.reranker)
|
768
|
+
graph_results, graph_request = await get_graph_results(
|
745
769
|
kbid=kbid,
|
746
770
|
query=main_query,
|
747
771
|
item=ask_request,
|
748
772
|
ndb_client=client_type,
|
749
773
|
user=user_id,
|
750
774
|
origin=origin,
|
751
|
-
|
752
|
-
|
775
|
+
graph_strategy=graph_strategy,
|
776
|
+
metrics=metrics.child_span("graph_retrieval"),
|
777
|
+
text_block_reranker=reranker,
|
753
778
|
)
|
754
779
|
|
755
|
-
if
|
756
|
-
|
757
|
-
"find parser must provide a reranking algorithm"
|
758
|
-
)
|
759
|
-
reranker = get_reranker(parsed_query.retrieval.reranker)
|
760
|
-
graph_results, graph_request = await get_graph_results(
|
761
|
-
kbid=kbid,
|
762
|
-
query=main_query,
|
763
|
-
item=ask_request,
|
764
|
-
ndb_client=client_type,
|
765
|
-
user=user_id,
|
766
|
-
origin=origin,
|
767
|
-
graph_strategy=graph_strategy,
|
768
|
-
metrics=metrics,
|
769
|
-
text_block_reranker=reranker,
|
770
|
-
)
|
771
|
-
|
772
|
-
if prequeries_results is None:
|
773
|
-
prequeries_results = []
|
780
|
+
if prequeries_results is None:
|
781
|
+
prequeries_results = []
|
774
782
|
|
775
|
-
|
776
|
-
|
783
|
+
prequery = PreQuery(id="graph", request=graph_request, weight=graph_strategy.weight)
|
784
|
+
prequeries_results.append((prequery, graph_results))
|
777
785
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
786
|
+
if len(main_results.resources) == 0 and all(
|
787
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
788
|
+
):
|
789
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
782
790
|
|
783
791
|
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
784
792
|
best_matches = compute_best_matches(
|
@@ -829,21 +837,20 @@ async def retrieval_in_resource(
|
|
829
837
|
)
|
830
838
|
add_resource_filter(prequery.request, [resource])
|
831
839
|
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
)
|
846
|
-
raise NoRetrievalResultsError(main_results, prequeries_results)
|
840
|
+
main_results, prequeries_results, parsed_query = await get_find_results(
|
841
|
+
kbid=kbid,
|
842
|
+
query=main_query,
|
843
|
+
item=ask_request,
|
844
|
+
ndb_client=client_type,
|
845
|
+
user=user_id,
|
846
|
+
origin=origin,
|
847
|
+
metrics=metrics.child_span("hybrid_retrieval"),
|
848
|
+
prequeries_strategy=prequeries,
|
849
|
+
)
|
850
|
+
if len(main_results.resources) == 0 and all(
|
851
|
+
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
852
|
+
):
|
853
|
+
raise NoRetrievalResultsError(main_results, prequeries_results)
|
847
854
|
main_query_weight = prequeries.main_query_weight if prequeries is not None else 1.0
|
848
855
|
best_matches = compute_best_matches(
|
849
856
|
main_results=main_results,
|
@@ -41,6 +41,7 @@ from nucliadb.search.search.chat.images import (
|
|
41
41
|
get_paragraph_image,
|
42
42
|
)
|
43
43
|
from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
|
44
|
+
from nucliadb.search.search.metrics import Metrics
|
44
45
|
from nucliadb.search.search.paragraphs import get_paragraph_text
|
45
46
|
from nucliadb_models.labels import translate_alias_to_system_label
|
46
47
|
from nucliadb_models.metadata import Extra, Origin
|
@@ -244,6 +245,7 @@ async def full_resource_prompt_context(
|
|
244
245
|
ordered_paragraphs: list[FindParagraph],
|
245
246
|
resource: Optional[str],
|
246
247
|
strategy: FullResourceStrategy,
|
248
|
+
metrics: Metrics,
|
247
249
|
) -> None:
|
248
250
|
"""
|
249
251
|
Algorithm steps:
|
@@ -298,6 +300,8 @@ async def full_resource_prompt_context(
|
|
298
300
|
context[field.full()] = extracted_text
|
299
301
|
added_fields.add(field.full())
|
300
302
|
|
303
|
+
metrics.set("full_resource_ops", len(added_fields))
|
304
|
+
|
301
305
|
if strategy.include_remaining_text_blocks:
|
302
306
|
for paragraph in ordered_paragraphs:
|
303
307
|
pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
|
@@ -309,6 +313,7 @@ async def extend_prompt_context_with_metadata(
|
|
309
313
|
context: CappedPromptContext,
|
310
314
|
kbid: str,
|
311
315
|
strategy: MetadataExtensionStrategy,
|
316
|
+
metrics: Metrics,
|
312
317
|
) -> None:
|
313
318
|
text_block_ids: list[TextBlockId] = []
|
314
319
|
for text_block_id in context.text_block_ids():
|
@@ -321,18 +326,25 @@ async def extend_prompt_context_with_metadata(
|
|
321
326
|
if len(text_block_ids) == 0: # pragma: no cover
|
322
327
|
return
|
323
328
|
|
329
|
+
ops = 0
|
324
330
|
if MetadataExtensionType.ORIGIN in strategy.types:
|
331
|
+
ops += 1
|
325
332
|
await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
|
326
333
|
|
327
334
|
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
335
|
+
ops += 1
|
328
336
|
await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
|
329
337
|
|
330
338
|
if MetadataExtensionType.NERS in strategy.types:
|
339
|
+
ops += 1
|
331
340
|
await extend_prompt_context_with_ner(context, kbid, text_block_ids)
|
332
341
|
|
333
342
|
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
343
|
+
ops += 1
|
334
344
|
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
|
335
345
|
|
346
|
+
metrics.set("metadata_extension_ops", ops * len(text_block_ids))
|
347
|
+
|
336
348
|
|
337
349
|
def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
338
350
|
try:
|
@@ -464,6 +476,7 @@ async def field_extension_prompt_context(
|
|
464
476
|
kbid: str,
|
465
477
|
ordered_paragraphs: list[FindParagraph],
|
466
478
|
strategy: FieldExtensionStrategy,
|
479
|
+
metrics: Metrics,
|
467
480
|
) -> None:
|
468
481
|
"""
|
469
482
|
Algorithm steps:
|
@@ -493,6 +506,8 @@ async def field_extension_prompt_context(
|
|
493
506
|
tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
|
494
507
|
field_extracted_texts = await run_concurrently(tasks)
|
495
508
|
|
509
|
+
metrics.set("field_extension_ops", len(field_extracted_texts))
|
510
|
+
|
496
511
|
for result in field_extracted_texts:
|
497
512
|
if result is None: # pragma: no cover
|
498
513
|
continue
|
@@ -619,6 +634,7 @@ async def neighbouring_paragraphs_prompt_context(
|
|
619
634
|
kbid: str,
|
620
635
|
ordered_text_blocks: list[FindParagraph],
|
621
636
|
strategy: NeighbouringParagraphsStrategy,
|
637
|
+
metrics: Metrics,
|
622
638
|
) -> None:
|
623
639
|
"""
|
624
640
|
This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
|
@@ -658,6 +674,9 @@ async def neighbouring_paragraphs_prompt_context(
|
|
658
674
|
return
|
659
675
|
|
660
676
|
results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
|
677
|
+
|
678
|
+
metrics.set("neighbouring_paragraphs_ops", len(results))
|
679
|
+
|
661
680
|
# Add the paragraph texts to the context
|
662
681
|
for pid, text in results:
|
663
682
|
if text != "":
|
@@ -670,8 +689,10 @@ async def conversation_prompt_context(
|
|
670
689
|
ordered_paragraphs: list[FindParagraph],
|
671
690
|
conversational_strategy: ConversationalStrategy,
|
672
691
|
visual_llm: bool,
|
692
|
+
metrics: Metrics,
|
673
693
|
):
|
674
694
|
analyzed_fields: List[str] = []
|
695
|
+
ops = 0
|
675
696
|
async with get_driver().transaction(read_only=True) as txn:
|
676
697
|
storage = await get_storage()
|
677
698
|
kb = KnowledgeBoxORM(txn, storage, kbid)
|
@@ -701,6 +722,7 @@ async def conversation_prompt_context(
|
|
701
722
|
|
702
723
|
attachments: List[resources_pb2.FieldRef] = []
|
703
724
|
if conversational_strategy.full:
|
725
|
+
ops += 5
|
704
726
|
extracted_text = await field_obj.get_extracted_text()
|
705
727
|
for current_page in range(1, cmetadata.pages + 1):
|
706
728
|
conv = await field_obj.db_get_value(current_page)
|
@@ -749,6 +771,7 @@ async def conversation_prompt_context(
|
|
749
771
|
break
|
750
772
|
|
751
773
|
for message in messages:
|
774
|
+
ops += 1
|
752
775
|
text = message.content.text.strip()
|
753
776
|
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
|
754
777
|
context[pid] = text
|
@@ -757,6 +780,7 @@ async def conversation_prompt_context(
|
|
757
780
|
if conversational_strategy.attachments_text:
|
758
781
|
# add on the context the images if vlm enabled
|
759
782
|
for attachment in attachments:
|
783
|
+
ops += 1
|
760
784
|
field: File = await resource.get_field(
|
761
785
|
attachment.field_id, attachment.field_type, load=True
|
762
786
|
) # type: ignore
|
@@ -767,6 +791,7 @@ async def conversation_prompt_context(
|
|
767
791
|
|
768
792
|
if conversational_strategy.attachments_images and visual_llm:
|
769
793
|
for attachment in attachments:
|
794
|
+
ops += 1
|
770
795
|
file_field: File = await resource.get_field(
|
771
796
|
attachment.field_id, attachment.field_type, load=True
|
772
797
|
) # type: ignore
|
@@ -776,6 +801,7 @@ async def conversation_prompt_context(
|
|
776
801
|
context.images[pid] = image
|
777
802
|
|
778
803
|
analyzed_fields.append(field_unique_id)
|
804
|
+
metrics.set("conversation_ops", ops)
|
779
805
|
|
780
806
|
|
781
807
|
async def hierarchy_prompt_context(
|
@@ -783,6 +809,7 @@ async def hierarchy_prompt_context(
|
|
783
809
|
kbid: str,
|
784
810
|
ordered_paragraphs: list[FindParagraph],
|
785
811
|
strategy: HierarchyResourceStrategy,
|
812
|
+
metrics: Metrics,
|
786
813
|
) -> None:
|
787
814
|
"""
|
788
815
|
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
|
@@ -842,6 +869,8 @@ async def hierarchy_prompt_context(
|
|
842
869
|
else:
|
843
870
|
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
|
844
871
|
|
872
|
+
metrics.set("hierarchy_ops", len(resources))
|
873
|
+
|
845
874
|
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
846
875
|
# extended paragraph text of all the paragraphs in the resource.
|
847
876
|
for values in resources.values():
|
@@ -886,6 +915,7 @@ class PromptContextBuilder:
|
|
886
915
|
image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
|
887
916
|
max_context_characters: Optional[int] = None,
|
888
917
|
visual_llm: bool = False,
|
918
|
+
metrics: Metrics = Metrics("prompt_context_builder"),
|
889
919
|
):
|
890
920
|
self.kbid = kbid
|
891
921
|
self.ordered_paragraphs = ordered_paragraphs
|
@@ -896,6 +926,7 @@ class PromptContextBuilder:
|
|
896
926
|
self.image_strategies = image_strategies
|
897
927
|
self.max_context_characters = max_context_characters
|
898
928
|
self.visual_llm = visual_llm
|
929
|
+
self.metrics = metrics
|
899
930
|
|
900
931
|
def prepend_user_context(self, context: CappedPromptContext):
|
901
932
|
# Chat extra context passed by the user is the most important, therefore
|
@@ -920,6 +951,7 @@ class PromptContextBuilder:
|
|
920
951
|
return context, context_order, context_images
|
921
952
|
|
922
953
|
async def _build_context_images(self, context: CappedPromptContext) -> None:
|
954
|
+
ops = 0
|
923
955
|
if self.image_strategies is None or len(self.image_strategies) == 0:
|
924
956
|
# Nothing to do
|
925
957
|
return
|
@@ -958,6 +990,7 @@ class PromptContextBuilder:
|
|
958
990
|
if page_image_id not in context.images:
|
959
991
|
image = await get_page_image(self.kbid, pid, paragraph_page_number)
|
960
992
|
if image is not None:
|
993
|
+
ops += 1
|
961
994
|
context.images[page_image_id] = image
|
962
995
|
page_images_added += 1
|
963
996
|
else:
|
@@ -977,6 +1010,7 @@ class PromptContextBuilder:
|
|
977
1010
|
):
|
978
1011
|
pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
|
979
1012
|
if pimage is not None:
|
1013
|
+
ops += 1
|
980
1014
|
context.images[paragraph.id] = pimage
|
981
1015
|
else:
|
982
1016
|
logger.warning(
|
@@ -987,6 +1021,7 @@ class PromptContextBuilder:
|
|
987
1021
|
"reference": paragraph.reference,
|
988
1022
|
},
|
989
1023
|
)
|
1024
|
+
self.metrics.set("image_ops", ops)
|
990
1025
|
|
991
1026
|
async def _build_context(self, context: CappedPromptContext) -> None:
|
992
1027
|
if self.strategies is None or len(self.strategies) == 0:
|
@@ -1038,17 +1073,17 @@ class PromptContextBuilder:
|
|
1038
1073
|
self.ordered_paragraphs,
|
1039
1074
|
self.resource,
|
1040
1075
|
full_resource,
|
1076
|
+
self.metrics,
|
1041
1077
|
)
|
1042
1078
|
if metadata_extension:
|
1043
|
-
await extend_prompt_context_with_metadata(
|
1079
|
+
await extend_prompt_context_with_metadata(
|
1080
|
+
context, self.kbid, metadata_extension, self.metrics
|
1081
|
+
)
|
1044
1082
|
return
|
1045
1083
|
|
1046
1084
|
if hierarchy:
|
1047
1085
|
await hierarchy_prompt_context(
|
1048
|
-
context,
|
1049
|
-
self.kbid,
|
1050
|
-
self.ordered_paragraphs,
|
1051
|
-
hierarchy,
|
1086
|
+
context, self.kbid, self.ordered_paragraphs, hierarchy, self.metrics
|
1052
1087
|
)
|
1053
1088
|
if neighbouring_paragraphs:
|
1054
1089
|
await neighbouring_paragraphs_prompt_context(
|
@@ -1056,6 +1091,7 @@ class PromptContextBuilder:
|
|
1056
1091
|
self.kbid,
|
1057
1092
|
self.ordered_paragraphs,
|
1058
1093
|
neighbouring_paragraphs,
|
1094
|
+
self.metrics,
|
1059
1095
|
)
|
1060
1096
|
if field_extension:
|
1061
1097
|
await field_extension_prompt_context(
|
@@ -1063,6 +1099,7 @@ class PromptContextBuilder:
|
|
1063
1099
|
self.kbid,
|
1064
1100
|
self.ordered_paragraphs,
|
1065
1101
|
field_extension,
|
1102
|
+
self.metrics,
|
1066
1103
|
)
|
1067
1104
|
if conversational_strategy:
|
1068
1105
|
await conversation_prompt_context(
|
@@ -1071,9 +1108,12 @@ class PromptContextBuilder:
|
|
1071
1108
|
self.ordered_paragraphs,
|
1072
1109
|
conversational_strategy,
|
1073
1110
|
self.visual_llm,
|
1111
|
+
self.metrics,
|
1074
1112
|
)
|
1075
1113
|
if metadata_extension:
|
1076
|
-
await extend_prompt_context_with_metadata(
|
1114
|
+
await extend_prompt_context_with_metadata(
|
1115
|
+
context, self.kbid, metadata_extension, self.metrics
|
1116
|
+
)
|
1077
1117
|
|
1078
1118
|
|
1079
1119
|
def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
|
@@ -27,7 +27,7 @@ from nidx_protos.nodereader_pb2 import (
|
|
27
27
|
|
28
28
|
from nucliadb.common.models_utils import to_proto
|
29
29
|
from nucliadb.search import logger
|
30
|
-
from nucliadb.search.predict import AnswerStatusCode
|
30
|
+
from nucliadb.search.predict import AnswerStatusCode, RephraseResponse
|
31
31
|
from nucliadb.search.requesters.utils import Method, node_query
|
32
32
|
from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
|
33
33
|
from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
@@ -71,7 +71,8 @@ async def rephrase_query(
|
|
71
71
|
user_id: str,
|
72
72
|
user_context: list[str],
|
73
73
|
generative_model: Optional[str] = None,
|
74
|
-
|
74
|
+
chat_history_relevance_threshold: Optional[float] = None,
|
75
|
+
) -> RephraseResponse:
|
75
76
|
predict = get_predict()
|
76
77
|
req = RephraseModel(
|
77
78
|
question=query,
|
@@ -79,6 +80,7 @@ async def rephrase_query(
|
|
79
80
|
user_id=user_id,
|
80
81
|
user_context=user_context,
|
81
82
|
generative_model=generative_model,
|
83
|
+
chat_history_relevance_threshold=chat_history_relevance_threshold,
|
82
84
|
)
|
83
85
|
return await predict.rephrase_query(kbid, req)
|
84
86
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nucliadb
|
3
|
-
Version: 6.4.0.
|
3
|
+
Version: 6.4.0.post4276
|
4
4
|
Summary: NucliaDB
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
6
6
|
License: AGPL
|
@@ -20,11 +20,11 @@ Classifier: Programming Language :: Python :: 3.12
|
|
20
20
|
Classifier: Programming Language :: Python :: 3 :: Only
|
21
21
|
Requires-Python: <4,>=3.9
|
22
22
|
Description-Content-Type: text/markdown
|
23
|
-
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.
|
24
|
-
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.
|
25
|
-
Requires-Dist: nucliadb-protos>=6.4.0.
|
26
|
-
Requires-Dist: nucliadb-models>=6.4.0.
|
27
|
-
Requires-Dist: nidx-protos>=6.4.0.
|
23
|
+
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.post4276
|
24
|
+
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.post4276
|
25
|
+
Requires-Dist: nucliadb-protos>=6.4.0.post4276
|
26
|
+
Requires-Dist: nucliadb-models>=6.4.0.post4276
|
27
|
+
Requires-Dist: nidx-protos>=6.4.0.post4276
|
28
28
|
Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
|
29
29
|
Requires-Dist: nuclia-models>=0.24.2
|
30
30
|
Requires-Dist: uvicorn[standard]
|
@@ -204,7 +204,7 @@ nucliadb/search/__init__.py,sha256=tnypbqcH4nBHbGpkINudhKgdLKpwXQCvDtPchUlsyY4,1
|
|
204
204
|
nucliadb/search/app.py,sha256=-WEX1AZRA8R_9aeOo9ovOTwjXW_7VfwWN7N2ccSoqXg,3387
|
205
205
|
nucliadb/search/lifecycle.py,sha256=hiylV-lxsAWkqTCulXBg0EIfMQdejSr8Zar0L_GLFT8,2218
|
206
206
|
nucliadb/search/openapi.py,sha256=t3Wo_4baTrfPftg2BHsyLWNZ1MYn7ZRdW7ht-wFOgRs,1016
|
207
|
-
nucliadb/search/predict.py,sha256=
|
207
|
+
nucliadb/search/predict.py,sha256=__0qwIU2CIRYRTYsbG9zZEjXXrxNe8puZWYJIyOT6dg,23492
|
208
208
|
nucliadb/search/predict_models.py,sha256=ZAe0dneUsPmV9uBar57cCFADCGOrYDsJHuqKlA5zWag,5937
|
209
209
|
nucliadb/search/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
210
210
|
nucliadb/search/run.py,sha256=aFb-CXRi_C8YMpP_ivNj8KW1BYhADj88y8K9Lr_nUPI,1402
|
@@ -255,11 +255,11 @@ nucliadb/search/search/shards.py,sha256=mc5DK-MoCv9AFhlXlOFHbPvetcyNDzTFOJ5rimK8
|
|
255
255
|
nucliadb/search/search/summarize.py,sha256=ksmYPubEQvAQgfPdZHfzB_rR19B2ci4IYZ6jLdHxZo8,4996
|
256
256
|
nucliadb/search/search/utils.py,sha256=ajRIXfdTF67dBVahQCXW-rSv6gJpUMPt3QhJrWqArTQ,2175
|
257
257
|
nucliadb/search/search/chat/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
258
|
-
nucliadb/search/search/chat/ask.py,sha256=
|
258
|
+
nucliadb/search/search/chat/ask.py,sha256=aaNj0MeAbx9dyeKpQJdm3VsHMq9OmcCESxahbgSxvCk,37805
|
259
259
|
nucliadb/search/search/chat/exceptions.py,sha256=Siy4GXW2L7oPhIR86H3WHBhE9lkV4A4YaAszuGGUf54,1356
|
260
260
|
nucliadb/search/search/chat/images.py,sha256=PA8VWxT5_HUGfW1ULhKTK46UBsVyINtWWqEM1ulzX1E,3095
|
261
|
-
nucliadb/search/search/chat/prompt.py,sha256=
|
262
|
-
nucliadb/search/search/chat/query.py,sha256=
|
261
|
+
nucliadb/search/search/chat/prompt.py,sha256=e8C7_MPr6Cn3nJHA4hWpeW3629KVI1ZUQA_wZf9Kiu4,48503
|
262
|
+
nucliadb/search/search/chat/query.py,sha256=6v6twBUTWfUUzklVV6xqJSYPkAshnIrBH9wbTcjQvkI,17063
|
263
263
|
nucliadb/search/search/query_parser/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
264
264
|
nucliadb/search/search/query_parser/exceptions.py,sha256=szAOXUZ27oNY-OSa9t2hQ5HHkQQC0EX1FZz_LluJHJE,1224
|
265
265
|
nucliadb/search/search/query_parser/fetcher.py,sha256=SkvBRDfSKmuz-QygNKLAU4AhZhhDo1dnOZmt1zA28RA,16851
|
@@ -368,8 +368,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
|
|
368
368
|
nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
|
369
369
|
nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
|
370
370
|
nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
|
371
|
-
nucliadb-6.4.0.
|
372
|
-
nucliadb-6.4.0.
|
373
|
-
nucliadb-6.4.0.
|
374
|
-
nucliadb-6.4.0.
|
375
|
-
nucliadb-6.4.0.
|
371
|
+
nucliadb-6.4.0.post4276.dist-info/METADATA,sha256=tOuAZanwYJGX1BhhJIuqi8rnoevYyz3fey7vfyQ_sJ4,4223
|
372
|
+
nucliadb-6.4.0.post4276.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
373
|
+
nucliadb-6.4.0.post4276.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
|
374
|
+
nucliadb-6.4.0.post4276.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
|
375
|
+
nucliadb-6.4.0.post4276.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|