nucliadb 6.4.0.post4265__py3-none-any.whl → 6.4.0.post4271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/predict.py +19 -9
- nucliadb/search/search/chat/ask.py +8 -1
- nucliadb/search/search/chat/query.py +4 -2
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4271.dist-info}/METADATA +6 -6
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4271.dist-info}/RECORD +8 -8
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4271.dist-info}/WHEEL +0 -0
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4271.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.4.0.post4265.dist-info → nucliadb-6.4.0.post4271.dist-info}/top_level.txt +0 -0
nucliadb/search/predict.py
CHANGED
@@ -22,6 +22,7 @@ import json
|
|
22
22
|
import logging
|
23
23
|
import os
|
24
24
|
import random
|
25
|
+
from dataclasses import dataclass
|
25
26
|
from enum import Enum
|
26
27
|
from typing import Any, AsyncGenerator, Optional
|
27
28
|
from unittest.mock import AsyncMock, Mock
|
@@ -108,7 +109,7 @@ RERANK = "/rerank"
|
|
108
109
|
|
109
110
|
NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
|
110
111
|
NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"
|
111
|
-
|
112
|
+
NUCLIA_LEARNING_CHAT_HISTORY_HEADER = "NUCLIA-LEARNING-CHAT-HISTORY"
|
112
113
|
|
113
114
|
predict_observer = metrics.Observer(
|
114
115
|
"predict_engine",
|
@@ -139,6 +140,12 @@ class AnswerStatusCode(str, Enum):
|
|
139
140
|
}[self]
|
140
141
|
|
141
142
|
|
143
|
+
@dataclass
|
144
|
+
class RephraseResponse:
|
145
|
+
rephrased_query: str
|
146
|
+
use_chat_history: Optional[bool]
|
147
|
+
|
148
|
+
|
142
149
|
async def start_predict_engine():
|
143
150
|
if nuclia_settings.dummy_predict:
|
144
151
|
predict_util = DummyPredictEngine()
|
@@ -267,7 +274,7 @@ class PredictEngine:
|
|
267
274
|
return await func(**request_args)
|
268
275
|
|
269
276
|
@predict_observer.wrap({"type": "rephrase"})
|
270
|
-
async def rephrase_query(self, kbid: str, item: RephraseModel) ->
|
277
|
+
async def rephrase_query(self, kbid: str, item: RephraseModel) -> RephraseResponse:
|
271
278
|
try:
|
272
279
|
self.check_nua_key_is_configured_for_onprem()
|
273
280
|
except NUAKeyMissingError:
|
@@ -477,9 +484,9 @@ class DummyPredictEngine(PredictEngine):
|
|
477
484
|
response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
|
478
485
|
return response
|
479
486
|
|
480
|
-
async def rephrase_query(self, kbid: str, item: RephraseModel) ->
|
487
|
+
async def rephrase_query(self, kbid: str, item: RephraseModel) -> RephraseResponse:
|
481
488
|
self.calls.append(("rephrase_query", item))
|
482
|
-
return DUMMY_REPHRASE_QUERY
|
489
|
+
return RephraseResponse(rephrased_query=DUMMY_REPHRASE_QUERY, use_chat_history=None)
|
483
490
|
|
484
491
|
async def chat_query_ndjson(
|
485
492
|
self, kbid: str, item: ChatModel
|
@@ -624,7 +631,7 @@ def get_chat_ndjson_generator(
|
|
624
631
|
|
625
632
|
async def _parse_rephrase_response(
|
626
633
|
resp: aiohttp.ClientResponse,
|
627
|
-
) ->
|
634
|
+
) -> RephraseResponse:
|
628
635
|
"""
|
629
636
|
Predict api is returning a json payload that is a string with the following format:
|
630
637
|
<rephrased_query><status_code>
|
@@ -632,12 +639,15 @@ async def _parse_rephrase_response(
|
|
632
639
|
it will raise an exception if the status code is not 0
|
633
640
|
"""
|
634
641
|
content = await resp.json()
|
642
|
+
|
635
643
|
if content.endswith("0"):
|
636
|
-
|
644
|
+
content = content[:-1]
|
637
645
|
elif content.endswith("-1"):
|
638
646
|
raise RephraseError(content[:-2])
|
639
647
|
elif content.endswith("-2"):
|
640
648
|
raise RephraseMissingContextError(content[:-2])
|
641
|
-
|
642
|
-
|
643
|
-
|
649
|
+
|
650
|
+
use_chat_history = None
|
651
|
+
if NUCLIA_LEARNING_CHAT_HISTORY_HEADER in resp.headers:
|
652
|
+
use_chat_history = resp.headers[NUCLIA_LEARNING_CHAT_HISTORY_HEADER] == "true"
|
653
|
+
return RephraseResponse(rephrased_query=content, use_chat_history=use_chat_history)
|
@@ -488,14 +488,21 @@ async def ask(
|
|
488
488
|
if len(chat_history) > 0 or len(user_context) > 0:
|
489
489
|
try:
|
490
490
|
with metrics.time("rephrase"):
|
491
|
-
|
491
|
+
rephrase_response = await rephrase_query(
|
492
492
|
kbid,
|
493
493
|
chat_history=chat_history,
|
494
494
|
query=user_query,
|
495
495
|
user_id=user_id,
|
496
496
|
user_context=user_context,
|
497
497
|
generative_model=ask_request.generative_model,
|
498
|
+
chat_history_relevance_threshold=ask_request.chat_history_relevance_threshold,
|
498
499
|
)
|
500
|
+
rephrased_query = rephrase_response.rephrased_query
|
501
|
+
if rephrase_response.use_chat_history is False:
|
502
|
+
# Ignored if the question is not relevant enough with the chat history
|
503
|
+
logger.info("Chat history was ignored for this request")
|
504
|
+
chat_history = []
|
505
|
+
|
499
506
|
except RephraseMissingContextError:
|
500
507
|
logger.info("Failed to rephrase ask query, using original")
|
501
508
|
|
@@ -27,7 +27,7 @@ from nidx_protos.nodereader_pb2 import (
|
|
27
27
|
|
28
28
|
from nucliadb.common.models_utils import to_proto
|
29
29
|
from nucliadb.search import logger
|
30
|
-
from nucliadb.search.predict import AnswerStatusCode
|
30
|
+
from nucliadb.search.predict import AnswerStatusCode, RephraseResponse
|
31
31
|
from nucliadb.search.requesters.utils import Method, node_query
|
32
32
|
from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
|
33
33
|
from nucliadb.search.search.exceptions import IncompleteFindResultsError
|
@@ -71,7 +71,8 @@ async def rephrase_query(
|
|
71
71
|
user_id: str,
|
72
72
|
user_context: list[str],
|
73
73
|
generative_model: Optional[str] = None,
|
74
|
-
|
74
|
+
chat_history_relevance_threshold: Optional[float] = None,
|
75
|
+
) -> RephraseResponse:
|
75
76
|
predict = get_predict()
|
76
77
|
req = RephraseModel(
|
77
78
|
question=query,
|
@@ -79,6 +80,7 @@ async def rephrase_query(
|
|
79
80
|
user_id=user_id,
|
80
81
|
user_context=user_context,
|
81
82
|
generative_model=generative_model,
|
83
|
+
chat_history_relevance_threshold=chat_history_relevance_threshold,
|
82
84
|
)
|
83
85
|
return await predict.rephrase_query(kbid, req)
|
84
86
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nucliadb
|
3
|
-
Version: 6.4.0.
|
3
|
+
Version: 6.4.0.post4271
|
4
4
|
Summary: NucliaDB
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
6
6
|
License: AGPL
|
@@ -20,11 +20,11 @@ Classifier: Programming Language :: Python :: 3.12
|
|
20
20
|
Classifier: Programming Language :: Python :: 3 :: Only
|
21
21
|
Requires-Python: <4,>=3.9
|
22
22
|
Description-Content-Type: text/markdown
|
23
|
-
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.
|
24
|
-
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.
|
25
|
-
Requires-Dist: nucliadb-protos>=6.4.0.
|
26
|
-
Requires-Dist: nucliadb-models>=6.4.0.
|
27
|
-
Requires-Dist: nidx-protos>=6.4.0.
|
23
|
+
Requires-Dist: nucliadb-telemetry[all]>=6.4.0.post4271
|
24
|
+
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.4.0.post4271
|
25
|
+
Requires-Dist: nucliadb-protos>=6.4.0.post4271
|
26
|
+
Requires-Dist: nucliadb-models>=6.4.0.post4271
|
27
|
+
Requires-Dist: nidx-protos>=6.4.0.post4271
|
28
28
|
Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
|
29
29
|
Requires-Dist: nuclia-models>=0.24.2
|
30
30
|
Requires-Dist: uvicorn[standard]
|
@@ -204,7 +204,7 @@ nucliadb/search/__init__.py,sha256=tnypbqcH4nBHbGpkINudhKgdLKpwXQCvDtPchUlsyY4,1
|
|
204
204
|
nucliadb/search/app.py,sha256=-WEX1AZRA8R_9aeOo9ovOTwjXW_7VfwWN7N2ccSoqXg,3387
|
205
205
|
nucliadb/search/lifecycle.py,sha256=hiylV-lxsAWkqTCulXBg0EIfMQdejSr8Zar0L_GLFT8,2218
|
206
206
|
nucliadb/search/openapi.py,sha256=t3Wo_4baTrfPftg2BHsyLWNZ1MYn7ZRdW7ht-wFOgRs,1016
|
207
|
-
nucliadb/search/predict.py,sha256=
|
207
|
+
nucliadb/search/predict.py,sha256=__0qwIU2CIRYRTYsbG9zZEjXXrxNe8puZWYJIyOT6dg,23492
|
208
208
|
nucliadb/search/predict_models.py,sha256=ZAe0dneUsPmV9uBar57cCFADCGOrYDsJHuqKlA5zWag,5937
|
209
209
|
nucliadb/search/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
210
210
|
nucliadb/search/run.py,sha256=aFb-CXRi_C8YMpP_ivNj8KW1BYhADj88y8K9Lr_nUPI,1402
|
@@ -255,11 +255,11 @@ nucliadb/search/search/shards.py,sha256=mc5DK-MoCv9AFhlXlOFHbPvetcyNDzTFOJ5rimK8
|
|
255
255
|
nucliadb/search/search/summarize.py,sha256=ksmYPubEQvAQgfPdZHfzB_rR19B2ci4IYZ6jLdHxZo8,4996
|
256
256
|
nucliadb/search/search/utils.py,sha256=ajRIXfdTF67dBVahQCXW-rSv6gJpUMPt3QhJrWqArTQ,2175
|
257
257
|
nucliadb/search/search/chat/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
258
|
-
nucliadb/search/search/chat/ask.py,sha256=
|
258
|
+
nucliadb/search/search/chat/ask.py,sha256=4cpeWC7Q4NI30vrOeq22N0Vw0PQS0Ko9b47R5ISlgUw,37833
|
259
259
|
nucliadb/search/search/chat/exceptions.py,sha256=Siy4GXW2L7oPhIR86H3WHBhE9lkV4A4YaAszuGGUf54,1356
|
260
260
|
nucliadb/search/search/chat/images.py,sha256=PA8VWxT5_HUGfW1ULhKTK46UBsVyINtWWqEM1ulzX1E,3095
|
261
261
|
nucliadb/search/search/chat/prompt.py,sha256=Jnja-Ss7skgnnDY8BymVfdeYsFPnIQFL8tEvcRXTKUE,47356
|
262
|
-
nucliadb/search/search/chat/query.py,sha256=
|
262
|
+
nucliadb/search/search/chat/query.py,sha256=6v6twBUTWfUUzklVV6xqJSYPkAshnIrBH9wbTcjQvkI,17063
|
263
263
|
nucliadb/search/search/query_parser/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
264
264
|
nucliadb/search/search/query_parser/exceptions.py,sha256=szAOXUZ27oNY-OSa9t2hQ5HHkQQC0EX1FZz_LluJHJE,1224
|
265
265
|
nucliadb/search/search/query_parser/fetcher.py,sha256=SkvBRDfSKmuz-QygNKLAU4AhZhhDo1dnOZmt1zA28RA,16851
|
@@ -368,8 +368,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
|
|
368
368
|
nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
|
369
369
|
nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
|
370
370
|
nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
|
371
|
-
nucliadb-6.4.0.
|
372
|
-
nucliadb-6.4.0.
|
373
|
-
nucliadb-6.4.0.
|
374
|
-
nucliadb-6.4.0.
|
375
|
-
nucliadb-6.4.0.
|
371
|
+
nucliadb-6.4.0.post4271.dist-info/METADATA,sha256=DKjN0N70XgUrzoOshpgsNU9xMWIweEdjtQgziYbveU0,4223
|
372
|
+
nucliadb-6.4.0.post4271.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
373
|
+
nucliadb-6.4.0.post4271.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
|
374
|
+
nucliadb-6.4.0.post4271.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
|
375
|
+
nucliadb-6.4.0.post4271.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|