nuclia 4.9.2__py3-none-any.whl → 4.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/nua.py +15 -0
- nuclia/lib/nua_responses.py +12 -0
- nuclia/sdk/predict.py +24 -0
- nuclia/tests/test_nua/test_predict.py +16 -1
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/METADATA +1 -1
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/RECORD +10 -10
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/WHEEL +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.3.dist-info}/top_level.txt +0 -0
nuclia/lib/nua.py
CHANGED
@@ -42,6 +42,8 @@ from nuclia.lib.nua_responses import (
|
|
42
42
|
PushResponseV2,
|
43
43
|
QueryInfo,
|
44
44
|
RephraseModel,
|
45
|
+
RerankModel,
|
46
|
+
RerankResponse,
|
45
47
|
RestrictedIDString,
|
46
48
|
Sentence,
|
47
49
|
Source,
|
@@ -77,6 +79,7 @@ PUSH_PROCESS = "/api/v2/processing/push"
|
|
77
79
|
SCHEMA = "/api/v1/learning/configuration/schema"
|
78
80
|
SCHEMA_KBID = "/api/v1/schema"
|
79
81
|
CONFIG = "/api/v1/config"
|
82
|
+
RERANK = "/api/v1/predict/rerank"
|
80
83
|
|
81
84
|
ConvertType = TypeVar("ConvertType", bound=BaseModel)
|
82
85
|
|
@@ -410,6 +413,12 @@ class NuaClient:
|
|
410
413
|
activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
|
411
414
|
return self._request("GET", activity_endpoint, ProcessRequestStatus)
|
412
415
|
|
416
|
+
def rerank(self, model: RerankModel) -> RerankResponse:
|
417
|
+
endpoint = f"{self.url}{RERANK}"
|
418
|
+
return self._request(
|
419
|
+
"POST", endpoint, payload=model.model_dump(), output=RerankResponse
|
420
|
+
)
|
421
|
+
|
413
422
|
|
414
423
|
class AsyncNuaClient:
|
415
424
|
def __init__(
|
@@ -792,3 +801,9 @@ class AsyncNuaClient:
|
|
792
801
|
return await self._request(
|
793
802
|
"GET", activity_endpoint, output=ProcessRequestStatus
|
794
803
|
)
|
804
|
+
|
805
|
+
async def rerank(self, model: RerankModel) -> RerankResponse:
|
806
|
+
endpoint = f"{self.url}{RERANK}"
|
807
|
+
return await self._request(
|
808
|
+
"POST", endpoint, payload=model.model_dump(), output=RerankResponse
|
809
|
+
)
|
nuclia/lib/nua_responses.py
CHANGED
@@ -557,3 +557,15 @@ class QueryInfo(BaseModel):
|
|
557
557
|
max_context: int
|
558
558
|
entities: Optional[TokenSearch]
|
559
559
|
sentence: Optional[SentenceSearch]
|
560
|
+
|
561
|
+
|
562
|
+
class RerankModel(BaseModel):
|
563
|
+
question: str
|
564
|
+
user_id: str
|
565
|
+
context: dict[str, str] = {}
|
566
|
+
|
567
|
+
|
568
|
+
class RerankResponse(BaseModel):
|
569
|
+
context_scores: dict[str, float] = Field(
|
570
|
+
description="Scores for each context given by the reranker"
|
571
|
+
)
|
nuclia/sdk/predict.py
CHANGED
@@ -13,6 +13,8 @@ from nuclia.lib.nua_responses import (
|
|
13
13
|
ConfigSchema,
|
14
14
|
LearningConfigurationCreation,
|
15
15
|
QueryInfo,
|
16
|
+
RerankModel,
|
17
|
+
RerankResponse,
|
16
18
|
Sentence,
|
17
19
|
StoredLearningConfiguration,
|
18
20
|
SummarizedModel,
|
@@ -162,6 +164,17 @@ class NucliaPredict:
|
|
162
164
|
nc: NuaClient = kwargs["nc"]
|
163
165
|
return nc.remi(request)
|
164
166
|
|
167
|
+
@nua
|
168
|
+
def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
|
169
|
+
"""
|
170
|
+
Perform a reranking of the results based on the question and context provided.
|
171
|
+
|
172
|
+
:param request: RerankModel
|
173
|
+
:return: RerankResponse
|
174
|
+
"""
|
175
|
+
nc: NuaClient = kwargs["nc"]
|
176
|
+
return nc.rerank(request)
|
177
|
+
|
165
178
|
|
166
179
|
class AsyncNucliaPredict:
|
167
180
|
@property
|
@@ -299,3 +312,14 @@ class AsyncNucliaPredict:
|
|
299
312
|
|
300
313
|
nc: AsyncNuaClient = kwargs["nc"]
|
301
314
|
return await nc.remi(request)
|
315
|
+
|
316
|
+
@nua
|
317
|
+
async def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
|
318
|
+
"""
|
319
|
+
Perform a reranking of the results based on the question and context provided.
|
320
|
+
|
321
|
+
:param request: RerankModel
|
322
|
+
:return: RerankResponse
|
323
|
+
"""
|
324
|
+
nc: AsyncNuaClient = kwargs["nc"]
|
325
|
+
return await nc.rerank(request)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from nuclia_models.predict.generative_responses import TextGenerativeResponse
|
2
2
|
|
3
|
-
from nuclia.lib.nua_responses import ChatModel, UserPrompt
|
3
|
+
from nuclia.lib.nua_responses import ChatModel, RerankModel, UserPrompt
|
4
4
|
from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict
|
5
5
|
import pytest
|
6
6
|
from nuclia_models.predict.remi import RemiRequest
|
@@ -170,3 +170,18 @@ async def test_nua_async_remi(testing_config):
|
|
170
170
|
|
171
171
|
assert results.context_relevance[1] < 2
|
172
172
|
assert results.groundedness[1] < 2
|
173
|
+
|
174
|
+
|
175
|
+
def test_nua_rerank(testing_config):
|
176
|
+
np = NucliaPredict()
|
177
|
+
results = np.rerank(
|
178
|
+
RerankModel(
|
179
|
+
user_id="Nuclia PY CLI",
|
180
|
+
question="What is the capital of France?",
|
181
|
+
context={
|
182
|
+
"1": "Paris is the capital of France.",
|
183
|
+
"2": "Berlin is the capital of Germany.",
|
184
|
+
},
|
185
|
+
)
|
186
|
+
)
|
187
|
+
assert results.context_scores["1"] > results.context_scores["2"]
|
@@ -11,9 +11,9 @@ nuclia/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
nuclia/lib/conversations.py,sha256=M6qhL9NPEKroYF767S-Q2XWokRrjX02kpYTzRvZKwUE,149
|
12
12
|
nuclia/lib/kb.py,sha256=vSLfmV6HqPvWJwPVw4iIzkDnm0M96_ImTH3QMZdkd_I,30985
|
13
13
|
nuclia/lib/models.py,sha256=ekEQrVIFU3aFvt60yQh-zpWkGNORBMSc7c5Hd_VzPzI,1564
|
14
|
-
nuclia/lib/nua.py,sha256=
|
14
|
+
nuclia/lib/nua.py,sha256=jVUV8I50iJpUQD6yn4V7wNDDSDP8oEWrRxnmrs33hnI,27591
|
15
15
|
nuclia/lib/nua_chat.py,sha256=ApL1Y1FWvAVUt-Y9a_8TUSJIhg8-UmBSy8TlDPn6tD8,3874
|
16
|
-
nuclia/lib/nua_responses.py,sha256=
|
16
|
+
nuclia/lib/nua_responses.py,sha256=FyCedSKHVGn4w380obV9B3D1JaqVyNTqda4-OCs4A9A,13637
|
17
17
|
nuclia/lib/utils.py,sha256=9l6DxBk-11WqUhXRn99cqeuVTUOJXj-1S6zckal7wOk,6312
|
18
18
|
nuclia/sdk/__init__.py,sha256=-nAw8i53XBdmbfTa1FJZ0FNRMNakimDVpD6W4OdES-c,1374
|
19
19
|
nuclia/sdk/accounts.py,sha256=7XQ3K9_jlSuk2Cez868FtazZ05xSGab6h3Mt1qMMwIE,647
|
@@ -28,7 +28,7 @@ nuclia/sdk/logger.py,sha256=UHB81eS6IGmLrsofKxLh8cmF2AsaTj_HXP0tGqMr_HM,57
|
|
28
28
|
nuclia/sdk/logs.py,sha256=3jfORpo8fzZiXFFSbGY0o3Bre1ZgJaKQCXgxP1keNHw,9614
|
29
29
|
nuclia/sdk/nua.py,sha256=6t0m0Sx-UhqNU2Hx9v6vTwy0m3a30K4T0KmP9G43MzY,293
|
30
30
|
nuclia/sdk/nucliadb.py,sha256=bOESIppPgY7IrNqrYY7T3ESoxwttbOSTm5zj1xUS1jI,1288
|
31
|
-
nuclia/sdk/predict.py,sha256=
|
31
|
+
nuclia/sdk/predict.py,sha256=Gom-BlISXawThNq-f7fb1te5tY4catlIdmO-pMJyqkk,9815
|
32
32
|
nuclia/sdk/process.py,sha256=WuNnqaWprp-EABWDC_z7O2woesGIlYWnDUKozh7Ibr4,2241
|
33
33
|
nuclia/sdk/remi.py,sha256=BEb3O9R2jOFlOda4vjFucKKGO1c2eTkqYZdFlIy3Zmo,4357
|
34
34
|
nuclia/sdk/resource.py,sha256=0lSvD4e1FpN5iM9W295dOKLJ8hsXfIe8HKdo0HsVg20,13976
|
@@ -57,15 +57,15 @@ nuclia/tests/test_manage/test_auth.py,sha256=I5ho9rKhrzCiP57cDVcs2zrXyy1uWlKxzfE
|
|
57
57
|
nuclia/tests/test_manage/test_kb.py,sha256=gqBMxmIuPUKC6kP2V_IapS9u72QR99V8CDQlJOa8sHU,1959
|
58
58
|
nuclia/tests/test_nua/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
nuclia/tests/test_nua/test_agent.py,sha256=iPA8lVw7CyONS-0fVG7rDzv8T-LnUlNPMNynTnP-2ic,334
|
60
|
-
nuclia/tests/test_nua/test_predict.py,sha256=
|
60
|
+
nuclia/tests/test_nua/test_predict.py,sha256=8by69GgXuOZKAgksjSjgmbNr98jCXanF1HOo29oNuMg,5798
|
61
61
|
nuclia/tests/test_nucliadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
62
62
|
nuclia/tests/test_nucliadb/test_crud.py,sha256=GuY76HRvt2DFaNgioKm5n0Aco1HnG7zzV_zKom5N8xc,1173
|
63
63
|
nuclia/tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
64
64
|
nuclia/tests/unit/test_export_import.py,sha256=xo_wVbjUnNlVV65ZGH7LtZ38qy39EkJp2hjOuTHC1nU,980
|
65
65
|
nuclia/tests/unit/test_nua_responses.py,sha256=t_hIdVztTi27RWvpfTJUYcCL0lpKdZFegZIwLdaPNh8,319
|
66
|
-
nuclia-4.9.
|
67
|
-
nuclia-4.9.
|
68
|
-
nuclia-4.9.
|
69
|
-
nuclia-4.9.
|
70
|
-
nuclia-4.9.
|
71
|
-
nuclia-4.9.
|
66
|
+
nuclia-4.9.3.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
|
67
|
+
nuclia-4.9.3.dist-info/METADATA,sha256=rxQGLEDjPHayY_NPuXkmXpfsp4_BfoBC2ZjLlOjhyoM,2337
|
68
|
+
nuclia-4.9.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
69
|
+
nuclia-4.9.3.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
|
70
|
+
nuclia-4.9.3.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
|
71
|
+
nuclia-4.9.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|