nucliadb 6.5.0.post4413__py3-none-any.whl → 6.5.0.post4415__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/search/api/v1/predict_proxy.py +18 -3
- nucliadb/search/search/predict_proxy.py +158 -4
- {nucliadb-6.5.0.post4413.dist-info → nucliadb-6.5.0.post4415.dist-info}/METADATA +6 -6
- {nucliadb-6.5.0.post4413.dist-info → nucliadb-6.5.0.post4415.dist-info}/RECORD +7 -7
- {nucliadb-6.5.0.post4413.dist-info → nucliadb-6.5.0.post4415.dist-info}/WHEEL +0 -0
- {nucliadb-6.5.0.post4413.dist-info → nucliadb-6.5.0.post4415.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.5.0.post4413.dist-info → nucliadb-6.5.0.post4415.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@
|
|
20
20
|
import json
|
21
21
|
from typing import Union
|
22
22
|
|
23
|
-
from fastapi import Request
|
23
|
+
from fastapi import Header, Request
|
24
24
|
from fastapi.responses import Response, StreamingResponse
|
25
25
|
from fastapi_versioning import version
|
26
26
|
|
@@ -30,10 +30,17 @@ from nucliadb.search import predict
|
|
30
30
|
from nucliadb.search.api.v1.router import KB_PREFIX, api
|
31
31
|
from nucliadb.search.search.predict_proxy import PredictProxiedEndpoints, predict_proxy
|
32
32
|
from nucliadb_models.resource import NucliaDBRoles
|
33
|
+
from nucliadb_models.search import NucliaDBClientType
|
33
34
|
from nucliadb_utils.authentication import requires
|
34
35
|
from nucliadb_utils.exceptions import LimitsExceededError
|
35
36
|
|
36
|
-
DESCRIPTION =
|
37
|
+
DESCRIPTION = (
|
38
|
+
"Convenience endpoint that proxies requests to the Predict API."
|
39
|
+
" It adds the Knowledge Box configuration settings as headers to"
|
40
|
+
" the predict API request. Refer to the Predict API documentation"
|
41
|
+
" for more details about the request and response models:"
|
42
|
+
" https://docs.nuclia.dev/docs/nua-api#tag/Predict"
|
43
|
+
)
|
37
44
|
|
38
45
|
|
39
46
|
@api.get(
|
@@ -58,20 +65,28 @@ async def predict_proxy_endpoint(
|
|
58
65
|
request: Request,
|
59
66
|
kbid: str,
|
60
67
|
endpoint: PredictProxiedEndpoints,
|
68
|
+
x_nucliadb_user: str = Header(""),
|
69
|
+
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
|
70
|
+
x_forwarded_for: str = Header(""),
|
61
71
|
) -> Union[Response, StreamingResponse, HTTPClientError]:
|
62
72
|
try:
|
63
73
|
payload = await request.json()
|
64
74
|
except json.JSONDecodeError:
|
65
75
|
payload = None
|
66
76
|
try:
|
67
|
-
|
77
|
+
response = await predict_proxy(
|
68
78
|
kbid,
|
69
79
|
endpoint,
|
70
80
|
request.method,
|
71
81
|
params=request.query_params,
|
72
82
|
json=payload,
|
73
83
|
headers=dict(request.headers),
|
84
|
+
user_id=x_nucliadb_user,
|
85
|
+
client_type=x_ndb_client,
|
86
|
+
origin=x_forwarded_for,
|
74
87
|
)
|
88
|
+
|
89
|
+
return response
|
75
90
|
except KnowledgeBoxNotFound:
|
76
91
|
return HTTPClientError(status_code=404, detail="Knowledge box not found")
|
77
92
|
except LimitsExceededError as exc:
|
@@ -17,15 +17,34 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
+
import json
|
20
21
|
from enum import Enum
|
21
22
|
from typing import Any, Optional, Union
|
22
23
|
|
24
|
+
import aiohttp
|
23
25
|
from fastapi.datastructures import QueryParams
|
24
26
|
from fastapi.responses import Response, StreamingResponse
|
27
|
+
from multidict import CIMultiDictProxy
|
28
|
+
from nuclia_models.predict.generative_responses import (
|
29
|
+
GenerativeChunk,
|
30
|
+
JSONGenerativeResponse,
|
31
|
+
StatusGenerativeResponse,
|
32
|
+
TextGenerativeResponse,
|
33
|
+
)
|
34
|
+
from pydantic import ValidationError
|
25
35
|
|
26
36
|
from nucliadb.common import datamanagers
|
27
|
-
from nucliadb.search
|
37
|
+
from nucliadb.search import logger
|
38
|
+
from nucliadb.search.predict import (
|
39
|
+
NUCLIA_LEARNING_ID_HEADER,
|
40
|
+
NUCLIA_LEARNING_MODEL_HEADER,
|
41
|
+
AnswerStatusCode,
|
42
|
+
PredictEngine,
|
43
|
+
)
|
44
|
+
from nucliadb.search.search.chat.query import maybe_audit_chat
|
45
|
+
from nucliadb.search.search.metrics import AskMetrics
|
28
46
|
from nucliadb.search.utilities import get_predict
|
47
|
+
from nucliadb_models.search import NucliaDBClientType
|
29
48
|
|
30
49
|
|
31
50
|
class PredictProxiedEndpoints(str, Enum):
|
@@ -46,12 +65,17 @@ ALLOWED_HEADERS = [
|
|
46
65
|
"Accept", # To allow 'application/x-ndjson' on the /chat endpoint
|
47
66
|
]
|
48
67
|
|
68
|
+
PREDICT_ANSWER_METRIC = "predict_answer_proxy_metric"
|
69
|
+
|
49
70
|
|
50
71
|
async def predict_proxy(
|
51
72
|
kbid: str,
|
52
73
|
endpoint: PredictProxiedEndpoints,
|
53
74
|
method: str,
|
54
75
|
params: QueryParams,
|
76
|
+
user_id: str,
|
77
|
+
client_type: NucliaDBClientType,
|
78
|
+
origin: str,
|
55
79
|
json: Optional[Any] = None,
|
56
80
|
headers: dict[str, str] = {},
|
57
81
|
) -> Union[Response, StreamingResponse]:
|
@@ -71,22 +95,61 @@ async def predict_proxy(
|
|
71
95
|
headers={**user_headers, **predict_headers},
|
72
96
|
)
|
73
97
|
|
74
|
-
# Proxy the response back to the client
|
75
98
|
status_code = predict_response.status
|
76
99
|
media_type = predict_response.headers.get("Content-Type")
|
77
100
|
response: Union[Response, StreamingResponse]
|
101
|
+
user_query = json.get("question") if json is not None else ""
|
78
102
|
if predict_response.headers.get("Transfer-Encoding") == "chunked":
|
103
|
+
if endpoint == PredictProxiedEndpoints.CHAT:
|
104
|
+
streaming_generator = chat_streaming_generator(
|
105
|
+
predict_response=predict_response,
|
106
|
+
kbid=kbid,
|
107
|
+
user_id=user_id,
|
108
|
+
client_type=client_type,
|
109
|
+
origin=origin,
|
110
|
+
user_query=user_query,
|
111
|
+
is_json="json" in (media_type or ""),
|
112
|
+
)
|
113
|
+
else:
|
114
|
+
streaming_generator = predict_response.content.iter_any()
|
115
|
+
|
79
116
|
response = StreamingResponse(
|
80
|
-
content=
|
117
|
+
content=streaming_generator,
|
81
118
|
status_code=status_code,
|
82
119
|
media_type=media_type,
|
83
120
|
)
|
84
121
|
else:
|
122
|
+
metrics = AskMetrics()
|
123
|
+
with metrics.time(PREDICT_ANSWER_METRIC):
|
124
|
+
content = await predict_response.read()
|
125
|
+
|
126
|
+
if endpoint == PredictProxiedEndpoints.CHAT:
|
127
|
+
try:
|
128
|
+
llm_status_code = int(content[-1:].decode()) # Decode just the last char
|
129
|
+
if llm_status_code != 0:
|
130
|
+
llm_status_code = -llm_status_code
|
131
|
+
except ValueError:
|
132
|
+
llm_status_code = -1
|
133
|
+
|
134
|
+
audit_predict_proxy_endpoint(
|
135
|
+
predict_response.headers,
|
136
|
+
kbid=kbid,
|
137
|
+
user_id=user_id,
|
138
|
+
user_query=user_query,
|
139
|
+
client_type=client_type,
|
140
|
+
origin=origin,
|
141
|
+
text_answer=content,
|
142
|
+
generative_answer_time=metrics[PREDICT_ANSWER_METRIC],
|
143
|
+
generative_answer_first_chunk_time=None,
|
144
|
+
status_code=AnswerStatusCode(str(llm_status_code)),
|
145
|
+
)
|
146
|
+
|
85
147
|
response = Response(
|
86
|
-
content=
|
148
|
+
content=content,
|
87
149
|
status_code=status_code,
|
88
150
|
media_type=media_type,
|
89
151
|
)
|
152
|
+
|
90
153
|
nuclia_learning_id = predict_response.headers.get("NUCLIA-LEARNING-ID")
|
91
154
|
if nuclia_learning_id:
|
92
155
|
response.headers["NUCLIA-LEARNING-ID"] = nuclia_learning_id
|
@@ -97,3 +160,94 @@ async def predict_proxy(
|
|
97
160
|
async def exists_kb(kbid: str) -> bool:
|
98
161
|
async with datamanagers.with_ro_transaction() as txn:
|
99
162
|
return await datamanagers.kb.exists_kb(txn, kbid=kbid)
|
163
|
+
|
164
|
+
|
165
|
+
async def chat_streaming_generator(
|
166
|
+
predict_response: aiohttp.ClientResponse,
|
167
|
+
kbid: str,
|
168
|
+
user_id: str,
|
169
|
+
client_type: NucliaDBClientType,
|
170
|
+
origin: str,
|
171
|
+
user_query: str,
|
172
|
+
is_json: bool,
|
173
|
+
):
|
174
|
+
stream = predict_response.content.iter_any()
|
175
|
+
first = True
|
176
|
+
status_code = AnswerStatusCode.ERROR.value
|
177
|
+
text_answer = ""
|
178
|
+
json_object = None
|
179
|
+
metrics = AskMetrics()
|
180
|
+
with metrics.time(PREDICT_ANSWER_METRIC):
|
181
|
+
async for chunk in stream:
|
182
|
+
if first:
|
183
|
+
metrics.record_first_chunk_yielded()
|
184
|
+
first = False
|
185
|
+
|
186
|
+
yield chunk
|
187
|
+
|
188
|
+
if is_json:
|
189
|
+
try:
|
190
|
+
parsed_chunk = GenerativeChunk.model_validate(chunk)
|
191
|
+
if isinstance(parsed_chunk, TextGenerativeResponse):
|
192
|
+
text_answer += parsed_chunk.text
|
193
|
+
elif isinstance(parsed_chunk, JSONGenerativeResponse):
|
194
|
+
json_object = parsed_chunk.object
|
195
|
+
elif isinstance(parsed_chunk, StatusGenerativeResponse):
|
196
|
+
status_code = parsed_chunk.code
|
197
|
+
except ValidationError:
|
198
|
+
logger.warning(
|
199
|
+
f"Unexpected item in predict answer stream: {chunk.decode()}",
|
200
|
+
extra={"kbid": kbid},
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
text_answer += chunk.decode()
|
204
|
+
|
205
|
+
if is_json is False and chunk: # Ensure chunk is not empty before decoding
|
206
|
+
# If response is text the status_code comes at the last chunk of data
|
207
|
+
status_code = chunk.decode()
|
208
|
+
|
209
|
+
audit_predict_proxy_endpoint(
|
210
|
+
headers=predict_response.headers,
|
211
|
+
kbid=kbid,
|
212
|
+
user_id=user_id,
|
213
|
+
user_query=user_query,
|
214
|
+
client_type=client_type,
|
215
|
+
origin=origin,
|
216
|
+
text_answer=text_answer.encode() if json_object is None else json.dumps(json_object).encode(),
|
217
|
+
generative_answer_time=metrics[PREDICT_ANSWER_METRIC],
|
218
|
+
generative_answer_first_chunk_time=metrics.get_first_chunk_time(),
|
219
|
+
status_code=AnswerStatusCode(status_code),
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
def audit_predict_proxy_endpoint(
|
224
|
+
headers: CIMultiDictProxy,
|
225
|
+
kbid: str,
|
226
|
+
user_id: str,
|
227
|
+
user_query: str,
|
228
|
+
client_type: NucliaDBClientType,
|
229
|
+
origin: str,
|
230
|
+
text_answer: bytes,
|
231
|
+
generative_answer_time: float,
|
232
|
+
generative_answer_first_chunk_time: Optional[float],
|
233
|
+
status_code: AnswerStatusCode,
|
234
|
+
):
|
235
|
+
maybe_audit_chat(
|
236
|
+
kbid=kbid,
|
237
|
+
user_id=user_id,
|
238
|
+
client_type=client_type,
|
239
|
+
origin=origin,
|
240
|
+
user_query=user_query,
|
241
|
+
rephrased_query=None,
|
242
|
+
retrieval_rephrase_query=None,
|
243
|
+
chat_history=[],
|
244
|
+
learning_id=headers.get(NUCLIA_LEARNING_ID_HEADER),
|
245
|
+
query_context={},
|
246
|
+
query_context_order={},
|
247
|
+
model=headers.get(NUCLIA_LEARNING_MODEL_HEADER),
|
248
|
+
text_answer=text_answer,
|
249
|
+
generative_answer_time=generative_answer_time,
|
250
|
+
generative_answer_first_chunk_time=generative_answer_first_chunk_time or 0,
|
251
|
+
rephrase_time=None,
|
252
|
+
status_code=status_code,
|
253
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nucliadb
|
3
|
-
Version: 6.5.0.
|
3
|
+
Version: 6.5.0.post4415
|
4
4
|
Summary: NucliaDB
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
6
6
|
License-Expression: AGPL-3.0-or-later
|
@@ -19,11 +19,11 @@ Classifier: Programming Language :: Python :: 3.12
|
|
19
19
|
Classifier: Programming Language :: Python :: 3 :: Only
|
20
20
|
Requires-Python: <4,>=3.9
|
21
21
|
Description-Content-Type: text/markdown
|
22
|
-
Requires-Dist: nucliadb-telemetry[all]>=6.5.0.
|
23
|
-
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.5.0.
|
24
|
-
Requires-Dist: nucliadb-protos>=6.5.0.
|
25
|
-
Requires-Dist: nucliadb-models>=6.5.0.
|
26
|
-
Requires-Dist: nidx-protos>=6.5.0.
|
22
|
+
Requires-Dist: nucliadb-telemetry[all]>=6.5.0.post4415
|
23
|
+
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.5.0.post4415
|
24
|
+
Requires-Dist: nucliadb-protos>=6.5.0.post4415
|
25
|
+
Requires-Dist: nucliadb-models>=6.5.0.post4415
|
26
|
+
Requires-Dist: nidx-protos>=6.5.0.post4415
|
27
27
|
Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
|
28
28
|
Requires-Dist: nuclia-models>=0.24.2
|
29
29
|
Requires-Dist: uvicorn[standard]
|
@@ -221,7 +221,7 @@ nucliadb/search/api/v1/feedback.py,sha256=kNLc4dHz2SXHzV0PwC1WiRAwY88fDptPcP-kO0
|
|
221
221
|
nucliadb/search/api/v1/find.py,sha256=iMjyq4y0JOMC_x1B8kUfVdkCoc9G9Ark58kPLLY4HDw,10824
|
222
222
|
nucliadb/search/api/v1/graph.py,sha256=gthqxCOn9biE6D6s93jRGLglk0ono8U7OyS390kWiI8,4178
|
223
223
|
nucliadb/search/api/v1/knowledgebox.py,sha256=e9xeLPUqnQTx33i4A8xuV93ENvtJGrpjPlLRbGJtAI8,8415
|
224
|
-
nucliadb/search/api/v1/predict_proxy.py,sha256=
|
224
|
+
nucliadb/search/api/v1/predict_proxy.py,sha256=TnXKAqf_Go-9QVi6L5z4cXjnuNRe7XLJjF5QH_uwA1I,3504
|
225
225
|
nucliadb/search/api/v1/router.py,sha256=mtT07rBZcVfpa49doaw9b1tj3sdi3qLH0gn9Io6NYM0,988
|
226
226
|
nucliadb/search/api/v1/search.py,sha256=eqlrvRE7IlMpunNwD1RJwt6RgMV01sIDJLgxxE7CFcE,12297
|
227
227
|
nucliadb/search/api/v1/suggest.py,sha256=gaJE60r8-z6TVO05mQRKBITwXn2_ofM3B4-OtpOgZEk,6343
|
@@ -250,7 +250,7 @@ nucliadb/search/search/merge.py,sha256=XiRBsxhYPshPV7lZXD-9E259KZOPIf4I2tKosY0lP
|
|
250
250
|
nucliadb/search/search/metrics.py,sha256=3I6IN0qDSmqIvUaWJmT3rt-Jyjs6LcvnKI8ZqCiuJPY,3501
|
251
251
|
nucliadb/search/search/paragraphs.py,sha256=pNAEiYqJGGUVcEf7xf-PFMVqz0PX4Qb-WNG-_zPGN2o,7799
|
252
252
|
nucliadb/search/search/pgcatalog.py,sha256=QtgArjoM-dW_B1oO0aXqp5au7GlLG8jAct9jevUHatw,10997
|
253
|
-
nucliadb/search/search/predict_proxy.py,sha256=
|
253
|
+
nucliadb/search/search/predict_proxy.py,sha256=2451FR2DJLedA4eV8KbMbY4la5WZCgjBUHO_iQVaA8I,8634
|
254
254
|
nucliadb/search/search/query.py,sha256=0qIQdt548L3jtKOyKo06aGJ73SLBxAW3N38_Hc1M3Uw,11528
|
255
255
|
nucliadb/search/search/rank_fusion.py,sha256=xZtXhbmKb_56gs73u6KkFm2efvTATOSMmpOV2wrAIqE,9613
|
256
256
|
nucliadb/search/search/rerankers.py,sha256=E2J1QdKAojqbhHM3KAyaOXKf6tJyETUxKs4tf_BEyqk,7472
|
@@ -370,8 +370,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
|
|
370
370
|
nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
|
371
371
|
nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
|
372
372
|
nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
|
373
|
-
nucliadb-6.5.0.
|
374
|
-
nucliadb-6.5.0.
|
375
|
-
nucliadb-6.5.0.
|
376
|
-
nucliadb-6.5.0.
|
377
|
-
nucliadb-6.5.0.
|
373
|
+
nucliadb-6.5.0.post4415.dist-info/METADATA,sha256=lSbgbK7GgTJXjm5IIzYAeFt7-W8LdME3CKarZjmQwbg,4152
|
374
|
+
nucliadb-6.5.0.post4415.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
375
|
+
nucliadb-6.5.0.post4415.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
|
376
|
+
nucliadb-6.5.0.post4415.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
|
377
|
+
nucliadb-6.5.0.post4415.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|