nuclia 4.9.3__py3-none-any.whl → 4.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/kb.py +25 -6
- nuclia/lib/nua.py +130 -25
- nuclia/lib/nua_responses.py +7 -0
- nuclia/sdk/kb.py +10 -2
- nuclia/sdk/predict.py +155 -32
- nuclia/sdk/search.py +68 -11
- nuclia/tests/test_kb/test_search.py +5 -1
- nuclia/tests/test_nua/test_predict.py +77 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/METADATA +4 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/RECORD +14 -14
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/WHEEL +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.4.dist-info}/top_level.txt +0 -0
nuclia/lib/kb.py
CHANGED
@@ -324,13 +324,18 @@ class NucliaDBClient(BaseNucliaDBClient):
|
|
324
324
|
handle_http_sync_errors(response)
|
325
325
|
return int(response.headers.get("Upload-Offset"))
|
326
326
|
|
327
|
-
def summarize(
|
327
|
+
def summarize(
|
328
|
+
self,
|
329
|
+
request: SummarizeRequest,
|
330
|
+
extra_headers: Optional[dict[str, str]] = None,
|
331
|
+
timeout: int = 1000,
|
332
|
+
):
|
328
333
|
if self.url is None or self.writer_session is None:
|
329
334
|
raise Exception("KB not configured")
|
330
335
|
url = f"{self.url}{SUMMARIZE_URL}"
|
331
336
|
assert self.reader_session
|
332
337
|
response = self.reader_session.post(
|
333
|
-
url, json=request.model_dump(), timeout=timeout
|
338
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
334
339
|
)
|
335
340
|
handle_http_sync_errors(response)
|
336
341
|
return response
|
@@ -569,12 +574,21 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
569
574
|
await handle_http_async_errors(response)
|
570
575
|
return response
|
571
576
|
|
572
|
-
async def ask(
|
577
|
+
async def ask(
|
578
|
+
self,
|
579
|
+
request: AskRequest,
|
580
|
+
extra_headers: Optional[dict[str, str]] = None,
|
581
|
+
timeout: int = 1000,
|
582
|
+
):
|
573
583
|
if self.url is None or self.reader_session is None:
|
574
584
|
raise Exception("KB not configured")
|
575
585
|
url = f"{self.url}{ASK_URL}"
|
576
586
|
req = self.reader_session.build_request(
|
577
|
-
"POST",
|
587
|
+
"POST",
|
588
|
+
url,
|
589
|
+
json=request.model_dump(),
|
590
|
+
headers=extra_headers,
|
591
|
+
timeout=timeout,
|
578
592
|
)
|
579
593
|
response = await self.reader_session.send(req, stream=True)
|
580
594
|
await handle_http_async_errors(response)
|
@@ -681,13 +695,18 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
681
695
|
await handle_http_async_errors(response)
|
682
696
|
return int(response.headers.get("Upload-Offset"))
|
683
697
|
|
684
|
-
async def summarize(
|
698
|
+
async def summarize(
|
699
|
+
self,
|
700
|
+
request: SummarizeRequest,
|
701
|
+
extra_headers: Optional[dict[str, str]] = None,
|
702
|
+
timeout: int = 1000,
|
703
|
+
):
|
685
704
|
if self.url is None or self.writer_session is None:
|
686
705
|
raise Exception("KB not configured")
|
687
706
|
url = f"{self.url}{SUMMARIZE_URL}"
|
688
707
|
assert self.reader_session
|
689
708
|
response = await self.reader_session.post(
|
690
|
-
url, json=request.model_dump(), timeout=timeout
|
709
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
691
710
|
)
|
692
711
|
await handle_http_async_errors(response)
|
693
712
|
return response
|
nuclia/lib/nua.py
CHANGED
@@ -18,6 +18,7 @@ from pydantic import BaseModel
|
|
18
18
|
|
19
19
|
from nuclia import REGIONAL
|
20
20
|
from nuclia.exceptions import NuaAPIException
|
21
|
+
from nuclia_models.common.consumption import ConsumptionGenerative
|
21
22
|
from nuclia_models.predict.generative_responses import (
|
22
23
|
GenerativeChunk,
|
23
24
|
GenerativeFullResponse,
|
@@ -28,6 +29,7 @@ from nuclia_models.predict.generative_responses import (
|
|
28
29
|
StatusGenerativeResponse,
|
29
30
|
ToolsGenerativeResponse,
|
30
31
|
)
|
32
|
+
from nuclia_models.common.consumption import Consumption
|
31
33
|
from nuclia.lib.nua_responses import (
|
32
34
|
ChatModel,
|
33
35
|
ChatResponse,
|
@@ -128,9 +130,12 @@ class NuaClient:
|
|
128
130
|
url: str,
|
129
131
|
output: Type[ConvertType],
|
130
132
|
payload: Optional[dict[Any, Any]] = None,
|
133
|
+
extra_headers: Optional[dict[str, str]] = None,
|
131
134
|
timeout: int = 60,
|
132
135
|
) -> ConvertType:
|
133
|
-
resp = self.client.request(
|
136
|
+
resp = self.client.request(
|
137
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
138
|
+
)
|
134
139
|
if resp.status_code != 200:
|
135
140
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
136
141
|
try:
|
@@ -143,6 +148,7 @@ class NuaClient:
|
|
143
148
|
self,
|
144
149
|
method: str,
|
145
150
|
url: str,
|
151
|
+
extra_headers: Optional[dict[str, str]] = None,
|
146
152
|
payload: Optional[dict[Any, Any]] = None,
|
147
153
|
timeout: int = 60,
|
148
154
|
) -> Iterator[GenerativeChunk]:
|
@@ -151,8 +157,9 @@ class NuaClient:
|
|
151
157
|
url,
|
152
158
|
json=payload,
|
153
159
|
timeout=timeout,
|
160
|
+
headers=extra_headers,
|
154
161
|
) as response:
|
155
|
-
if response.headers.get("
|
162
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
156
163
|
for json_body in response.iter_lines():
|
157
164
|
try:
|
158
165
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -194,17 +201,31 @@ class NuaClient:
|
|
194
201
|
endpoint = f"{self.url}{CONFIG}/{kbid}"
|
195
202
|
return self._request("GET", endpoint, output=StoredLearningConfiguration)
|
196
203
|
|
197
|
-
def sentence_predict(
|
204
|
+
def sentence_predict(
|
205
|
+
self,
|
206
|
+
text: str,
|
207
|
+
model: Optional[str] = None,
|
208
|
+
extra_headers: Optional[dict[str, str]] = None,
|
209
|
+
) -> Sentence:
|
198
210
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
199
211
|
if model:
|
200
212
|
endpoint += f"&model={model}"
|
201
|
-
return self._request(
|
213
|
+
return self._request(
|
214
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
215
|
+
)
|
202
216
|
|
203
|
-
def tokens_predict(
|
217
|
+
def tokens_predict(
|
218
|
+
self,
|
219
|
+
text: str,
|
220
|
+
model: Optional[str] = None,
|
221
|
+
extra_headers: Optional[dict[str, str]] = None,
|
222
|
+
) -> Tokens:
|
204
223
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
205
224
|
if model:
|
206
225
|
endpoint += f"&model={model}"
|
207
|
-
return self._request(
|
226
|
+
return self._request(
|
227
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
228
|
+
)
|
208
229
|
|
209
230
|
def query_predict(
|
210
231
|
self,
|
@@ -212,6 +233,7 @@ class NuaClient:
|
|
212
233
|
semantic_model: Optional[str] = None,
|
213
234
|
token_model: Optional[str] = None,
|
214
235
|
generative_model: Optional[str] = None,
|
236
|
+
extra_headers: Optional[dict[str, str]] = None,
|
215
237
|
) -> QueryInfo:
|
216
238
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
217
239
|
if semantic_model:
|
@@ -220,10 +242,16 @@ class NuaClient:
|
|
220
242
|
endpoint += f"&token_model={token_model}"
|
221
243
|
if generative_model:
|
222
244
|
endpoint += f"&generative_model={generative_model}"
|
223
|
-
return self._request(
|
245
|
+
return self._request(
|
246
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
247
|
+
)
|
224
248
|
|
225
249
|
def generate(
|
226
|
-
self,
|
250
|
+
self,
|
251
|
+
body: ChatModel,
|
252
|
+
model: Optional[str] = None,
|
253
|
+
extra_headers: Optional[dict[str, str]] = None,
|
254
|
+
timeout: int = 300,
|
227
255
|
) -> GenerativeFullResponse:
|
228
256
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
229
257
|
if model:
|
@@ -235,6 +263,7 @@ class NuaClient:
|
|
235
263
|
endpoint,
|
236
264
|
payload=body.model_dump(),
|
237
265
|
timeout=timeout,
|
266
|
+
extra_headers=extra_headers,
|
238
267
|
):
|
239
268
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
240
269
|
result.answer += chunk.chunk.text
|
@@ -252,10 +281,19 @@ class NuaClient:
|
|
252
281
|
result.code = chunk.chunk.code
|
253
282
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
254
283
|
result.tools = chunk.chunk.tools
|
284
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
285
|
+
result.consumption = Consumption(
|
286
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
287
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
288
|
+
)
|
255
289
|
return result
|
256
290
|
|
257
291
|
def generate_stream(
|
258
|
-
self,
|
292
|
+
self,
|
293
|
+
body: ChatModel,
|
294
|
+
model: Optional[str] = None,
|
295
|
+
extra_headers: Optional[dict[str, str]] = None,
|
296
|
+
timeout: int = 300,
|
259
297
|
) -> Iterator[GenerativeChunk]:
|
260
298
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
261
299
|
if model:
|
@@ -266,11 +304,16 @@ class NuaClient:
|
|
266
304
|
endpoint,
|
267
305
|
payload=body.model_dump(),
|
268
306
|
timeout=timeout,
|
307
|
+
extra_headers=extra_headers,
|
269
308
|
):
|
270
309
|
yield gr
|
271
310
|
|
272
311
|
def summarize(
|
273
|
-
self,
|
312
|
+
self,
|
313
|
+
documents: dict[str, str],
|
314
|
+
model: Optional[str] = None,
|
315
|
+
extra_headers: Optional[dict[str, str]] = None,
|
316
|
+
timeout: int = 300,
|
274
317
|
) -> SummarizedModel:
|
275
318
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
276
319
|
if model:
|
@@ -288,6 +331,7 @@ class NuaClient:
|
|
288
331
|
payload=body.model_dump(),
|
289
332
|
output=SummarizedModel,
|
290
333
|
timeout=timeout,
|
334
|
+
extra_headers=extra_headers,
|
291
335
|
)
|
292
336
|
|
293
337
|
def rephrase(
|
@@ -324,11 +368,13 @@ class NuaClient:
|
|
324
368
|
def remi(
|
325
369
|
self,
|
326
370
|
request: RemiRequest,
|
371
|
+
extra_headers: Optional[dict[str, str]] = None,
|
327
372
|
) -> RemiResponse:
|
328
373
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
329
374
|
return self._request(
|
330
375
|
"POST",
|
331
376
|
endpoint,
|
377
|
+
extra_headers=extra_headers,
|
332
378
|
payload=request.model_dump(),
|
333
379
|
output=RemiResponse,
|
334
380
|
)
|
@@ -413,10 +459,18 @@ class NuaClient:
|
|
413
459
|
activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
|
414
460
|
return self._request("GET", activity_endpoint, ProcessRequestStatus)
|
415
461
|
|
416
|
-
def rerank(
|
462
|
+
def rerank(
|
463
|
+
self,
|
464
|
+
model: RerankModel,
|
465
|
+
extra_headers: Optional[dict[str, str]] = None,
|
466
|
+
) -> RerankResponse:
|
417
467
|
endpoint = f"{self.url}{RERANK}"
|
418
468
|
return self._request(
|
419
|
-
"POST",
|
469
|
+
"POST",
|
470
|
+
endpoint,
|
471
|
+
payload=model.model_dump(),
|
472
|
+
output=RerankResponse,
|
473
|
+
extra_headers=extra_headers,
|
420
474
|
)
|
421
475
|
|
422
476
|
|
@@ -454,9 +508,12 @@ class AsyncNuaClient:
|
|
454
508
|
url: str,
|
455
509
|
output: Type[ConvertType],
|
456
510
|
payload: Optional[dict[Any, Any]] = None,
|
511
|
+
extra_headers: Optional[dict[str, str]] = None,
|
457
512
|
timeout: int = 60,
|
458
513
|
) -> ConvertType:
|
459
|
-
resp = await self.client.request(
|
514
|
+
resp = await self.client.request(
|
515
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
516
|
+
)
|
460
517
|
if resp.status_code != 200:
|
461
518
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
462
519
|
try:
|
@@ -469,6 +526,7 @@ class AsyncNuaClient:
|
|
469
526
|
self,
|
470
527
|
method: str,
|
471
528
|
url: str,
|
529
|
+
extra_headers: Optional[dict[str, str]] = None,
|
472
530
|
payload: Optional[dict[Any, Any]] = None,
|
473
531
|
timeout: int = 60,
|
474
532
|
) -> AsyncIterator[GenerativeChunk]:
|
@@ -477,8 +535,9 @@ class AsyncNuaClient:
|
|
477
535
|
url,
|
478
536
|
json=payload,
|
479
537
|
timeout=timeout,
|
538
|
+
headers=extra_headers,
|
480
539
|
) as response:
|
481
|
-
if response.headers.get("
|
540
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
482
541
|
async for json_body in response.aiter_lines():
|
483
542
|
try:
|
484
543
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -527,18 +586,30 @@ class AsyncNuaClient:
|
|
527
586
|
return await self._request("GET", endpoint, output=StoredLearningConfiguration)
|
528
587
|
|
529
588
|
async def sentence_predict(
|
530
|
-
self,
|
589
|
+
self,
|
590
|
+
text: str,
|
591
|
+
model: Optional[str] = None,
|
592
|
+
extra_headers: Optional[dict[str, str]] = None,
|
531
593
|
) -> Sentence:
|
532
594
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
533
595
|
if model:
|
534
596
|
endpoint += f"&model={model}"
|
535
|
-
return await self._request(
|
597
|
+
return await self._request(
|
598
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
599
|
+
)
|
536
600
|
|
537
|
-
async def tokens_predict(
|
601
|
+
async def tokens_predict(
|
602
|
+
self,
|
603
|
+
text: str,
|
604
|
+
model: Optional[str] = None,
|
605
|
+
extra_headers: Optional[dict[str, str]] = None,
|
606
|
+
) -> Tokens:
|
538
607
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
539
608
|
if model:
|
540
609
|
endpoint += f"&model={model}"
|
541
|
-
return await self._request(
|
610
|
+
return await self._request(
|
611
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
612
|
+
)
|
542
613
|
|
543
614
|
async def query_predict(
|
544
615
|
self,
|
@@ -546,6 +617,7 @@ class AsyncNuaClient:
|
|
546
617
|
semantic_model: Optional[str] = None,
|
547
618
|
token_model: Optional[str] = None,
|
548
619
|
generative_model: Optional[str] = None,
|
620
|
+
extra_headers: Optional[dict[str, str]] = None,
|
549
621
|
) -> QueryInfo:
|
550
622
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
551
623
|
if semantic_model:
|
@@ -554,7 +626,9 @@ class AsyncNuaClient:
|
|
554
626
|
endpoint += f"&token_model={token_model}"
|
555
627
|
if generative_model:
|
556
628
|
endpoint += f"&generative_model={generative_model}"
|
557
|
-
return await self._request(
|
629
|
+
return await self._request(
|
630
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
631
|
+
)
|
558
632
|
|
559
633
|
@deprecated(version="2.1.0", reason="You should use generate function")
|
560
634
|
async def generate_predict(
|
@@ -573,7 +647,11 @@ class AsyncNuaClient:
|
|
573
647
|
)
|
574
648
|
|
575
649
|
async def generate(
|
576
|
-
self,
|
650
|
+
self,
|
651
|
+
body: ChatModel,
|
652
|
+
model: Optional[str] = None,
|
653
|
+
extra_headers: Optional[dict[str, str]] = None,
|
654
|
+
timeout: int = 300,
|
577
655
|
) -> GenerativeFullResponse:
|
578
656
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
579
657
|
if model:
|
@@ -585,6 +663,7 @@ class AsyncNuaClient:
|
|
585
663
|
endpoint,
|
586
664
|
payload=body.model_dump(),
|
587
665
|
timeout=timeout,
|
666
|
+
extra_headers=extra_headers,
|
588
667
|
):
|
589
668
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
590
669
|
result.answer += chunk.chunk.text
|
@@ -602,11 +681,20 @@ class AsyncNuaClient:
|
|
602
681
|
result.code = chunk.chunk.code
|
603
682
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
604
683
|
result.tools = chunk.chunk.tools
|
684
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
685
|
+
result.consumption = Consumption(
|
686
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
687
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
688
|
+
)
|
605
689
|
|
606
690
|
return result
|
607
691
|
|
608
692
|
async def generate_stream(
|
609
|
-
self,
|
693
|
+
self,
|
694
|
+
body: ChatModel,
|
695
|
+
model: Optional[str] = None,
|
696
|
+
extra_headers: Optional[dict[str, str]] = None,
|
697
|
+
timeout: int = 300,
|
610
698
|
) -> AsyncIterator[GenerativeChunk]:
|
611
699
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
612
700
|
if model:
|
@@ -617,11 +705,16 @@ class AsyncNuaClient:
|
|
617
705
|
endpoint,
|
618
706
|
payload=body.model_dump(),
|
619
707
|
timeout=timeout,
|
708
|
+
extra_headers=extra_headers,
|
620
709
|
):
|
621
710
|
yield gr
|
622
711
|
|
623
712
|
async def summarize(
|
624
|
-
self,
|
713
|
+
self,
|
714
|
+
documents: dict[str, str],
|
715
|
+
model: Optional[str] = None,
|
716
|
+
extra_headers: Optional[dict[str, str]] = None,
|
717
|
+
timeout: int = 300,
|
625
718
|
) -> SummarizedModel:
|
626
719
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
627
720
|
if model:
|
@@ -639,6 +732,7 @@ class AsyncNuaClient:
|
|
639
732
|
payload=body.model_dump(),
|
640
733
|
output=SummarizedModel,
|
641
734
|
timeout=timeout,
|
735
|
+
extra_headers=extra_headers,
|
642
736
|
)
|
643
737
|
|
644
738
|
async def rephrase(
|
@@ -672,13 +766,18 @@ class AsyncNuaClient:
|
|
672
766
|
output=RephraseModel,
|
673
767
|
)
|
674
768
|
|
675
|
-
async def remi(
|
769
|
+
async def remi(
|
770
|
+
self,
|
771
|
+
request: RemiRequest,
|
772
|
+
extra_headers: Optional[dict[str, str]] = None,
|
773
|
+
) -> RemiResponse:
|
676
774
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
677
775
|
return await self._request(
|
678
776
|
"POST",
|
679
777
|
endpoint,
|
680
778
|
payload=request.model_dump(),
|
681
779
|
output=RemiResponse,
|
780
|
+
extra_headers=extra_headers,
|
682
781
|
)
|
683
782
|
|
684
783
|
async def generate_retrieval(
|
@@ -802,8 +901,14 @@ class AsyncNuaClient:
|
|
802
901
|
"GET", activity_endpoint, output=ProcessRequestStatus
|
803
902
|
)
|
804
903
|
|
805
|
-
async def rerank(
|
904
|
+
async def rerank(
|
905
|
+
self, model: RerankModel, extra_headers: Optional[dict[str, str]] = None
|
906
|
+
) -> RerankResponse:
|
806
907
|
endpoint = f"{self.url}{RERANK}"
|
807
908
|
return await self._request(
|
808
|
-
"POST",
|
909
|
+
"POST",
|
910
|
+
endpoint,
|
911
|
+
payload=model.model_dump(),
|
912
|
+
output=RerankResponse,
|
913
|
+
extra_headers=extra_headers,
|
809
914
|
)
|
nuclia/lib/nua_responses.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union, cast
|
|
5
5
|
import pydantic
|
6
6
|
from pydantic import BaseModel, Field, RootModel, model_validator
|
7
7
|
from typing_extensions import Annotated, Self
|
8
|
+
from nuclia_models.common.consumption import Consumption
|
8
9
|
|
9
10
|
|
10
11
|
class GenerativeOption(BaseModel):
|
@@ -51,6 +52,7 @@ class ConfigSchema(BaseModel):
|
|
51
52
|
class Sentence(BaseModel):
|
52
53
|
data: List[float]
|
53
54
|
time: float
|
55
|
+
consumption: Optional[Consumption] = None
|
54
56
|
|
55
57
|
|
56
58
|
class Author(str, Enum):
|
@@ -137,6 +139,7 @@ class Token(BaseModel):
|
|
137
139
|
class Tokens(BaseModel):
|
138
140
|
tokens: List[Token]
|
139
141
|
time: float
|
142
|
+
consumption: Optional[Consumption] = None
|
140
143
|
|
141
144
|
|
142
145
|
class SummarizeResource(BaseModel):
|
@@ -155,6 +158,7 @@ class SummarizedResource(BaseModel):
|
|
155
158
|
class SummarizedModel(BaseModel):
|
156
159
|
resources: Dict[str, SummarizedResource]
|
157
160
|
summary: str = ""
|
161
|
+
consumption: Optional[Consumption] = None
|
158
162
|
|
159
163
|
|
160
164
|
class RephraseModel(RootModel[str]):
|
@@ -535,6 +539,7 @@ class StoredLearningConfiguration(BaseModel):
|
|
535
539
|
class SentenceSearch(BaseModel):
|
536
540
|
data: List[float] = []
|
537
541
|
time: float
|
542
|
+
consumption: Optional[Consumption] = None
|
538
543
|
|
539
544
|
|
540
545
|
class Ner(BaseModel):
|
@@ -547,6 +552,7 @@ class Ner(BaseModel):
|
|
547
552
|
class TokenSearch(BaseModel):
|
548
553
|
tokens: List[Ner] = []
|
549
554
|
time: float
|
555
|
+
consumption: Optional[Consumption] = None
|
550
556
|
|
551
557
|
|
552
558
|
class QueryInfo(BaseModel):
|
@@ -569,3 +575,4 @@ class RerankResponse(BaseModel):
|
|
569
575
|
context_scores: dict[str, float] = Field(
|
570
576
|
description="Scores for each context given by the reranker"
|
571
577
|
)
|
578
|
+
consumption: Optional[Consumption] = None
|
nuclia/sdk/kb.py
CHANGED
@@ -297,9 +297,15 @@ class NucliaKB:
|
|
297
297
|
)
|
298
298
|
|
299
299
|
@kb
|
300
|
-
def summarize(
|
300
|
+
def summarize(
|
301
|
+
self, *, resources: List[str], show_consumption: bool = False, **kwargs
|
302
|
+
):
|
301
303
|
ndb: NucliaDBClient = kwargs["ndb"]
|
302
|
-
return ndb.ndb.summarize(
|
304
|
+
return ndb.ndb.summarize(
|
305
|
+
kbid=ndb.kbid,
|
306
|
+
resources=resources,
|
307
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
308
|
+
)
|
303
309
|
|
304
310
|
@kb
|
305
311
|
def notifications(self, **kwargs):
|
@@ -715,6 +721,7 @@ class AsyncNucliaKB:
|
|
715
721
|
resources: List[str],
|
716
722
|
generative_model: Optional[str] = None,
|
717
723
|
summary_kind: Optional[str] = None,
|
724
|
+
show_consumption: bool = False,
|
718
725
|
timeout: int = 1000,
|
719
726
|
**kwargs,
|
720
727
|
) -> SummarizedModel:
|
@@ -725,6 +732,7 @@ class AsyncNucliaKB:
|
|
725
732
|
generative_model=generative_model,
|
726
733
|
summary_kind=SummaryKind(summary_kind),
|
727
734
|
),
|
735
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
728
736
|
timeout=timeout,
|
729
737
|
)
|
730
738
|
return SummarizedModel.model_validate(resp.json())
|
nuclia/sdk/predict.py
CHANGED
@@ -52,9 +52,19 @@ class NucliaPredict:
|
|
52
52
|
nc.del_config_predict(kbid)
|
53
53
|
|
54
54
|
@nua
|
55
|
-
def sentence(
|
55
|
+
def sentence(
|
56
|
+
self,
|
57
|
+
text: str,
|
58
|
+
model: Optional[str] = None,
|
59
|
+
show_consumption: bool = False,
|
60
|
+
**kwargs,
|
61
|
+
) -> Sentence:
|
56
62
|
nc: NuaClient = kwargs["nc"]
|
57
|
-
return nc.sentence_predict(
|
63
|
+
return nc.sentence_predict(
|
64
|
+
text,
|
65
|
+
model,
|
66
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
67
|
+
)
|
58
68
|
|
59
69
|
@nua
|
60
70
|
def query(
|
@@ -63,19 +73,25 @@ class NucliaPredict:
|
|
63
73
|
semantic_model: Optional[str] = None,
|
64
74
|
token_model: Optional[str] = None,
|
65
75
|
generative_model: Optional[str] = None,
|
76
|
+
show_consumption: bool = False,
|
66
77
|
**kwargs,
|
67
78
|
) -> QueryInfo:
|
68
79
|
nc: NuaClient = kwargs["nc"]
|
69
80
|
return nc.query_predict(
|
70
|
-
text,
|
81
|
+
text=text,
|
71
82
|
semantic_model=semantic_model,
|
72
83
|
token_model=token_model,
|
73
84
|
generative_model=generative_model,
|
85
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
74
86
|
)
|
75
87
|
|
76
88
|
@nua
|
77
89
|
def generate(
|
78
|
-
self,
|
90
|
+
self,
|
91
|
+
text: Union[str, ChatModel],
|
92
|
+
model: Optional[str] = None,
|
93
|
+
show_consumption: bool = False,
|
94
|
+
**kwargs,
|
79
95
|
) -> GenerativeFullResponse:
|
80
96
|
nc: NuaClient = kwargs["nc"]
|
81
97
|
if isinstance(text, str):
|
@@ -88,11 +104,19 @@ class NucliaPredict:
|
|
88
104
|
else:
|
89
105
|
body = text
|
90
106
|
|
91
|
-
return nc.generate(
|
107
|
+
return nc.generate(
|
108
|
+
body=body,
|
109
|
+
model=model,
|
110
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
111
|
+
)
|
92
112
|
|
93
113
|
@nua
|
94
114
|
def generate_stream(
|
95
|
-
self,
|
115
|
+
self,
|
116
|
+
text: Union[str, ChatModel],
|
117
|
+
model: Optional[str] = None,
|
118
|
+
show_consumption: bool = False,
|
119
|
+
**kwargs,
|
96
120
|
) -> Iterator[GenerativeChunk]:
|
97
121
|
nc: NuaClient = kwargs["nc"]
|
98
122
|
if isinstance(text, str):
|
@@ -105,20 +129,42 @@ class NucliaPredict:
|
|
105
129
|
else:
|
106
130
|
body = text
|
107
131
|
|
108
|
-
for chunk in nc.generate_stream(
|
132
|
+
for chunk in nc.generate_stream(
|
133
|
+
body=body,
|
134
|
+
model=model,
|
135
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
136
|
+
):
|
109
137
|
yield chunk
|
110
138
|
|
111
139
|
@nua
|
112
|
-
def tokens(
|
140
|
+
def tokens(
|
141
|
+
self,
|
142
|
+
text: str,
|
143
|
+
model: Optional[str] = None,
|
144
|
+
show_consumption: bool = False,
|
145
|
+
**kwargs,
|
146
|
+
) -> Tokens:
|
113
147
|
nc: NuaClient = kwargs["nc"]
|
114
|
-
return nc.tokens_predict(
|
148
|
+
return nc.tokens_predict(
|
149
|
+
text,
|
150
|
+
model,
|
151
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
152
|
+
)
|
115
153
|
|
116
154
|
@nua
|
117
155
|
def summarize(
|
118
|
-
self,
|
156
|
+
self,
|
157
|
+
texts: dict[str, str],
|
158
|
+
model: Optional[str] = None,
|
159
|
+
show_consumption: bool = False,
|
160
|
+
**kwargs,
|
119
161
|
) -> SummarizedModel:
|
120
162
|
nc: NuaClient = kwargs["nc"]
|
121
|
-
return nc.summarize(
|
163
|
+
return nc.summarize(
|
164
|
+
documents=texts,
|
165
|
+
model=model,
|
166
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
167
|
+
)
|
122
168
|
|
123
169
|
@nua
|
124
170
|
def rephrase(
|
@@ -135,7 +181,12 @@ class NucliaPredict:
|
|
135
181
|
|
136
182
|
@nua
|
137
183
|
def rag(
|
138
|
-
self,
|
184
|
+
self,
|
185
|
+
question: str,
|
186
|
+
context: list[str],
|
187
|
+
model: Optional[str] = None,
|
188
|
+
show_consumption: bool = False,
|
189
|
+
**kwargs,
|
139
190
|
) -> GenerativeFullResponse:
|
140
191
|
nc: NuaClient = kwargs["nc"]
|
141
192
|
body = ChatModel(
|
@@ -145,10 +196,19 @@ class NucliaPredict:
|
|
145
196
|
query_context=context,
|
146
197
|
)
|
147
198
|
|
148
|
-
return nc.generate(
|
199
|
+
return nc.generate(
|
200
|
+
body,
|
201
|
+
model,
|
202
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
203
|
+
)
|
149
204
|
|
150
205
|
@nua
|
151
|
-
def remi(
|
206
|
+
def remi(
|
207
|
+
self,
|
208
|
+
request: Optional[RemiRequest] = None,
|
209
|
+
show_consumption: bool = False,
|
210
|
+
**kwargs,
|
211
|
+
) -> RemiResponse:
|
152
212
|
"""
|
153
213
|
Perform a REMi evaluation over a RAG experience
|
154
214
|
|
@@ -162,10 +222,15 @@ class NucliaPredict:
|
|
162
222
|
if request is None:
|
163
223
|
request = RemiRequest(**kwargs)
|
164
224
|
nc: NuaClient = kwargs["nc"]
|
165
|
-
return nc.remi(
|
225
|
+
return nc.remi(
|
226
|
+
request=request,
|
227
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
228
|
+
)
|
166
229
|
|
167
230
|
@nua
|
168
|
-
def rerank(
|
231
|
+
def rerank(
|
232
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
233
|
+
) -> RerankResponse:
|
169
234
|
"""
|
170
235
|
Perform a reranking of the results based on the question and context provided.
|
171
236
|
|
@@ -173,7 +238,10 @@ class NucliaPredict:
|
|
173
238
|
:return: RerankResponse
|
174
239
|
"""
|
175
240
|
nc: NuaClient = kwargs["nc"]
|
176
|
-
return nc.rerank(
|
241
|
+
return nc.rerank(
|
242
|
+
request,
|
243
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
244
|
+
)
|
177
245
|
|
178
246
|
|
179
247
|
class AsyncNucliaPredict:
|
@@ -208,14 +276,26 @@ class AsyncNucliaPredict:
|
|
208
276
|
|
209
277
|
@nua
|
210
278
|
async def sentence(
|
211
|
-
self,
|
279
|
+
self,
|
280
|
+
text: str,
|
281
|
+
model: Optional[str] = None,
|
282
|
+
show_consumption: bool = False,
|
283
|
+
**kwargs,
|
212
284
|
) -> Sentence:
|
213
285
|
nc: AsyncNuaClient = kwargs["nc"]
|
214
|
-
return await nc.sentence_predict(
|
286
|
+
return await nc.sentence_predict(
|
287
|
+
text,
|
288
|
+
model,
|
289
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
290
|
+
)
|
215
291
|
|
216
292
|
@nua
|
217
293
|
async def generate(
|
218
|
-
self,
|
294
|
+
self,
|
295
|
+
text: Union[str, ChatModel],
|
296
|
+
model: Optional[str] = None,
|
297
|
+
show_consumption: bool = False,
|
298
|
+
**kwargs,
|
219
299
|
) -> GenerativeFullResponse:
|
220
300
|
nc: AsyncNuaClient = kwargs["nc"]
|
221
301
|
if isinstance(text, str):
|
@@ -227,11 +307,19 @@ class AsyncNucliaPredict:
|
|
227
307
|
)
|
228
308
|
else:
|
229
309
|
body = text
|
230
|
-
return await nc.generate(
|
310
|
+
return await nc.generate(
|
311
|
+
body=body,
|
312
|
+
model=model,
|
313
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
314
|
+
)
|
231
315
|
|
232
316
|
@nua
|
233
317
|
async def generate_stream(
|
234
|
-
self,
|
318
|
+
self,
|
319
|
+
text: Union[str, ChatModel],
|
320
|
+
model: Optional[str] = None,
|
321
|
+
show_consumption: bool = False,
|
322
|
+
**kwargs,
|
235
323
|
) -> AsyncIterator[GenerativeChunk]:
|
236
324
|
nc: AsyncNuaClient = kwargs["nc"]
|
237
325
|
if isinstance(text, str):
|
@@ -244,13 +332,27 @@ class AsyncNucliaPredict:
|
|
244
332
|
else:
|
245
333
|
body = text
|
246
334
|
|
247
|
-
async for chunk in nc.generate_stream(
|
335
|
+
async for chunk in nc.generate_stream(
|
336
|
+
body=body,
|
337
|
+
model=model,
|
338
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
339
|
+
):
|
248
340
|
yield chunk
|
249
341
|
|
250
342
|
@nua
|
251
|
-
async def tokens(
|
343
|
+
async def tokens(
|
344
|
+
self,
|
345
|
+
text: str,
|
346
|
+
model: Optional[str] = None,
|
347
|
+
show_consumption: bool = False,
|
348
|
+
**kwargs,
|
349
|
+
) -> Tokens:
|
252
350
|
nc: AsyncNuaClient = kwargs["nc"]
|
253
|
-
return await nc.tokens_predict(
|
351
|
+
return await nc.tokens_predict(
|
352
|
+
text,
|
353
|
+
model,
|
354
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
355
|
+
)
|
254
356
|
|
255
357
|
@nua
|
256
358
|
async def query(
|
@@ -259,22 +361,32 @@ class AsyncNucliaPredict:
|
|
259
361
|
semantic_model: Optional[str] = None,
|
260
362
|
token_model: Optional[str] = None,
|
261
363
|
generative_model: Optional[str] = None,
|
364
|
+
show_consumption: bool = False,
|
262
365
|
**kwargs,
|
263
366
|
) -> QueryInfo:
|
264
367
|
nc: AsyncNuaClient = kwargs["nc"]
|
265
368
|
return await nc.query_predict(
|
266
|
-
text,
|
369
|
+
text=text,
|
267
370
|
semantic_model=semantic_model,
|
268
371
|
token_model=token_model,
|
269
372
|
generative_model=generative_model,
|
373
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
270
374
|
)
|
271
375
|
|
272
376
|
@nua
|
273
377
|
async def summarize(
|
274
|
-
self,
|
378
|
+
self,
|
379
|
+
texts: dict[str, str],
|
380
|
+
model: Optional[str] = None,
|
381
|
+
show_consumption: bool = False,
|
382
|
+
**kwargs,
|
275
383
|
) -> SummarizedModel:
|
276
384
|
nc: AsyncNuaClient = kwargs["nc"]
|
277
|
-
return await nc.summarize(
|
385
|
+
return await nc.summarize(
|
386
|
+
documents=texts,
|
387
|
+
model=model,
|
388
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
389
|
+
)
|
278
390
|
|
279
391
|
@nua
|
280
392
|
async def rephrase(
|
@@ -298,7 +410,10 @@ class AsyncNucliaPredict:
|
|
298
410
|
|
299
411
|
@nua
|
300
412
|
async def remi(
|
301
|
-
self,
|
413
|
+
self,
|
414
|
+
request: Optional[RemiRequest] = None,
|
415
|
+
show_consumption: bool = False,
|
416
|
+
**kwargs,
|
302
417
|
) -> RemiResponse:
|
303
418
|
"""
|
304
419
|
Perform a REMi evaluation over a RAG experience
|
@@ -311,10 +426,15 @@ class AsyncNucliaPredict:
|
|
311
426
|
request = RemiRequest(**kwargs)
|
312
427
|
|
313
428
|
nc: AsyncNuaClient = kwargs["nc"]
|
314
|
-
return await nc.remi(
|
429
|
+
return await nc.remi(
|
430
|
+
request=request,
|
431
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
432
|
+
)
|
315
433
|
|
316
434
|
@nua
|
317
|
-
async def rerank(
|
435
|
+
async def rerank(
|
436
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
437
|
+
) -> RerankResponse:
|
318
438
|
"""
|
319
439
|
Perform a reranking of the results based on the question and context provided.
|
320
440
|
|
@@ -322,4 +442,7 @@ class AsyncNucliaPredict:
|
|
322
442
|
:return: RerankResponse
|
323
443
|
"""
|
324
444
|
nc: AsyncNuaClient = kwargs["nc"]
|
325
|
-
return await nc.rerank(
|
445
|
+
return await nc.rerank(
|
446
|
+
request,
|
447
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
448
|
+
)
|
nuclia/sdk/search.py
CHANGED
@@ -30,6 +30,7 @@ from nuclia.lib.kb import AsyncNucliaDBClient, NucliaDBClient
|
|
30
30
|
from nuclia.sdk.logger import logger
|
31
31
|
from nuclia.sdk.auth import AsyncNucliaAuth, NucliaAuth
|
32
32
|
from nuclia.sdk.resource import RagImagesStrategiesParse, RagStrategiesParse
|
33
|
+
from nuclia_models.common.consumption import Consumption, TokensDetail
|
33
34
|
|
34
35
|
|
35
36
|
@dataclass
|
@@ -49,6 +50,7 @@ class AskAnswer:
|
|
49
50
|
relations: Optional[Relations]
|
50
51
|
predict_request: Optional[ChatModel]
|
51
52
|
error_details: Optional[str]
|
53
|
+
consumption: Optional[Consumption]
|
52
54
|
|
53
55
|
def __str__(self):
|
54
56
|
if self.answer:
|
@@ -184,8 +186,9 @@ class NucliaSearch:
|
|
184
186
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
185
187
|
rag_strategies: Optional[list[RagStrategies]] = None,
|
186
188
|
rag_images_strategies: Optional[list[RagImagesStrategies]] = None,
|
189
|
+
show_consumption: bool = False,
|
187
190
|
**kwargs,
|
188
|
-
):
|
191
|
+
) -> AskAnswer:
|
189
192
|
"""
|
190
193
|
Answer a question.
|
191
194
|
|
@@ -217,7 +220,11 @@ class NucliaSearch:
|
|
217
220
|
else:
|
218
221
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
219
222
|
|
220
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
223
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
224
|
+
kbid=ndb.kbid,
|
225
|
+
content=req,
|
226
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
227
|
+
)
|
221
228
|
|
222
229
|
result = AskAnswer(
|
223
230
|
answer=ask_response.answer.encode(),
|
@@ -239,6 +246,7 @@ class NucliaSearch:
|
|
239
246
|
else None,
|
240
247
|
relations=ask_response.relations,
|
241
248
|
prompt_context=ask_response.prompt_context,
|
249
|
+
consumption=ask_response.consumption,
|
242
250
|
)
|
243
251
|
|
244
252
|
if ask_response.prompt_context:
|
@@ -257,8 +265,9 @@ class NucliaSearch:
|
|
257
265
|
schema: Union[str, Dict[str, Any]],
|
258
266
|
query: Union[str, dict, AskRequest, None] = None,
|
259
267
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
268
|
+
show_consumption: bool = False,
|
260
269
|
**kwargs,
|
261
|
-
):
|
270
|
+
) -> Optional[AskAnswer]:
|
262
271
|
"""
|
263
272
|
Answer a question.
|
264
273
|
|
@@ -272,10 +281,10 @@ class NucliaSearch:
|
|
272
281
|
schema_json = json.load(json_file_handler)
|
273
282
|
except Exception:
|
274
283
|
logger.exception("File format is not JSON")
|
275
|
-
return
|
284
|
+
return None
|
276
285
|
else:
|
277
286
|
logger.exception("File not found")
|
278
|
-
return
|
287
|
+
return None
|
279
288
|
else:
|
280
289
|
schema_json = schema
|
281
290
|
|
@@ -303,7 +312,11 @@ class NucliaSearch:
|
|
303
312
|
req.filters = filters
|
304
313
|
else:
|
305
314
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
306
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
315
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
316
|
+
kbid=ndb.kbid,
|
317
|
+
content=req,
|
318
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
319
|
+
)
|
307
320
|
|
308
321
|
result = AskAnswer(
|
309
322
|
answer=ask_response.answer.encode(),
|
@@ -325,6 +338,7 @@ class NucliaSearch:
|
|
325
338
|
else None,
|
326
339
|
relations=ask_response.relations,
|
327
340
|
prompt_context=ask_response.prompt_context,
|
341
|
+
consumption=ask_response.consumption,
|
328
342
|
)
|
329
343
|
if ask_response.metadata is not None:
|
330
344
|
if ask_response.metadata.timings is not None:
|
@@ -483,9 +497,10 @@ class AsyncNucliaSearch:
|
|
483
497
|
*,
|
484
498
|
query: Union[str, dict, AskRequest],
|
485
499
|
filters: Optional[List[str]] = None,
|
500
|
+
show_consumption: bool = False,
|
486
501
|
timeout: int = 100,
|
487
502
|
**kwargs,
|
488
|
-
):
|
503
|
+
) -> AskAnswer:
|
489
504
|
"""
|
490
505
|
Answer a question.
|
491
506
|
|
@@ -509,7 +524,11 @@ class AsyncNucliaSearch:
|
|
509
524
|
req = query
|
510
525
|
else:
|
511
526
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
512
|
-
ask_stream_response = await ndb.ask(
|
527
|
+
ask_stream_response = await ndb.ask(
|
528
|
+
req,
|
529
|
+
timeout=timeout,
|
530
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
531
|
+
)
|
513
532
|
result = AskAnswer(
|
514
533
|
answer=b"",
|
515
534
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -526,6 +545,7 @@ class AsyncNucliaSearch:
|
|
526
545
|
predict_request=None,
|
527
546
|
relations=None,
|
528
547
|
prompt_context=None,
|
548
|
+
consumption=None,
|
529
549
|
)
|
530
550
|
async for line in ask_stream_response.aiter_lines():
|
531
551
|
try:
|
@@ -548,6 +568,19 @@ class AsyncNucliaSearch:
|
|
548
568
|
result.timings = ask_response_item.timings.model_dump()
|
549
569
|
if ask_response_item.tokens:
|
550
570
|
result.tokens = ask_response_item.tokens.model_dump()
|
571
|
+
elif ask_response_item.type == "consumption":
|
572
|
+
result.consumption = Consumption(
|
573
|
+
normalized_tokens=TokensDetail(
|
574
|
+
input=ask_response_item.normalized_tokens.input,
|
575
|
+
output=ask_response_item.normalized_tokens.output,
|
576
|
+
image=ask_response_item.normalized_tokens.image,
|
577
|
+
),
|
578
|
+
customer_key_tokens=TokensDetail(
|
579
|
+
input=ask_response_item.customer_key_tokens.input,
|
580
|
+
output=ask_response_item.customer_key_tokens.output,
|
581
|
+
image=ask_response_item.customer_key_tokens.image,
|
582
|
+
),
|
583
|
+
)
|
551
584
|
elif ask_response_item.type == "status":
|
552
585
|
result.status = ask_response_item.status
|
553
586
|
elif ask_response_item.type == "prequeries":
|
@@ -569,6 +602,7 @@ class AsyncNucliaSearch:
|
|
569
602
|
*,
|
570
603
|
query: Union[str, dict, AskRequest],
|
571
604
|
filters: Optional[List[str]] = None,
|
605
|
+
show_consumption: bool = False,
|
572
606
|
timeout: int = 100,
|
573
607
|
**kwargs,
|
574
608
|
) -> AsyncIterator[AskResponseItem]:
|
@@ -593,7 +627,11 @@ class AsyncNucliaSearch:
|
|
593
627
|
req = query
|
594
628
|
else:
|
595
629
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
596
|
-
ask_stream_response = await ndb.ask(
|
630
|
+
ask_stream_response = await ndb.ask(
|
631
|
+
req,
|
632
|
+
timeout=timeout,
|
633
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
634
|
+
)
|
597
635
|
async for line in ask_stream_response.aiter_lines():
|
598
636
|
try:
|
599
637
|
ask_response_item = AskResponseItem.model_validate_json(line)
|
@@ -609,9 +647,10 @@ class AsyncNucliaSearch:
|
|
609
647
|
query: Union[str, dict, AskRequest],
|
610
648
|
schema: Dict[str, Any],
|
611
649
|
filters: Optional[List[str]] = None,
|
650
|
+
show_consumption: bool = False,
|
612
651
|
timeout: int = 100,
|
613
652
|
**kwargs,
|
614
|
-
):
|
653
|
+
) -> AskAnswer:
|
615
654
|
"""
|
616
655
|
Answer a question.
|
617
656
|
|
@@ -635,7 +674,11 @@ class AsyncNucliaSearch:
|
|
635
674
|
req = query
|
636
675
|
else:
|
637
676
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
638
|
-
ask_stream_response = await ndb.ask(
|
677
|
+
ask_stream_response = await ndb.ask(
|
678
|
+
req,
|
679
|
+
timeout=timeout,
|
680
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
681
|
+
)
|
639
682
|
result = AskAnswer(
|
640
683
|
answer=b"",
|
641
684
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -652,6 +695,7 @@ class AsyncNucliaSearch:
|
|
652
695
|
predict_request=None,
|
653
696
|
relations=None,
|
654
697
|
prompt_context=None,
|
698
|
+
consumption=None,
|
655
699
|
)
|
656
700
|
async for line in ask_stream_response.aiter_lines():
|
657
701
|
try:
|
@@ -674,6 +718,19 @@ class AsyncNucliaSearch:
|
|
674
718
|
result.timings = ask_response_item.timings.model_dump()
|
675
719
|
if ask_response_item.tokens:
|
676
720
|
result.tokens = ask_response_item.tokens.model_dump()
|
721
|
+
elif ask_response_item.type == "consumption":
|
722
|
+
result.consumption = Consumption(
|
723
|
+
normalized_tokens=TokensDetail(
|
724
|
+
input=ask_response_item.normalized_tokens.input,
|
725
|
+
output=ask_response_item.normalized_tokens.output,
|
726
|
+
image=ask_response_item.normalized_tokens.image,
|
727
|
+
),
|
728
|
+
customer_key_tokens=TokensDetail(
|
729
|
+
input=ask_response_item.customer_key_tokens.input,
|
730
|
+
output=ask_response_item.customer_key_tokens.output,
|
731
|
+
image=ask_response_item.customer_key_tokens.image,
|
732
|
+
),
|
733
|
+
)
|
677
734
|
elif ask_response_item.type == "status":
|
678
735
|
result.status = ask_response_item.status
|
679
736
|
elif ask_response_item.type == "prequeries":
|
@@ -132,7 +132,11 @@ def test_ask_json(testing_config):
|
|
132
132
|
async def test_ask_json_async(testing_config):
|
133
133
|
search = AsyncNucliaSearch()
|
134
134
|
results = await search.ask_json(
|
135
|
-
query="Who is hedy Lamarr?",
|
135
|
+
query="Who is hedy Lamarr?",
|
136
|
+
filters=["/icon/application/pdf"],
|
137
|
+
schema=SCHEMA,
|
138
|
+
show_consumption=True,
|
136
139
|
)
|
137
140
|
|
138
141
|
assert "TECHNOLOGY" in results.object["document_type"]
|
142
|
+
assert results.consumption is not None
|
@@ -1,4 +1,7 @@
|
|
1
|
-
from nuclia_models.predict.generative_responses import
|
1
|
+
from nuclia_models.predict.generative_responses import (
|
2
|
+
TextGenerativeResponse,
|
3
|
+
ConsumptionGenerative,
|
4
|
+
)
|
2
5
|
|
3
6
|
from nuclia.lib.nua_responses import ChatModel, RerankModel, UserPrompt
|
4
7
|
from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict
|
@@ -8,9 +11,12 @@ from nuclia_models.predict.remi import RemiRequest
|
|
8
11
|
|
9
12
|
def test_predict(testing_config):
|
10
13
|
np = NucliaPredict()
|
11
|
-
embed = np.sentence(
|
14
|
+
embed = np.sentence(
|
15
|
+
text="This is my text", model="multilingual-2024-05-06", show_consumption=True
|
16
|
+
)
|
12
17
|
assert embed.time > 0
|
13
18
|
assert len(embed.data) == 1024
|
19
|
+
assert embed.consumption is not None
|
14
20
|
|
15
21
|
|
16
22
|
def test_predict_query(testing_config):
|
@@ -20,11 +26,14 @@ def test_predict_query(testing_config):
|
|
20
26
|
semantic_model="multilingual-2024-05-06",
|
21
27
|
token_model="multilingual",
|
22
28
|
generative_model="chatgpt-azure-4o-mini",
|
29
|
+
show_consumption=True,
|
23
30
|
)
|
24
31
|
assert query.language == "en"
|
25
32
|
assert query.visual_llm is True
|
26
33
|
assert query.entities and query.entities.tokens[0].text == "Ramon"
|
27
34
|
assert query.sentence and len(query.sentence.data) == 1024
|
35
|
+
assert query.entities.consumption is not None
|
36
|
+
assert query.sentence.consumption is not None
|
28
37
|
|
29
38
|
|
30
39
|
def test_rag(testing_config):
|
@@ -36,14 +45,34 @@ def test_rag(testing_config):
|
|
36
45
|
"Eudald Camprubí is CEO at the same company as Ramon Navarro",
|
37
46
|
],
|
38
47
|
model="chatgpt-azure-4o-mini",
|
48
|
+
show_consumption=True,
|
39
49
|
)
|
40
50
|
assert "Eudald" in generated.answer
|
51
|
+
assert generated.consumption is not None
|
41
52
|
|
42
53
|
|
43
54
|
def test_generative(testing_config):
|
44
55
|
np = NucliaPredict()
|
45
56
|
generated = np.generate(text="How much is 2 + 2?", model="chatgpt-azure-4o-mini")
|
46
57
|
assert "4" in generated.answer
|
58
|
+
assert generated.consumption is None
|
59
|
+
|
60
|
+
|
61
|
+
@pytest.mark.asyncio
|
62
|
+
async def test_generative_with_consumption(testing_config):
|
63
|
+
np = NucliaPredict()
|
64
|
+
generated = np.generate(
|
65
|
+
text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
|
66
|
+
)
|
67
|
+
assert "4" in generated.answer
|
68
|
+
assert generated.consumption is not None
|
69
|
+
|
70
|
+
anp = AsyncNucliaPredict()
|
71
|
+
async_generated = await anp.generate(
|
72
|
+
text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
|
73
|
+
)
|
74
|
+
assert "4" in async_generated.answer
|
75
|
+
assert async_generated.consumption is not None
|
47
76
|
|
48
77
|
|
49
78
|
@pytest.mark.asyncio
|
@@ -70,13 +99,18 @@ def test_stream_generative(testing_config):
|
|
70
99
|
@pytest.mark.asyncio
|
71
100
|
async def test_async_stream_generative(testing_config):
|
72
101
|
np = AsyncNucliaPredict()
|
102
|
+
consumption_found = False
|
103
|
+
found = False
|
73
104
|
async for stream in np.generate_stream(
|
74
|
-
text="How much is 2 + 2?", model="chatgpt-azure-4o-mini"
|
105
|
+
text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
|
75
106
|
):
|
76
107
|
if isinstance(stream.chunk, TextGenerativeResponse) and stream.chunk.text:
|
77
108
|
if "4" in stream.chunk.text:
|
78
109
|
found = True
|
110
|
+
elif isinstance(stream.chunk, ConsumptionGenerative):
|
111
|
+
consumption_found = True
|
79
112
|
assert found
|
113
|
+
assert consumption_found
|
80
114
|
|
81
115
|
|
82
116
|
SCHEMA = {
|
@@ -148,6 +182,8 @@ def test_nua_remi(testing_config):
|
|
148
182
|
assert results.context_relevance[1] < 2
|
149
183
|
assert results.groundedness[1] < 2
|
150
184
|
|
185
|
+
assert results.consumption is None
|
186
|
+
|
151
187
|
|
152
188
|
@pytest.mark.asyncio
|
153
189
|
async def test_nua_async_remi(testing_config):
|
@@ -161,7 +197,8 @@ async def test_nua_async_remi(testing_config):
|
|
161
197
|
"Paris is the capital of France.",
|
162
198
|
"Berlin is the capital of Germany.",
|
163
199
|
],
|
164
|
-
)
|
200
|
+
),
|
201
|
+
show_consumption=True,
|
165
202
|
)
|
166
203
|
assert results.answer_relevance.score >= 4
|
167
204
|
|
@@ -171,6 +208,8 @@ async def test_nua_async_remi(testing_config):
|
|
171
208
|
assert results.context_relevance[1] < 2
|
172
209
|
assert results.groundedness[1] < 2
|
173
210
|
|
211
|
+
assert results.consumption is not None
|
212
|
+
|
174
213
|
|
175
214
|
def test_nua_rerank(testing_config):
|
176
215
|
np = NucliaPredict()
|
@@ -185,3 +224,37 @@ def test_nua_rerank(testing_config):
|
|
185
224
|
)
|
186
225
|
)
|
187
226
|
assert results.context_scores["1"] > results.context_scores["2"]
|
227
|
+
assert results.consumption is None
|
228
|
+
|
229
|
+
|
230
|
+
@pytest.mark.asyncio
|
231
|
+
async def test_nua_rerank_with_consumption(testing_config):
|
232
|
+
np = NucliaPredict()
|
233
|
+
results = np.rerank(
|
234
|
+
RerankModel(
|
235
|
+
user_id="Nuclia PY CLI",
|
236
|
+
question="What is the capital of France?",
|
237
|
+
context={
|
238
|
+
"1": "Paris is the capital of France.",
|
239
|
+
"2": "Berlin is the capital of Germany.",
|
240
|
+
},
|
241
|
+
),
|
242
|
+
show_consumption=True,
|
243
|
+
)
|
244
|
+
assert results.context_scores["1"] > results.context_scores["2"]
|
245
|
+
assert results.consumption is not None
|
246
|
+
|
247
|
+
anp = AsyncNucliaPredict()
|
248
|
+
async_results = await anp.rerank(
|
249
|
+
RerankModel(
|
250
|
+
user_id="Nuclia PY CLI",
|
251
|
+
question="What is the capital of France?",
|
252
|
+
context={
|
253
|
+
"1": "Paris is the capital of France.",
|
254
|
+
"2": "Berlin is the capital of Germany.",
|
255
|
+
},
|
256
|
+
),
|
257
|
+
show_consumption=True,
|
258
|
+
)
|
259
|
+
assert async_results.context_scores["1"] > async_results.context_scores["2"]
|
260
|
+
assert async_results.consumption is not None
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nuclia
|
3
|
-
Version: 4.9.
|
3
|
+
Version: 4.9.4
|
4
4
|
Summary: Nuclia Python SDK
|
5
5
|
Author-email: Nuclia <info@nuclia.com>
|
6
6
|
License-Expression: MIT
|
@@ -24,9 +24,9 @@ Requires-Dist: requests
|
|
24
24
|
Requires-Dist: httpx
|
25
25
|
Requires-Dist: httpcore>=1.0.0
|
26
26
|
Requires-Dist: prompt_toolkit
|
27
|
-
Requires-Dist: nucliadb_sdk<7,>=6.
|
28
|
-
Requires-Dist: nucliadb_models<7,>=6.
|
29
|
-
Requires-Dist: nuclia-models>=0.
|
27
|
+
Requires-Dist: nucliadb_sdk<7,>=6.6.1
|
28
|
+
Requires-Dist: nucliadb_models<7,>=6.6.1
|
29
|
+
Requires-Dist: nuclia-models>=0.45.0
|
30
30
|
Requires-Dist: tqdm
|
31
31
|
Requires-Dist: aiofiles
|
32
32
|
Requires-Dist: backoff
|
@@ -9,11 +9,11 @@ nuclia/cli/run.py,sha256=B1hP0upSbSCqqT89WAwsd93ZxkAoF6ajVyLOdYmo8fU,1560
|
|
9
9
|
nuclia/cli/utils.py,sha256=iZ3P8juBdAGvaRUd2BGz7bpUXNDHdPrC5p876yyZ2Cs,1223
|
10
10
|
nuclia/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
nuclia/lib/conversations.py,sha256=M6qhL9NPEKroYF767S-Q2XWokRrjX02kpYTzRvZKwUE,149
|
12
|
-
nuclia/lib/kb.py,sha256=
|
12
|
+
nuclia/lib/kb.py,sha256=hbWg-0pX1OCt9ezUXYcOfyWHahgmZLUQe2nlSycHetU,31364
|
13
13
|
nuclia/lib/models.py,sha256=ekEQrVIFU3aFvt60yQh-zpWkGNORBMSc7c5Hd_VzPzI,1564
|
14
|
-
nuclia/lib/nua.py,sha256=
|
14
|
+
nuclia/lib/nua.py,sha256=qP6Gv9ck117kdqNuKVl_qHodrDE8CS8xdh_bdzeJMiQ,30631
|
15
15
|
nuclia/lib/nua_chat.py,sha256=ApL1Y1FWvAVUt-Y9a_8TUSJIhg8-UmBSy8TlDPn6tD8,3874
|
16
|
-
nuclia/lib/nua_responses.py,sha256=
|
16
|
+
nuclia/lib/nua_responses.py,sha256=KHAjoW3fu2QT7u4F4WfL4DQIM8lCE5W_14vuRtCmtZg,13970
|
17
17
|
nuclia/lib/utils.py,sha256=9l6DxBk-11WqUhXRn99cqeuVTUOJXj-1S6zckal7wOk,6312
|
18
18
|
nuclia/sdk/__init__.py,sha256=-nAw8i53XBdmbfTa1FJZ0FNRMNakimDVpD6W4OdES-c,1374
|
19
19
|
nuclia/sdk/accounts.py,sha256=7XQ3K9_jlSuk2Cez868FtazZ05xSGab6h3Mt1qMMwIE,647
|
@@ -22,17 +22,17 @@ nuclia/sdk/auth.py,sha256=o9CiWt3P_UHQW_L5Jq_9IOB-KpSCXFsVTouQcZp8BJM,25072
|
|
22
22
|
nuclia/sdk/backup.py,sha256=adbPcNEbHGZW698o028toXKfDkDrmk5QRIDSiN6SPys,6529
|
23
23
|
nuclia/sdk/export_import.py,sha256=y5cTOxhILwRPIvR2Ya12bk-ReGbeDzA3C9TPxgnOHD4,9756
|
24
24
|
nuclia/sdk/extract_strategy.py,sha256=NZBLLThdLyQYw8z1mT9iRhFjkE5sQP86-3QhsTiyV9o,2540
|
25
|
-
nuclia/sdk/kb.py,sha256=
|
25
|
+
nuclia/sdk/kb.py,sha256=rY1yVx_9S01twMXgrkpKSNpPChTupwYEw7EAIOwzLYA,27321
|
26
26
|
nuclia/sdk/kbs.py,sha256=nXEvg5ddZYdDS8Kie7TrN-s1meU9ecYLf9FlT5xr-ro,9131
|
27
27
|
nuclia/sdk/logger.py,sha256=UHB81eS6IGmLrsofKxLh8cmF2AsaTj_HXP0tGqMr_HM,57
|
28
28
|
nuclia/sdk/logs.py,sha256=3jfORpo8fzZiXFFSbGY0o3Bre1ZgJaKQCXgxP1keNHw,9614
|
29
29
|
nuclia/sdk/nua.py,sha256=6t0m0Sx-UhqNU2Hx9v6vTwy0m3a30K4T0KmP9G43MzY,293
|
30
30
|
nuclia/sdk/nucliadb.py,sha256=bOESIppPgY7IrNqrYY7T3ESoxwttbOSTm5zj1xUS1jI,1288
|
31
|
-
nuclia/sdk/predict.py,sha256=
|
31
|
+
nuclia/sdk/predict.py,sha256=ZQMtoeQBXaPuaJRxOmO23nnKkVKRpNCMi0JgXWSJ9lM,12836
|
32
32
|
nuclia/sdk/process.py,sha256=WuNnqaWprp-EABWDC_z7O2woesGIlYWnDUKozh7Ibr4,2241
|
33
33
|
nuclia/sdk/remi.py,sha256=BEb3O9R2jOFlOda4vjFucKKGO1c2eTkqYZdFlIy3Zmo,4357
|
34
34
|
nuclia/sdk/resource.py,sha256=0lSvD4e1FpN5iM9W295dOKLJ8hsXfIe8HKdo0HsVg20,13976
|
35
|
-
nuclia/sdk/search.py,sha256=
|
35
|
+
nuclia/sdk/search.py,sha256=tHpx0gwVwSepa3dn7AX2iWs7JO0pQkjJk7v8TEoL7Gg,27742
|
36
36
|
nuclia/sdk/task.py,sha256=UawH-7IneRIGVOiLdDQ2vDBwe5eI51UXRbLRRVeu3C4,6095
|
37
37
|
nuclia/sdk/upload.py,sha256=ZBzYROF3yP-77HcaR06OBsFjJAbTOCvF-nlxaqQZsT4,22720
|
38
38
|
nuclia/sdk/zones.py,sha256=1ARWrTsTuzj8zguanpX3OaIw-3Qq_ULS_g4GG2mHxOA,342
|
@@ -49,7 +49,7 @@ nuclia/tests/test_kb/test_labels.py,sha256=IUdTq4mzv0OrOkwBWWy4UwKGKyJybtoHrgvXr
|
|
49
49
|
nuclia/tests/test_kb/test_logs.py,sha256=Z9ELtiiU9NniITJzeWt92GCcERKYy9Nwc_fUVPboRU0,3121
|
50
50
|
nuclia/tests/test_kb/test_remi.py,sha256=OX5N-MHbgcwpLg6fBjrAK_KhqkMspJo_VKQHCBCayZ8,2080
|
51
51
|
nuclia/tests/test_kb/test_resource.py,sha256=05Xgmg5fwcPW2PZKnUSSjr6MPXp5w8XDgx8plfNcR68,1102
|
52
|
-
nuclia/tests/test_kb/test_search.py,sha256=
|
52
|
+
nuclia/tests/test_kb/test_search.py,sha256=8v3u-VcczBCTc-rHlWVM5mUtLJKNBZSM0V4v2I1uE3E,4751
|
53
53
|
nuclia/tests/test_kb/test_tasks.py,sha256=tfJl1js2o0_dKyXdOxiwdj929kbzMkfl5XjO8HVMInQ,3456
|
54
54
|
nuclia/tests/test_manage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
nuclia/tests/test_manage/test_account.py,sha256=u9NhRK8gJLS7BEY618aGoYoV2rgDLHZUeSsWWYkDhF8,289
|
@@ -57,15 +57,15 @@ nuclia/tests/test_manage/test_auth.py,sha256=I5ho9rKhrzCiP57cDVcs2zrXyy1uWlKxzfE
|
|
57
57
|
nuclia/tests/test_manage/test_kb.py,sha256=gqBMxmIuPUKC6kP2V_IapS9u72QR99V8CDQlJOa8sHU,1959
|
58
58
|
nuclia/tests/test_nua/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
nuclia/tests/test_nua/test_agent.py,sha256=iPA8lVw7CyONS-0fVG7rDzv8T-LnUlNPMNynTnP-2ic,334
|
60
|
-
nuclia/tests/test_nua/test_predict.py,sha256=
|
60
|
+
nuclia/tests/test_nua/test_predict.py,sha256=EKgpvHZHA4pf3ENE17j0B7CKewqHeJryR7D8JSTGJMo,8154
|
61
61
|
nuclia/tests/test_nucliadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
62
62
|
nuclia/tests/test_nucliadb/test_crud.py,sha256=GuY76HRvt2DFaNgioKm5n0Aco1HnG7zzV_zKom5N8xc,1173
|
63
63
|
nuclia/tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
64
64
|
nuclia/tests/unit/test_export_import.py,sha256=xo_wVbjUnNlVV65ZGH7LtZ38qy39EkJp2hjOuTHC1nU,980
|
65
65
|
nuclia/tests/unit/test_nua_responses.py,sha256=t_hIdVztTi27RWvpfTJUYcCL0lpKdZFegZIwLdaPNh8,319
|
66
|
-
nuclia-4.9.
|
67
|
-
nuclia-4.9.
|
68
|
-
nuclia-4.9.
|
69
|
-
nuclia-4.9.
|
70
|
-
nuclia-4.9.
|
71
|
-
nuclia-4.9.
|
66
|
+
nuclia-4.9.4.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
|
67
|
+
nuclia-4.9.4.dist-info/METADATA,sha256=2fk95GmG2Qa-GGAOPxWXWgA-gAGv-vEIHMRftKOsgA0,2341
|
68
|
+
nuclia-4.9.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
69
|
+
nuclia-4.9.4.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
|
70
|
+
nuclia-4.9.4.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
|
71
|
+
nuclia-4.9.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|