nuclia 4.9.2__py3-none-any.whl → 4.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nuclia/lib/nua.py CHANGED
@@ -42,6 +42,8 @@ from nuclia.lib.nua_responses import (
42
42
  PushResponseV2,
43
43
  QueryInfo,
44
44
  RephraseModel,
45
+ RerankModel,
46
+ RerankResponse,
45
47
  RestrictedIDString,
46
48
  Sentence,
47
49
  Source,
@@ -77,6 +79,7 @@ PUSH_PROCESS = "/api/v2/processing/push"
77
79
  SCHEMA = "/api/v1/learning/configuration/schema"
78
80
  SCHEMA_KBID = "/api/v1/schema"
79
81
  CONFIG = "/api/v1/config"
82
+ RERANK = "/api/v1/predict/rerank"
80
83
 
81
84
  ConvertType = TypeVar("ConvertType", bound=BaseModel)
82
85
 
@@ -410,6 +413,12 @@ class NuaClient:
410
413
  activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
411
414
  return self._request("GET", activity_endpoint, ProcessRequestStatus)
412
415
 
416
+ def rerank(self, model: RerankModel) -> RerankResponse:
417
+ endpoint = f"{self.url}{RERANK}"
418
+ return self._request(
419
+ "POST", endpoint, payload=model.model_dump(), output=RerankResponse
420
+ )
421
+
413
422
 
414
423
  class AsyncNuaClient:
415
424
  def __init__(
@@ -792,3 +801,9 @@ class AsyncNuaClient:
792
801
  return await self._request(
793
802
  "GET", activity_endpoint, output=ProcessRequestStatus
794
803
  )
804
+
805
+ async def rerank(self, model: RerankModel) -> RerankResponse:
806
+ endpoint = f"{self.url}{RERANK}"
807
+ return await self._request(
808
+ "POST", endpoint, payload=model.model_dump(), output=RerankResponse
809
+ )
@@ -557,3 +557,15 @@ class QueryInfo(BaseModel):
557
557
  max_context: int
558
558
  entities: Optional[TokenSearch]
559
559
  sentence: Optional[SentenceSearch]
560
+
561
+
562
+ class RerankModel(BaseModel):
563
+ question: str
564
+ user_id: str
565
+ context: dict[str, str] = {}
566
+
567
+
568
+ class RerankResponse(BaseModel):
569
+ context_scores: dict[str, float] = Field(
570
+ description="Scores for each context given by the reranker"
571
+ )
nuclia/sdk/predict.py CHANGED
@@ -13,6 +13,8 @@ from nuclia.lib.nua_responses import (
13
13
  ConfigSchema,
14
14
  LearningConfigurationCreation,
15
15
  QueryInfo,
16
+ RerankModel,
17
+ RerankResponse,
16
18
  Sentence,
17
19
  StoredLearningConfiguration,
18
20
  SummarizedModel,
@@ -162,6 +164,17 @@ class NucliaPredict:
162
164
  nc: NuaClient = kwargs["nc"]
163
165
  return nc.remi(request)
164
166
 
167
+ @nua
168
+ def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
169
+ """
170
+ Perform a reranking of the results based on the question and context provided.
171
+
172
+ :param request: RerankModel
173
+ :return: RerankResponse
174
+ """
175
+ nc: NuaClient = kwargs["nc"]
176
+ return nc.rerank(request)
177
+
165
178
 
166
179
  class AsyncNucliaPredict:
167
180
  @property
@@ -299,3 +312,14 @@ class AsyncNucliaPredict:
299
312
 
300
313
  nc: AsyncNuaClient = kwargs["nc"]
301
314
  return await nc.remi(request)
315
+
316
+ @nua
317
+ async def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
318
+ """
319
+ Perform a reranking of the results based on the question and context provided.
320
+
321
+ :param request: RerankModel
322
+ :return: RerankResponse
323
+ """
324
+ nc: AsyncNuaClient = kwargs["nc"]
325
+ return await nc.rerank(request)
@@ -1,6 +1,6 @@
1
1
  from nuclia_models.predict.generative_responses import TextGenerativeResponse
2
2
 
3
- from nuclia.lib.nua_responses import ChatModel, UserPrompt
3
+ from nuclia.lib.nua_responses import ChatModel, RerankModel, UserPrompt
4
4
  from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict
5
5
  import pytest
6
6
  from nuclia_models.predict.remi import RemiRequest
@@ -170,3 +170,18 @@ async def test_nua_async_remi(testing_config):
170
170
 
171
171
  assert results.context_relevance[1] < 2
172
172
  assert results.groundedness[1] < 2
173
+
174
+
175
+ def test_nua_rerank(testing_config):
176
+ np = NucliaPredict()
177
+ results = np.rerank(
178
+ RerankModel(
179
+ user_id="Nuclia PY CLI",
180
+ question="What is the capital of France?",
181
+ context={
182
+ "1": "Paris is the capital of France.",
183
+ "2": "Berlin is the capital of Germany.",
184
+ },
185
+ )
186
+ )
187
+ assert results.context_scores["1"] > results.context_scores["2"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nuclia
3
- Version: 4.9.2
3
+ Version: 4.9.3
4
4
  Summary: Nuclia Python SDK
5
5
  Author-email: Nuclia <info@nuclia.com>
6
6
  License-Expression: MIT
@@ -11,9 +11,9 @@ nuclia/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  nuclia/lib/conversations.py,sha256=M6qhL9NPEKroYF767S-Q2XWokRrjX02kpYTzRvZKwUE,149
12
12
  nuclia/lib/kb.py,sha256=vSLfmV6HqPvWJwPVw4iIzkDnm0M96_ImTH3QMZdkd_I,30985
13
13
  nuclia/lib/models.py,sha256=ekEQrVIFU3aFvt60yQh-zpWkGNORBMSc7c5Hd_VzPzI,1564
14
- nuclia/lib/nua.py,sha256=sUVFdCjvLigTqUUhILywdHpiC0qKCtKPABn5kUXfuxQ,27064
14
+ nuclia/lib/nua.py,sha256=jVUV8I50iJpUQD6yn4V7wNDDSDP8oEWrRxnmrs33hnI,27591
15
15
  nuclia/lib/nua_chat.py,sha256=ApL1Y1FWvAVUt-Y9a_8TUSJIhg8-UmBSy8TlDPn6tD8,3874
16
- nuclia/lib/nua_responses.py,sha256=RJB7bbgZEpYr29aOmaT2nb1X0g6wds4MOngq7N-ZcTg,13382
16
+ nuclia/lib/nua_responses.py,sha256=FyCedSKHVGn4w380obV9B3D1JaqVyNTqda4-OCs4A9A,13637
17
17
  nuclia/lib/utils.py,sha256=9l6DxBk-11WqUhXRn99cqeuVTUOJXj-1S6zckal7wOk,6312
18
18
  nuclia/sdk/__init__.py,sha256=-nAw8i53XBdmbfTa1FJZ0FNRMNakimDVpD6W4OdES-c,1374
19
19
  nuclia/sdk/accounts.py,sha256=7XQ3K9_jlSuk2Cez868FtazZ05xSGab6h3Mt1qMMwIE,647
@@ -28,7 +28,7 @@ nuclia/sdk/logger.py,sha256=UHB81eS6IGmLrsofKxLh8cmF2AsaTj_HXP0tGqMr_HM,57
28
28
  nuclia/sdk/logs.py,sha256=3jfORpo8fzZiXFFSbGY0o3Bre1ZgJaKQCXgxP1keNHw,9614
29
29
  nuclia/sdk/nua.py,sha256=6t0m0Sx-UhqNU2Hx9v6vTwy0m3a30K4T0KmP9G43MzY,293
30
30
  nuclia/sdk/nucliadb.py,sha256=bOESIppPgY7IrNqrYY7T3ESoxwttbOSTm5zj1xUS1jI,1288
31
- nuclia/sdk/predict.py,sha256=KF7iT2aasaB9DIEAwqktXbOl2H_Y_ne-6-SEErN7YOk,9095
31
+ nuclia/sdk/predict.py,sha256=Gom-BlISXawThNq-f7fb1te5tY4catlIdmO-pMJyqkk,9815
32
32
  nuclia/sdk/process.py,sha256=WuNnqaWprp-EABWDC_z7O2woesGIlYWnDUKozh7Ibr4,2241
33
33
  nuclia/sdk/remi.py,sha256=BEb3O9R2jOFlOda4vjFucKKGO1c2eTkqYZdFlIy3Zmo,4357
34
34
  nuclia/sdk/resource.py,sha256=0lSvD4e1FpN5iM9W295dOKLJ8hsXfIe8HKdo0HsVg20,13976
@@ -57,15 +57,15 @@ nuclia/tests/test_manage/test_auth.py,sha256=I5ho9rKhrzCiP57cDVcs2zrXyy1uWlKxzfE
57
57
  nuclia/tests/test_manage/test_kb.py,sha256=gqBMxmIuPUKC6kP2V_IapS9u72QR99V8CDQlJOa8sHU,1959
58
58
  nuclia/tests/test_nua/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  nuclia/tests/test_nua/test_agent.py,sha256=iPA8lVw7CyONS-0fVG7rDzv8T-LnUlNPMNynTnP-2ic,334
60
- nuclia/tests/test_nua/test_predict.py,sha256=SKASYohoZgDGHoVQsS2_jFwXVdMFU6REh5Suw8okIHQ,5347
60
+ nuclia/tests/test_nua/test_predict.py,sha256=8by69GgXuOZKAgksjSjgmbNr98jCXanF1HOo29oNuMg,5798
61
61
  nuclia/tests/test_nucliadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
62
  nuclia/tests/test_nucliadb/test_crud.py,sha256=GuY76HRvt2DFaNgioKm5n0Aco1HnG7zzV_zKom5N8xc,1173
63
63
  nuclia/tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
64
  nuclia/tests/unit/test_export_import.py,sha256=xo_wVbjUnNlVV65ZGH7LtZ38qy39EkJp2hjOuTHC1nU,980
65
65
  nuclia/tests/unit/test_nua_responses.py,sha256=t_hIdVztTi27RWvpfTJUYcCL0lpKdZFegZIwLdaPNh8,319
66
- nuclia-4.9.2.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
67
- nuclia-4.9.2.dist-info/METADATA,sha256=QqBc1WyVh98_k_OuLTJBoD5_Ig5VZuQ5ewQZIriu9p0,2337
68
- nuclia-4.9.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
- nuclia-4.9.2.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
70
- nuclia-4.9.2.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
71
- nuclia-4.9.2.dist-info/RECORD,,
66
+ nuclia-4.9.3.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
67
+ nuclia-4.9.3.dist-info/METADATA,sha256=rxQGLEDjPHayY_NPuXkmXpfsp4_BfoBC2ZjLlOjhyoM,2337
68
+ nuclia-4.9.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
+ nuclia-4.9.3.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
70
+ nuclia-4.9.3.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
71
+ nuclia-4.9.3.dist-info/RECORD,,
File without changes