nucliadb 6.5.0.post4413__py3-none-any.whl → 6.5.0.post4415__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,7 @@
20
20
  import json
21
21
  from typing import Union
22
22
 
23
- from fastapi import Request
23
+ from fastapi import Header, Request
24
24
  from fastapi.responses import Response, StreamingResponse
25
25
  from fastapi_versioning import version
26
26
 
@@ -30,10 +30,17 @@ from nucliadb.search import predict
30
30
  from nucliadb.search.api.v1.router import KB_PREFIX, api
31
31
  from nucliadb.search.search.predict_proxy import PredictProxiedEndpoints, predict_proxy
32
32
  from nucliadb_models.resource import NucliaDBRoles
33
+ from nucliadb_models.search import NucliaDBClientType
33
34
  from nucliadb_utils.authentication import requires
34
35
  from nucliadb_utils.exceptions import LimitsExceededError
35
36
 
36
- DESCRIPTION = "Convenience endpoint that proxies requests to the Predict API. It adds the Knowledge Box configuration settings as headers to the predict API request. Refer to the Predict API documentation for more details about the request and response models: https://docs.nuclia.dev/docs/nua-api#tag/Predict" # noqa: E501
37
+ DESCRIPTION = (
38
+ "Convenience endpoint that proxies requests to the Predict API."
39
+ " It adds the Knowledge Box configuration settings as headers to"
40
+ " the predict API request. Refer to the Predict API documentation"
41
+ " for more details about the request and response models:"
42
+ " https://docs.nuclia.dev/docs/nua-api#tag/Predict"
43
+ )
37
44
 
38
45
 
39
46
  @api.get(
@@ -58,20 +65,28 @@ async def predict_proxy_endpoint(
58
65
  request: Request,
59
66
  kbid: str,
60
67
  endpoint: PredictProxiedEndpoints,
68
+ x_nucliadb_user: str = Header(""),
69
+ x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
70
+ x_forwarded_for: str = Header(""),
61
71
  ) -> Union[Response, StreamingResponse, HTTPClientError]:
62
72
  try:
63
73
  payload = await request.json()
64
74
  except json.JSONDecodeError:
65
75
  payload = None
66
76
  try:
67
- return await predict_proxy(
77
+ response = await predict_proxy(
68
78
  kbid,
69
79
  endpoint,
70
80
  request.method,
71
81
  params=request.query_params,
72
82
  json=payload,
73
83
  headers=dict(request.headers),
84
+ user_id=x_nucliadb_user,
85
+ client_type=x_ndb_client,
86
+ origin=x_forwarded_for,
74
87
  )
88
+
89
+ return response
75
90
  except KnowledgeBoxNotFound:
76
91
  return HTTPClientError(status_code=404, detail="Knowledge box not found")
77
92
  except LimitsExceededError as exc:
@@ -17,15 +17,34 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
+ import json
20
21
  from enum import Enum
21
22
  from typing import Any, Optional, Union
22
23
 
24
+ import aiohttp
23
25
  from fastapi.datastructures import QueryParams
24
26
  from fastapi.responses import Response, StreamingResponse
27
+ from multidict import CIMultiDictProxy
28
+ from nuclia_models.predict.generative_responses import (
29
+ GenerativeChunk,
30
+ JSONGenerativeResponse,
31
+ StatusGenerativeResponse,
32
+ TextGenerativeResponse,
33
+ )
34
+ from pydantic import ValidationError
25
35
 
26
36
  from nucliadb.common import datamanagers
27
- from nucliadb.search.predict import PredictEngine
37
+ from nucliadb.search import logger
38
+ from nucliadb.search.predict import (
39
+ NUCLIA_LEARNING_ID_HEADER,
40
+ NUCLIA_LEARNING_MODEL_HEADER,
41
+ AnswerStatusCode,
42
+ PredictEngine,
43
+ )
44
+ from nucliadb.search.search.chat.query import maybe_audit_chat
45
+ from nucliadb.search.search.metrics import AskMetrics
28
46
  from nucliadb.search.utilities import get_predict
47
+ from nucliadb_models.search import NucliaDBClientType
29
48
 
30
49
 
31
50
  class PredictProxiedEndpoints(str, Enum):
@@ -46,12 +65,17 @@ ALLOWED_HEADERS = [
46
65
  "Accept", # To allow 'application/x-ndjson' on the /chat endpoint
47
66
  ]
48
67
 
68
+ PREDICT_ANSWER_METRIC = "predict_answer_proxy_metric"
69
+
49
70
 
50
71
  async def predict_proxy(
51
72
  kbid: str,
52
73
  endpoint: PredictProxiedEndpoints,
53
74
  method: str,
54
75
  params: QueryParams,
76
+ user_id: str,
77
+ client_type: NucliaDBClientType,
78
+ origin: str,
55
79
  json: Optional[Any] = None,
56
80
  headers: dict[str, str] = {},
57
81
  ) -> Union[Response, StreamingResponse]:
@@ -71,22 +95,61 @@ async def predict_proxy(
71
95
  headers={**user_headers, **predict_headers},
72
96
  )
73
97
 
74
- # Proxy the response back to the client
75
98
  status_code = predict_response.status
76
99
  media_type = predict_response.headers.get("Content-Type")
77
100
  response: Union[Response, StreamingResponse]
101
+ user_query = json.get("question") if json is not None else ""
78
102
  if predict_response.headers.get("Transfer-Encoding") == "chunked":
103
+ if endpoint == PredictProxiedEndpoints.CHAT:
104
+ streaming_generator = chat_streaming_generator(
105
+ predict_response=predict_response,
106
+ kbid=kbid,
107
+ user_id=user_id,
108
+ client_type=client_type,
109
+ origin=origin,
110
+ user_query=user_query,
111
+ is_json="json" in (media_type or ""),
112
+ )
113
+ else:
114
+ streaming_generator = predict_response.content.iter_any()
115
+
79
116
  response = StreamingResponse(
80
- content=predict_response.content.iter_any(),
117
+ content=streaming_generator,
81
118
  status_code=status_code,
82
119
  media_type=media_type,
83
120
  )
84
121
  else:
122
+ metrics = AskMetrics()
123
+ with metrics.time(PREDICT_ANSWER_METRIC):
124
+ content = await predict_response.read()
125
+
126
+ if endpoint == PredictProxiedEndpoints.CHAT:
127
+ try:
128
+ llm_status_code = int(content[-1:].decode()) # Decode just the last char
129
+ if llm_status_code != 0:
130
+ llm_status_code = -llm_status_code
131
+ except ValueError:
132
+ llm_status_code = -1
133
+
134
+ audit_predict_proxy_endpoint(
135
+ predict_response.headers,
136
+ kbid=kbid,
137
+ user_id=user_id,
138
+ user_query=user_query,
139
+ client_type=client_type,
140
+ origin=origin,
141
+ text_answer=content,
142
+ generative_answer_time=metrics[PREDICT_ANSWER_METRIC],
143
+ generative_answer_first_chunk_time=None,
144
+ status_code=AnswerStatusCode(str(llm_status_code)),
145
+ )
146
+
85
147
  response = Response(
86
- content=await predict_response.read(),
148
+ content=content,
87
149
  status_code=status_code,
88
150
  media_type=media_type,
89
151
  )
152
+
90
153
  nuclia_learning_id = predict_response.headers.get("NUCLIA-LEARNING-ID")
91
154
  if nuclia_learning_id:
92
155
  response.headers["NUCLIA-LEARNING-ID"] = nuclia_learning_id
@@ -97,3 +160,94 @@ async def predict_proxy(
97
160
  async def exists_kb(kbid: str) -> bool:
98
161
  async with datamanagers.with_ro_transaction() as txn:
99
162
  return await datamanagers.kb.exists_kb(txn, kbid=kbid)
163
+
164
+
165
+ async def chat_streaming_generator(
166
+ predict_response: aiohttp.ClientResponse,
167
+ kbid: str,
168
+ user_id: str,
169
+ client_type: NucliaDBClientType,
170
+ origin: str,
171
+ user_query: str,
172
+ is_json: bool,
173
+ ):
174
+ stream = predict_response.content.iter_any()
175
+ first = True
176
+ status_code = AnswerStatusCode.ERROR.value
177
+ text_answer = ""
178
+ json_object = None
179
+ metrics = AskMetrics()
180
+ with metrics.time(PREDICT_ANSWER_METRIC):
181
+ async for chunk in stream:
182
+ if first:
183
+ metrics.record_first_chunk_yielded()
184
+ first = False
185
+
186
+ yield chunk
187
+
188
+ if is_json:
189
+ try:
190
+ parsed_chunk = GenerativeChunk.model_validate(chunk)
191
+ if isinstance(parsed_chunk, TextGenerativeResponse):
192
+ text_answer += parsed_chunk.text
193
+ elif isinstance(parsed_chunk, JSONGenerativeResponse):
194
+ json_object = parsed_chunk.object
195
+ elif isinstance(parsed_chunk, StatusGenerativeResponse):
196
+ status_code = parsed_chunk.code
197
+ except ValidationError:
198
+ logger.warning(
199
+ f"Unexpected item in predict answer stream: {chunk.decode()}",
200
+ extra={"kbid": kbid},
201
+ )
202
+ else:
203
+ text_answer += chunk.decode()
204
+
205
+ if is_json is False and chunk: # Ensure chunk is not empty before decoding
206
+ # If response is text the status_code comes at the last chunk of data
207
+ status_code = chunk.decode()
208
+
209
+ audit_predict_proxy_endpoint(
210
+ headers=predict_response.headers,
211
+ kbid=kbid,
212
+ user_id=user_id,
213
+ user_query=user_query,
214
+ client_type=client_type,
215
+ origin=origin,
216
+ text_answer=text_answer.encode() if json_object is None else json.dumps(json_object).encode(),
217
+ generative_answer_time=metrics[PREDICT_ANSWER_METRIC],
218
+ generative_answer_first_chunk_time=metrics.get_first_chunk_time(),
219
+ status_code=AnswerStatusCode(status_code),
220
+ )
221
+
222
+
223
+ def audit_predict_proxy_endpoint(
224
+ headers: CIMultiDictProxy,
225
+ kbid: str,
226
+ user_id: str,
227
+ user_query: str,
228
+ client_type: NucliaDBClientType,
229
+ origin: str,
230
+ text_answer: bytes,
231
+ generative_answer_time: float,
232
+ generative_answer_first_chunk_time: Optional[float],
233
+ status_code: AnswerStatusCode,
234
+ ):
235
+ maybe_audit_chat(
236
+ kbid=kbid,
237
+ user_id=user_id,
238
+ client_type=client_type,
239
+ origin=origin,
240
+ user_query=user_query,
241
+ rephrased_query=None,
242
+ retrieval_rephrase_query=None,
243
+ chat_history=[],
244
+ learning_id=headers.get(NUCLIA_LEARNING_ID_HEADER),
245
+ query_context={},
246
+ query_context_order={},
247
+ model=headers.get(NUCLIA_LEARNING_MODEL_HEADER),
248
+ text_answer=text_answer,
249
+ generative_answer_time=generative_answer_time,
250
+ generative_answer_first_chunk_time=generative_answer_first_chunk_time or 0,
251
+ rephrase_time=None,
252
+ status_code=status_code,
253
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nucliadb
3
- Version: 6.5.0.post4413
3
+ Version: 6.5.0.post4415
4
4
  Summary: NucliaDB
5
5
  Author-email: Nuclia <nucliadb@nuclia.com>
6
6
  License-Expression: AGPL-3.0-or-later
@@ -19,11 +19,11 @@ Classifier: Programming Language :: Python :: 3.12
19
19
  Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Requires-Python: <4,>=3.9
21
21
  Description-Content-Type: text/markdown
22
- Requires-Dist: nucliadb-telemetry[all]>=6.5.0.post4413
23
- Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.5.0.post4413
24
- Requires-Dist: nucliadb-protos>=6.5.0.post4413
25
- Requires-Dist: nucliadb-models>=6.5.0.post4413
26
- Requires-Dist: nidx-protos>=6.5.0.post4413
22
+ Requires-Dist: nucliadb-telemetry[all]>=6.5.0.post4415
23
+ Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.5.0.post4415
24
+ Requires-Dist: nucliadb-protos>=6.5.0.post4415
25
+ Requires-Dist: nucliadb-models>=6.5.0.post4415
26
+ Requires-Dist: nidx-protos>=6.5.0.post4415
27
27
  Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
28
28
  Requires-Dist: nuclia-models>=0.24.2
29
29
  Requires-Dist: uvicorn[standard]
@@ -221,7 +221,7 @@ nucliadb/search/api/v1/feedback.py,sha256=kNLc4dHz2SXHzV0PwC1WiRAwY88fDptPcP-kO0
221
221
  nucliadb/search/api/v1/find.py,sha256=iMjyq4y0JOMC_x1B8kUfVdkCoc9G9Ark58kPLLY4HDw,10824
222
222
  nucliadb/search/api/v1/graph.py,sha256=gthqxCOn9biE6D6s93jRGLglk0ono8U7OyS390kWiI8,4178
223
223
  nucliadb/search/api/v1/knowledgebox.py,sha256=e9xeLPUqnQTx33i4A8xuV93ENvtJGrpjPlLRbGJtAI8,8415
224
- nucliadb/search/api/v1/predict_proxy.py,sha256=Q03ZTvWp7Sq0x71t5Br4LHxTiYsRd6-GCb4YuKqhynM,3131
224
+ nucliadb/search/api/v1/predict_proxy.py,sha256=TnXKAqf_Go-9QVi6L5z4cXjnuNRe7XLJjF5QH_uwA1I,3504
225
225
  nucliadb/search/api/v1/router.py,sha256=mtT07rBZcVfpa49doaw9b1tj3sdi3qLH0gn9Io6NYM0,988
226
226
  nucliadb/search/api/v1/search.py,sha256=eqlrvRE7IlMpunNwD1RJwt6RgMV01sIDJLgxxE7CFcE,12297
227
227
  nucliadb/search/api/v1/suggest.py,sha256=gaJE60r8-z6TVO05mQRKBITwXn2_ofM3B4-OtpOgZEk,6343
@@ -250,7 +250,7 @@ nucliadb/search/search/merge.py,sha256=XiRBsxhYPshPV7lZXD-9E259KZOPIf4I2tKosY0lP
250
250
  nucliadb/search/search/metrics.py,sha256=3I6IN0qDSmqIvUaWJmT3rt-Jyjs6LcvnKI8ZqCiuJPY,3501
251
251
  nucliadb/search/search/paragraphs.py,sha256=pNAEiYqJGGUVcEf7xf-PFMVqz0PX4Qb-WNG-_zPGN2o,7799
252
252
  nucliadb/search/search/pgcatalog.py,sha256=QtgArjoM-dW_B1oO0aXqp5au7GlLG8jAct9jevUHatw,10997
253
- nucliadb/search/search/predict_proxy.py,sha256=JwgBeEg1j4LnCjPCvTUrnmOd9LceJAt3iAu4m9cmJBo,3390
253
+ nucliadb/search/search/predict_proxy.py,sha256=2451FR2DJLedA4eV8KbMbY4la5WZCgjBUHO_iQVaA8I,8634
254
254
  nucliadb/search/search/query.py,sha256=0qIQdt548L3jtKOyKo06aGJ73SLBxAW3N38_Hc1M3Uw,11528
255
255
  nucliadb/search/search/rank_fusion.py,sha256=xZtXhbmKb_56gs73u6KkFm2efvTATOSMmpOV2wrAIqE,9613
256
256
  nucliadb/search/search/rerankers.py,sha256=E2J1QdKAojqbhHM3KAyaOXKf6tJyETUxKs4tf_BEyqk,7472
@@ -370,8 +370,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
370
370
  nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
371
371
  nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
372
372
  nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
373
- nucliadb-6.5.0.post4413.dist-info/METADATA,sha256=B1pDmizFK3zLtNJQUCZAOb4jLeQgnyw3vcryDfJO5y4,4152
374
- nucliadb-6.5.0.post4413.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
375
- nucliadb-6.5.0.post4413.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
376
- nucliadb-6.5.0.post4413.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
377
- nucliadb-6.5.0.post4413.dist-info/RECORD,,
373
+ nucliadb-6.5.0.post4415.dist-info/METADATA,sha256=lSbgbK7GgTJXjm5IIzYAeFt7-W8LdME3CKarZjmQwbg,4152
374
+ nucliadb-6.5.0.post4415.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
375
+ nucliadb-6.5.0.post4415.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
376
+ nucliadb-6.5.0.post4415.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
377
+ nucliadb-6.5.0.post4415.dist-info/RECORD,,