nuclia 4.9.3__py3-none-any.whl → 4.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/kb.py +37 -10
- nuclia/lib/nua.py +130 -25
- nuclia/lib/nua_responses.py +7 -0
- nuclia/sdk/kb.py +10 -2
- nuclia/sdk/predict.py +155 -32
- nuclia/sdk/search.py +68 -11
- nuclia/sdk/task.py +5 -5
- nuclia/tests/test_kb/test_search.py +5 -1
- nuclia/tests/test_nua/test_predict.py +77 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/METADATA +4 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/RECORD +15 -15
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/WHEEL +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/top_level.txt +0 -0
nuclia/lib/kb.py
CHANGED
@@ -324,13 +324,18 @@ class NucliaDBClient(BaseNucliaDBClient):
|
|
324
324
|
handle_http_sync_errors(response)
|
325
325
|
return int(response.headers.get("Upload-Offset"))
|
326
326
|
|
327
|
-
def summarize(
|
327
|
+
def summarize(
|
328
|
+
self,
|
329
|
+
request: SummarizeRequest,
|
330
|
+
extra_headers: Optional[dict[str, str]] = None,
|
331
|
+
timeout: int = 1000,
|
332
|
+
):
|
328
333
|
if self.url is None or self.writer_session is None:
|
329
334
|
raise Exception("KB not configured")
|
330
335
|
url = f"{self.url}{SUMMARIZE_URL}"
|
331
336
|
assert self.reader_session
|
332
337
|
response = self.reader_session.post(
|
333
|
-
url, json=request.model_dump(), timeout=timeout
|
338
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
334
339
|
)
|
335
340
|
handle_http_sync_errors(response)
|
336
341
|
return response
|
@@ -450,12 +455,16 @@ class NucliaDBClient(BaseNucliaDBClient):
|
|
450
455
|
handle_http_sync_errors(response)
|
451
456
|
return response
|
452
457
|
|
453
|
-
def delete_task(self, task_id: str) -> httpx.Response:
|
458
|
+
def delete_task(self, task_id: str, cleanup: bool = False) -> httpx.Response:
|
454
459
|
if self.writer_session is None:
|
455
460
|
raise Exception("KB not configured")
|
456
461
|
|
462
|
+
params = None
|
463
|
+
if cleanup:
|
464
|
+
params = {"cleanup": "true"}
|
465
|
+
|
457
466
|
response: httpx.Response = self.writer_session.delete(
|
458
|
-
f"{self.url}{DELETE_TASK.format(task_id=task_id)}",
|
467
|
+
f"{self.url}{DELETE_TASK.format(task_id=task_id)}", params=params
|
459
468
|
)
|
460
469
|
handle_http_sync_errors(response)
|
461
470
|
return response
|
@@ -569,12 +578,21 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
569
578
|
await handle_http_async_errors(response)
|
570
579
|
return response
|
571
580
|
|
572
|
-
async def ask(
|
581
|
+
async def ask(
|
582
|
+
self,
|
583
|
+
request: AskRequest,
|
584
|
+
extra_headers: Optional[dict[str, str]] = None,
|
585
|
+
timeout: int = 1000,
|
586
|
+
):
|
573
587
|
if self.url is None or self.reader_session is None:
|
574
588
|
raise Exception("KB not configured")
|
575
589
|
url = f"{self.url}{ASK_URL}"
|
576
590
|
req = self.reader_session.build_request(
|
577
|
-
"POST",
|
591
|
+
"POST",
|
592
|
+
url,
|
593
|
+
json=request.model_dump(),
|
594
|
+
headers=extra_headers,
|
595
|
+
timeout=timeout,
|
578
596
|
)
|
579
597
|
response = await self.reader_session.send(req, stream=True)
|
580
598
|
await handle_http_async_errors(response)
|
@@ -681,13 +699,18 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
681
699
|
await handle_http_async_errors(response)
|
682
700
|
return int(response.headers.get("Upload-Offset"))
|
683
701
|
|
684
|
-
async def summarize(
|
702
|
+
async def summarize(
|
703
|
+
self,
|
704
|
+
request: SummarizeRequest,
|
705
|
+
extra_headers: Optional[dict[str, str]] = None,
|
706
|
+
timeout: int = 1000,
|
707
|
+
):
|
685
708
|
if self.url is None or self.writer_session is None:
|
686
709
|
raise Exception("KB not configured")
|
687
710
|
url = f"{self.url}{SUMMARIZE_URL}"
|
688
711
|
assert self.reader_session
|
689
712
|
response = await self.reader_session.post(
|
690
|
-
url, json=request.model_dump(), timeout=timeout
|
713
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
691
714
|
)
|
692
715
|
await handle_http_async_errors(response)
|
693
716
|
return response
|
@@ -808,12 +831,16 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
808
831
|
await handle_http_async_errors(response)
|
809
832
|
return response
|
810
833
|
|
811
|
-
async def delete_task(self, task_id: str) -> httpx.Response:
|
834
|
+
async def delete_task(self, task_id: str, cleanup: bool = False) -> httpx.Response:
|
812
835
|
if self.writer_session is None:
|
813
836
|
raise Exception("KB not configured")
|
814
837
|
|
838
|
+
params = None
|
839
|
+
if cleanup:
|
840
|
+
params = {"cleanup": "true"}
|
841
|
+
|
815
842
|
response: httpx.Response = await self.writer_session.delete(
|
816
|
-
f"{self.url}{DELETE_TASK.format(task_id=task_id)}",
|
843
|
+
f"{self.url}{DELETE_TASK.format(task_id=task_id)}", params=params
|
817
844
|
)
|
818
845
|
await handle_http_async_errors(response)
|
819
846
|
return response
|
nuclia/lib/nua.py
CHANGED
@@ -18,6 +18,7 @@ from pydantic import BaseModel
|
|
18
18
|
|
19
19
|
from nuclia import REGIONAL
|
20
20
|
from nuclia.exceptions import NuaAPIException
|
21
|
+
from nuclia_models.common.consumption import ConsumptionGenerative
|
21
22
|
from nuclia_models.predict.generative_responses import (
|
22
23
|
GenerativeChunk,
|
23
24
|
GenerativeFullResponse,
|
@@ -28,6 +29,7 @@ from nuclia_models.predict.generative_responses import (
|
|
28
29
|
StatusGenerativeResponse,
|
29
30
|
ToolsGenerativeResponse,
|
30
31
|
)
|
32
|
+
from nuclia_models.common.consumption import Consumption
|
31
33
|
from nuclia.lib.nua_responses import (
|
32
34
|
ChatModel,
|
33
35
|
ChatResponse,
|
@@ -128,9 +130,12 @@ class NuaClient:
|
|
128
130
|
url: str,
|
129
131
|
output: Type[ConvertType],
|
130
132
|
payload: Optional[dict[Any, Any]] = None,
|
133
|
+
extra_headers: Optional[dict[str, str]] = None,
|
131
134
|
timeout: int = 60,
|
132
135
|
) -> ConvertType:
|
133
|
-
resp = self.client.request(
|
136
|
+
resp = self.client.request(
|
137
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
138
|
+
)
|
134
139
|
if resp.status_code != 200:
|
135
140
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
136
141
|
try:
|
@@ -143,6 +148,7 @@ class NuaClient:
|
|
143
148
|
self,
|
144
149
|
method: str,
|
145
150
|
url: str,
|
151
|
+
extra_headers: Optional[dict[str, str]] = None,
|
146
152
|
payload: Optional[dict[Any, Any]] = None,
|
147
153
|
timeout: int = 60,
|
148
154
|
) -> Iterator[GenerativeChunk]:
|
@@ -151,8 +157,9 @@ class NuaClient:
|
|
151
157
|
url,
|
152
158
|
json=payload,
|
153
159
|
timeout=timeout,
|
160
|
+
headers=extra_headers,
|
154
161
|
) as response:
|
155
|
-
if response.headers.get("
|
162
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
156
163
|
for json_body in response.iter_lines():
|
157
164
|
try:
|
158
165
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -194,17 +201,31 @@ class NuaClient:
|
|
194
201
|
endpoint = f"{self.url}{CONFIG}/{kbid}"
|
195
202
|
return self._request("GET", endpoint, output=StoredLearningConfiguration)
|
196
203
|
|
197
|
-
def sentence_predict(
|
204
|
+
def sentence_predict(
|
205
|
+
self,
|
206
|
+
text: str,
|
207
|
+
model: Optional[str] = None,
|
208
|
+
extra_headers: Optional[dict[str, str]] = None,
|
209
|
+
) -> Sentence:
|
198
210
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
199
211
|
if model:
|
200
212
|
endpoint += f"&model={model}"
|
201
|
-
return self._request(
|
213
|
+
return self._request(
|
214
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
215
|
+
)
|
202
216
|
|
203
|
-
def tokens_predict(
|
217
|
+
def tokens_predict(
|
218
|
+
self,
|
219
|
+
text: str,
|
220
|
+
model: Optional[str] = None,
|
221
|
+
extra_headers: Optional[dict[str, str]] = None,
|
222
|
+
) -> Tokens:
|
204
223
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
205
224
|
if model:
|
206
225
|
endpoint += f"&model={model}"
|
207
|
-
return self._request(
|
226
|
+
return self._request(
|
227
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
228
|
+
)
|
208
229
|
|
209
230
|
def query_predict(
|
210
231
|
self,
|
@@ -212,6 +233,7 @@ class NuaClient:
|
|
212
233
|
semantic_model: Optional[str] = None,
|
213
234
|
token_model: Optional[str] = None,
|
214
235
|
generative_model: Optional[str] = None,
|
236
|
+
extra_headers: Optional[dict[str, str]] = None,
|
215
237
|
) -> QueryInfo:
|
216
238
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
217
239
|
if semantic_model:
|
@@ -220,10 +242,16 @@ class NuaClient:
|
|
220
242
|
endpoint += f"&token_model={token_model}"
|
221
243
|
if generative_model:
|
222
244
|
endpoint += f"&generative_model={generative_model}"
|
223
|
-
return self._request(
|
245
|
+
return self._request(
|
246
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
247
|
+
)
|
224
248
|
|
225
249
|
def generate(
|
226
|
-
self,
|
250
|
+
self,
|
251
|
+
body: ChatModel,
|
252
|
+
model: Optional[str] = None,
|
253
|
+
extra_headers: Optional[dict[str, str]] = None,
|
254
|
+
timeout: int = 300,
|
227
255
|
) -> GenerativeFullResponse:
|
228
256
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
229
257
|
if model:
|
@@ -235,6 +263,7 @@ class NuaClient:
|
|
235
263
|
endpoint,
|
236
264
|
payload=body.model_dump(),
|
237
265
|
timeout=timeout,
|
266
|
+
extra_headers=extra_headers,
|
238
267
|
):
|
239
268
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
240
269
|
result.answer += chunk.chunk.text
|
@@ -252,10 +281,19 @@ class NuaClient:
|
|
252
281
|
result.code = chunk.chunk.code
|
253
282
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
254
283
|
result.tools = chunk.chunk.tools
|
284
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
285
|
+
result.consumption = Consumption(
|
286
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
287
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
288
|
+
)
|
255
289
|
return result
|
256
290
|
|
257
291
|
def generate_stream(
|
258
|
-
self,
|
292
|
+
self,
|
293
|
+
body: ChatModel,
|
294
|
+
model: Optional[str] = None,
|
295
|
+
extra_headers: Optional[dict[str, str]] = None,
|
296
|
+
timeout: int = 300,
|
259
297
|
) -> Iterator[GenerativeChunk]:
|
260
298
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
261
299
|
if model:
|
@@ -266,11 +304,16 @@ class NuaClient:
|
|
266
304
|
endpoint,
|
267
305
|
payload=body.model_dump(),
|
268
306
|
timeout=timeout,
|
307
|
+
extra_headers=extra_headers,
|
269
308
|
):
|
270
309
|
yield gr
|
271
310
|
|
272
311
|
def summarize(
|
273
|
-
self,
|
312
|
+
self,
|
313
|
+
documents: dict[str, str],
|
314
|
+
model: Optional[str] = None,
|
315
|
+
extra_headers: Optional[dict[str, str]] = None,
|
316
|
+
timeout: int = 300,
|
274
317
|
) -> SummarizedModel:
|
275
318
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
276
319
|
if model:
|
@@ -288,6 +331,7 @@ class NuaClient:
|
|
288
331
|
payload=body.model_dump(),
|
289
332
|
output=SummarizedModel,
|
290
333
|
timeout=timeout,
|
334
|
+
extra_headers=extra_headers,
|
291
335
|
)
|
292
336
|
|
293
337
|
def rephrase(
|
@@ -324,11 +368,13 @@ class NuaClient:
|
|
324
368
|
def remi(
|
325
369
|
self,
|
326
370
|
request: RemiRequest,
|
371
|
+
extra_headers: Optional[dict[str, str]] = None,
|
327
372
|
) -> RemiResponse:
|
328
373
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
329
374
|
return self._request(
|
330
375
|
"POST",
|
331
376
|
endpoint,
|
377
|
+
extra_headers=extra_headers,
|
332
378
|
payload=request.model_dump(),
|
333
379
|
output=RemiResponse,
|
334
380
|
)
|
@@ -413,10 +459,18 @@ class NuaClient:
|
|
413
459
|
activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
|
414
460
|
return self._request("GET", activity_endpoint, ProcessRequestStatus)
|
415
461
|
|
416
|
-
def rerank(
|
462
|
+
def rerank(
|
463
|
+
self,
|
464
|
+
model: RerankModel,
|
465
|
+
extra_headers: Optional[dict[str, str]] = None,
|
466
|
+
) -> RerankResponse:
|
417
467
|
endpoint = f"{self.url}{RERANK}"
|
418
468
|
return self._request(
|
419
|
-
"POST",
|
469
|
+
"POST",
|
470
|
+
endpoint,
|
471
|
+
payload=model.model_dump(),
|
472
|
+
output=RerankResponse,
|
473
|
+
extra_headers=extra_headers,
|
420
474
|
)
|
421
475
|
|
422
476
|
|
@@ -454,9 +508,12 @@ class AsyncNuaClient:
|
|
454
508
|
url: str,
|
455
509
|
output: Type[ConvertType],
|
456
510
|
payload: Optional[dict[Any, Any]] = None,
|
511
|
+
extra_headers: Optional[dict[str, str]] = None,
|
457
512
|
timeout: int = 60,
|
458
513
|
) -> ConvertType:
|
459
|
-
resp = await self.client.request(
|
514
|
+
resp = await self.client.request(
|
515
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
516
|
+
)
|
460
517
|
if resp.status_code != 200:
|
461
518
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
462
519
|
try:
|
@@ -469,6 +526,7 @@ class AsyncNuaClient:
|
|
469
526
|
self,
|
470
527
|
method: str,
|
471
528
|
url: str,
|
529
|
+
extra_headers: Optional[dict[str, str]] = None,
|
472
530
|
payload: Optional[dict[Any, Any]] = None,
|
473
531
|
timeout: int = 60,
|
474
532
|
) -> AsyncIterator[GenerativeChunk]:
|
@@ -477,8 +535,9 @@ class AsyncNuaClient:
|
|
477
535
|
url,
|
478
536
|
json=payload,
|
479
537
|
timeout=timeout,
|
538
|
+
headers=extra_headers,
|
480
539
|
) as response:
|
481
|
-
if response.headers.get("
|
540
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
482
541
|
async for json_body in response.aiter_lines():
|
483
542
|
try:
|
484
543
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -527,18 +586,30 @@ class AsyncNuaClient:
|
|
527
586
|
return await self._request("GET", endpoint, output=StoredLearningConfiguration)
|
528
587
|
|
529
588
|
async def sentence_predict(
|
530
|
-
self,
|
589
|
+
self,
|
590
|
+
text: str,
|
591
|
+
model: Optional[str] = None,
|
592
|
+
extra_headers: Optional[dict[str, str]] = None,
|
531
593
|
) -> Sentence:
|
532
594
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
533
595
|
if model:
|
534
596
|
endpoint += f"&model={model}"
|
535
|
-
return await self._request(
|
597
|
+
return await self._request(
|
598
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
599
|
+
)
|
536
600
|
|
537
|
-
async def tokens_predict(
|
601
|
+
async def tokens_predict(
|
602
|
+
self,
|
603
|
+
text: str,
|
604
|
+
model: Optional[str] = None,
|
605
|
+
extra_headers: Optional[dict[str, str]] = None,
|
606
|
+
) -> Tokens:
|
538
607
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
539
608
|
if model:
|
540
609
|
endpoint += f"&model={model}"
|
541
|
-
return await self._request(
|
610
|
+
return await self._request(
|
611
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
612
|
+
)
|
542
613
|
|
543
614
|
async def query_predict(
|
544
615
|
self,
|
@@ -546,6 +617,7 @@ class AsyncNuaClient:
|
|
546
617
|
semantic_model: Optional[str] = None,
|
547
618
|
token_model: Optional[str] = None,
|
548
619
|
generative_model: Optional[str] = None,
|
620
|
+
extra_headers: Optional[dict[str, str]] = None,
|
549
621
|
) -> QueryInfo:
|
550
622
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
551
623
|
if semantic_model:
|
@@ -554,7 +626,9 @@ class AsyncNuaClient:
|
|
554
626
|
endpoint += f"&token_model={token_model}"
|
555
627
|
if generative_model:
|
556
628
|
endpoint += f"&generative_model={generative_model}"
|
557
|
-
return await self._request(
|
629
|
+
return await self._request(
|
630
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
631
|
+
)
|
558
632
|
|
559
633
|
@deprecated(version="2.1.0", reason="You should use generate function")
|
560
634
|
async def generate_predict(
|
@@ -573,7 +647,11 @@ class AsyncNuaClient:
|
|
573
647
|
)
|
574
648
|
|
575
649
|
async def generate(
|
576
|
-
self,
|
650
|
+
self,
|
651
|
+
body: ChatModel,
|
652
|
+
model: Optional[str] = None,
|
653
|
+
extra_headers: Optional[dict[str, str]] = None,
|
654
|
+
timeout: int = 300,
|
577
655
|
) -> GenerativeFullResponse:
|
578
656
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
579
657
|
if model:
|
@@ -585,6 +663,7 @@ class AsyncNuaClient:
|
|
585
663
|
endpoint,
|
586
664
|
payload=body.model_dump(),
|
587
665
|
timeout=timeout,
|
666
|
+
extra_headers=extra_headers,
|
588
667
|
):
|
589
668
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
590
669
|
result.answer += chunk.chunk.text
|
@@ -602,11 +681,20 @@ class AsyncNuaClient:
|
|
602
681
|
result.code = chunk.chunk.code
|
603
682
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
604
683
|
result.tools = chunk.chunk.tools
|
684
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
685
|
+
result.consumption = Consumption(
|
686
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
687
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
688
|
+
)
|
605
689
|
|
606
690
|
return result
|
607
691
|
|
608
692
|
async def generate_stream(
|
609
|
-
self,
|
693
|
+
self,
|
694
|
+
body: ChatModel,
|
695
|
+
model: Optional[str] = None,
|
696
|
+
extra_headers: Optional[dict[str, str]] = None,
|
697
|
+
timeout: int = 300,
|
610
698
|
) -> AsyncIterator[GenerativeChunk]:
|
611
699
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
612
700
|
if model:
|
@@ -617,11 +705,16 @@ class AsyncNuaClient:
|
|
617
705
|
endpoint,
|
618
706
|
payload=body.model_dump(),
|
619
707
|
timeout=timeout,
|
708
|
+
extra_headers=extra_headers,
|
620
709
|
):
|
621
710
|
yield gr
|
622
711
|
|
623
712
|
async def summarize(
|
624
|
-
self,
|
713
|
+
self,
|
714
|
+
documents: dict[str, str],
|
715
|
+
model: Optional[str] = None,
|
716
|
+
extra_headers: Optional[dict[str, str]] = None,
|
717
|
+
timeout: int = 300,
|
625
718
|
) -> SummarizedModel:
|
626
719
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
627
720
|
if model:
|
@@ -639,6 +732,7 @@ class AsyncNuaClient:
|
|
639
732
|
payload=body.model_dump(),
|
640
733
|
output=SummarizedModel,
|
641
734
|
timeout=timeout,
|
735
|
+
extra_headers=extra_headers,
|
642
736
|
)
|
643
737
|
|
644
738
|
async def rephrase(
|
@@ -672,13 +766,18 @@ class AsyncNuaClient:
|
|
672
766
|
output=RephraseModel,
|
673
767
|
)
|
674
768
|
|
675
|
-
async def remi(
|
769
|
+
async def remi(
|
770
|
+
self,
|
771
|
+
request: RemiRequest,
|
772
|
+
extra_headers: Optional[dict[str, str]] = None,
|
773
|
+
) -> RemiResponse:
|
676
774
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
677
775
|
return await self._request(
|
678
776
|
"POST",
|
679
777
|
endpoint,
|
680
778
|
payload=request.model_dump(),
|
681
779
|
output=RemiResponse,
|
780
|
+
extra_headers=extra_headers,
|
682
781
|
)
|
683
782
|
|
684
783
|
async def generate_retrieval(
|
@@ -802,8 +901,14 @@ class AsyncNuaClient:
|
|
802
901
|
"GET", activity_endpoint, output=ProcessRequestStatus
|
803
902
|
)
|
804
903
|
|
805
|
-
async def rerank(
|
904
|
+
async def rerank(
|
905
|
+
self, model: RerankModel, extra_headers: Optional[dict[str, str]] = None
|
906
|
+
) -> RerankResponse:
|
806
907
|
endpoint = f"{self.url}{RERANK}"
|
807
908
|
return await self._request(
|
808
|
-
"POST",
|
909
|
+
"POST",
|
910
|
+
endpoint,
|
911
|
+
payload=model.model_dump(),
|
912
|
+
output=RerankResponse,
|
913
|
+
extra_headers=extra_headers,
|
809
914
|
)
|
nuclia/lib/nua_responses.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union, cast
|
|
5
5
|
import pydantic
|
6
6
|
from pydantic import BaseModel, Field, RootModel, model_validator
|
7
7
|
from typing_extensions import Annotated, Self
|
8
|
+
from nuclia_models.common.consumption import Consumption
|
8
9
|
|
9
10
|
|
10
11
|
class GenerativeOption(BaseModel):
|
@@ -51,6 +52,7 @@ class ConfigSchema(BaseModel):
|
|
51
52
|
class Sentence(BaseModel):
|
52
53
|
data: List[float]
|
53
54
|
time: float
|
55
|
+
consumption: Optional[Consumption] = None
|
54
56
|
|
55
57
|
|
56
58
|
class Author(str, Enum):
|
@@ -137,6 +139,7 @@ class Token(BaseModel):
|
|
137
139
|
class Tokens(BaseModel):
|
138
140
|
tokens: List[Token]
|
139
141
|
time: float
|
142
|
+
consumption: Optional[Consumption] = None
|
140
143
|
|
141
144
|
|
142
145
|
class SummarizeResource(BaseModel):
|
@@ -155,6 +158,7 @@ class SummarizedResource(BaseModel):
|
|
155
158
|
class SummarizedModel(BaseModel):
|
156
159
|
resources: Dict[str, SummarizedResource]
|
157
160
|
summary: str = ""
|
161
|
+
consumption: Optional[Consumption] = None
|
158
162
|
|
159
163
|
|
160
164
|
class RephraseModel(RootModel[str]):
|
@@ -535,6 +539,7 @@ class StoredLearningConfiguration(BaseModel):
|
|
535
539
|
class SentenceSearch(BaseModel):
|
536
540
|
data: List[float] = []
|
537
541
|
time: float
|
542
|
+
consumption: Optional[Consumption] = None
|
538
543
|
|
539
544
|
|
540
545
|
class Ner(BaseModel):
|
@@ -547,6 +552,7 @@ class Ner(BaseModel):
|
|
547
552
|
class TokenSearch(BaseModel):
|
548
553
|
tokens: List[Ner] = []
|
549
554
|
time: float
|
555
|
+
consumption: Optional[Consumption] = None
|
550
556
|
|
551
557
|
|
552
558
|
class QueryInfo(BaseModel):
|
@@ -569,3 +575,4 @@ class RerankResponse(BaseModel):
|
|
569
575
|
context_scores: dict[str, float] = Field(
|
570
576
|
description="Scores for each context given by the reranker"
|
571
577
|
)
|
578
|
+
consumption: Optional[Consumption] = None
|
nuclia/sdk/kb.py
CHANGED
@@ -297,9 +297,15 @@ class NucliaKB:
|
|
297
297
|
)
|
298
298
|
|
299
299
|
@kb
|
300
|
-
def summarize(
|
300
|
+
def summarize(
|
301
|
+
self, *, resources: List[str], show_consumption: bool = False, **kwargs
|
302
|
+
):
|
301
303
|
ndb: NucliaDBClient = kwargs["ndb"]
|
302
|
-
return ndb.ndb.summarize(
|
304
|
+
return ndb.ndb.summarize(
|
305
|
+
kbid=ndb.kbid,
|
306
|
+
resources=resources,
|
307
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
308
|
+
)
|
303
309
|
|
304
310
|
@kb
|
305
311
|
def notifications(self, **kwargs):
|
@@ -715,6 +721,7 @@ class AsyncNucliaKB:
|
|
715
721
|
resources: List[str],
|
716
722
|
generative_model: Optional[str] = None,
|
717
723
|
summary_kind: Optional[str] = None,
|
724
|
+
show_consumption: bool = False,
|
718
725
|
timeout: int = 1000,
|
719
726
|
**kwargs,
|
720
727
|
) -> SummarizedModel:
|
@@ -725,6 +732,7 @@ class AsyncNucliaKB:
|
|
725
732
|
generative_model=generative_model,
|
726
733
|
summary_kind=SummaryKind(summary_kind),
|
727
734
|
),
|
735
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
728
736
|
timeout=timeout,
|
729
737
|
)
|
730
738
|
return SummarizedModel.model_validate(resp.json())
|