nuclia 4.9.3__py3-none-any.whl → 4.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nuclia/lib/kb.py CHANGED
@@ -324,13 +324,18 @@ class NucliaDBClient(BaseNucliaDBClient):
324
324
  handle_http_sync_errors(response)
325
325
  return int(response.headers.get("Upload-Offset"))
326
326
 
327
- def summarize(self, request: SummarizeRequest, timeout: int = 1000):
327
+ def summarize(
328
+ self,
329
+ request: SummarizeRequest,
330
+ extra_headers: Optional[dict[str, str]] = None,
331
+ timeout: int = 1000,
332
+ ):
328
333
  if self.url is None or self.writer_session is None:
329
334
  raise Exception("KB not configured")
330
335
  url = f"{self.url}{SUMMARIZE_URL}"
331
336
  assert self.reader_session
332
337
  response = self.reader_session.post(
333
- url, json=request.model_dump(), timeout=timeout
338
+ url, json=request.model_dump(), headers=extra_headers, timeout=timeout
334
339
  )
335
340
  handle_http_sync_errors(response)
336
341
  return response
@@ -569,12 +574,21 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
569
574
  await handle_http_async_errors(response)
570
575
  return response
571
576
 
572
- async def ask(self, request: AskRequest, timeout: int = 1000):
577
+ async def ask(
578
+ self,
579
+ request: AskRequest,
580
+ extra_headers: Optional[dict[str, str]] = None,
581
+ timeout: int = 1000,
582
+ ):
573
583
  if self.url is None or self.reader_session is None:
574
584
  raise Exception("KB not configured")
575
585
  url = f"{self.url}{ASK_URL}"
576
586
  req = self.reader_session.build_request(
577
- "POST", url, json=request.model_dump(), timeout=timeout
587
+ "POST",
588
+ url,
589
+ json=request.model_dump(),
590
+ headers=extra_headers,
591
+ timeout=timeout,
578
592
  )
579
593
  response = await self.reader_session.send(req, stream=True)
580
594
  await handle_http_async_errors(response)
@@ -681,13 +695,18 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
681
695
  await handle_http_async_errors(response)
682
696
  return int(response.headers.get("Upload-Offset"))
683
697
 
684
- async def summarize(self, request: SummarizeRequest, timeout: int = 1000):
698
+ async def summarize(
699
+ self,
700
+ request: SummarizeRequest,
701
+ extra_headers: Optional[dict[str, str]] = None,
702
+ timeout: int = 1000,
703
+ ):
685
704
  if self.url is None or self.writer_session is None:
686
705
  raise Exception("KB not configured")
687
706
  url = f"{self.url}{SUMMARIZE_URL}"
688
707
  assert self.reader_session
689
708
  response = await self.reader_session.post(
690
- url, json=request.model_dump(), timeout=timeout
709
+ url, json=request.model_dump(), headers=extra_headers, timeout=timeout
691
710
  )
692
711
  await handle_http_async_errors(response)
693
712
  return response
nuclia/lib/nua.py CHANGED
@@ -18,6 +18,7 @@ from pydantic import BaseModel
18
18
 
19
19
  from nuclia import REGIONAL
20
20
  from nuclia.exceptions import NuaAPIException
21
+ from nuclia_models.common.consumption import ConsumptionGenerative
21
22
  from nuclia_models.predict.generative_responses import (
22
23
  GenerativeChunk,
23
24
  GenerativeFullResponse,
@@ -28,6 +29,7 @@ from nuclia_models.predict.generative_responses import (
28
29
  StatusGenerativeResponse,
29
30
  ToolsGenerativeResponse,
30
31
  )
32
+ from nuclia_models.common.consumption import Consumption
31
33
  from nuclia.lib.nua_responses import (
32
34
  ChatModel,
33
35
  ChatResponse,
@@ -128,9 +130,12 @@ class NuaClient:
128
130
  url: str,
129
131
  output: Type[ConvertType],
130
132
  payload: Optional[dict[Any, Any]] = None,
133
+ extra_headers: Optional[dict[str, str]] = None,
131
134
  timeout: int = 60,
132
135
  ) -> ConvertType:
133
- resp = self.client.request(method, url, json=payload, timeout=timeout)
136
+ resp = self.client.request(
137
+ method, url, json=payload, timeout=timeout, headers=extra_headers
138
+ )
134
139
  if resp.status_code != 200:
135
140
  raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
136
141
  try:
@@ -143,6 +148,7 @@ class NuaClient:
143
148
  self,
144
149
  method: str,
145
150
  url: str,
151
+ extra_headers: Optional[dict[str, str]] = None,
146
152
  payload: Optional[dict[Any, Any]] = None,
147
153
  timeout: int = 60,
148
154
  ) -> Iterator[GenerativeChunk]:
@@ -151,8 +157,9 @@ class NuaClient:
151
157
  url,
152
158
  json=payload,
153
159
  timeout=timeout,
160
+ headers=extra_headers,
154
161
  ) as response:
155
- if response.headers.get("content-type") == "application/x-ndjson":
162
+ if response.headers.get("transfer-encoding") == "chunked":
156
163
  for json_body in response.iter_lines():
157
164
  try:
158
165
  yield GenerativeChunk.model_validate_json(json_body) # type: ignore
@@ -194,17 +201,31 @@ class NuaClient:
194
201
  endpoint = f"{self.url}{CONFIG}/{kbid}"
195
202
  return self._request("GET", endpoint, output=StoredLearningConfiguration)
196
203
 
197
- def sentence_predict(self, text: str, model: Optional[str] = None) -> Sentence:
204
+ def sentence_predict(
205
+ self,
206
+ text: str,
207
+ model: Optional[str] = None,
208
+ extra_headers: Optional[dict[str, str]] = None,
209
+ ) -> Sentence:
198
210
  endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
199
211
  if model:
200
212
  endpoint += f"&model={model}"
201
- return self._request("GET", endpoint, output=Sentence)
213
+ return self._request(
214
+ "GET", endpoint, output=Sentence, extra_headers=extra_headers
215
+ )
202
216
 
203
- def tokens_predict(self, text: str, model: Optional[str] = None) -> Tokens:
217
+ def tokens_predict(
218
+ self,
219
+ text: str,
220
+ model: Optional[str] = None,
221
+ extra_headers: Optional[dict[str, str]] = None,
222
+ ) -> Tokens:
204
223
  endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
205
224
  if model:
206
225
  endpoint += f"&model={model}"
207
- return self._request("GET", endpoint, output=Tokens)
226
+ return self._request(
227
+ "GET", endpoint, output=Tokens, extra_headers=extra_headers
228
+ )
208
229
 
209
230
  def query_predict(
210
231
  self,
@@ -212,6 +233,7 @@ class NuaClient:
212
233
  semantic_model: Optional[str] = None,
213
234
  token_model: Optional[str] = None,
214
235
  generative_model: Optional[str] = None,
236
+ extra_headers: Optional[dict[str, str]] = None,
215
237
  ) -> QueryInfo:
216
238
  endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
217
239
  if semantic_model:
@@ -220,10 +242,16 @@ class NuaClient:
220
242
  endpoint += f"&token_model={token_model}"
221
243
  if generative_model:
222
244
  endpoint += f"&generative_model={generative_model}"
223
- return self._request("GET", endpoint, output=QueryInfo)
245
+ return self._request(
246
+ "GET", endpoint, output=QueryInfo, extra_headers=extra_headers
247
+ )
224
248
 
225
249
  def generate(
226
- self, body: ChatModel, model: Optional[str] = None, timeout: int = 300
250
+ self,
251
+ body: ChatModel,
252
+ model: Optional[str] = None,
253
+ extra_headers: Optional[dict[str, str]] = None,
254
+ timeout: int = 300,
227
255
  ) -> GenerativeFullResponse:
228
256
  endpoint = f"{self.url}{CHAT_PREDICT}"
229
257
  if model:
@@ -235,6 +263,7 @@ class NuaClient:
235
263
  endpoint,
236
264
  payload=body.model_dump(),
237
265
  timeout=timeout,
266
+ extra_headers=extra_headers,
238
267
  ):
239
268
  if isinstance(chunk.chunk, TextGenerativeResponse):
240
269
  result.answer += chunk.chunk.text
@@ -252,10 +281,19 @@ class NuaClient:
252
281
  result.code = chunk.chunk.code
253
282
  elif isinstance(chunk.chunk, ToolsGenerativeResponse):
254
283
  result.tools = chunk.chunk.tools
284
+ elif isinstance(chunk.chunk, ConsumptionGenerative):
285
+ result.consumption = Consumption(
286
+ normalized_tokens=chunk.chunk.normalized_tokens,
287
+ customer_key_tokens=chunk.chunk.customer_key_tokens,
288
+ )
255
289
  return result
256
290
 
257
291
  def generate_stream(
258
- self, body: ChatModel, model: Optional[str] = None, timeout: int = 300
292
+ self,
293
+ body: ChatModel,
294
+ model: Optional[str] = None,
295
+ extra_headers: Optional[dict[str, str]] = None,
296
+ timeout: int = 300,
259
297
  ) -> Iterator[GenerativeChunk]:
260
298
  endpoint = f"{self.url}{CHAT_PREDICT}"
261
299
  if model:
@@ -266,11 +304,16 @@ class NuaClient:
266
304
  endpoint,
267
305
  payload=body.model_dump(),
268
306
  timeout=timeout,
307
+ extra_headers=extra_headers,
269
308
  ):
270
309
  yield gr
271
310
 
272
311
  def summarize(
273
- self, documents: dict[str, str], model: Optional[str] = None, timeout: int = 300
312
+ self,
313
+ documents: dict[str, str],
314
+ model: Optional[str] = None,
315
+ extra_headers: Optional[dict[str, str]] = None,
316
+ timeout: int = 300,
274
317
  ) -> SummarizedModel:
275
318
  endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
276
319
  if model:
@@ -288,6 +331,7 @@ class NuaClient:
288
331
  payload=body.model_dump(),
289
332
  output=SummarizedModel,
290
333
  timeout=timeout,
334
+ extra_headers=extra_headers,
291
335
  )
292
336
 
293
337
  def rephrase(
@@ -324,11 +368,13 @@ class NuaClient:
324
368
  def remi(
325
369
  self,
326
370
  request: RemiRequest,
371
+ extra_headers: Optional[dict[str, str]] = None,
327
372
  ) -> RemiResponse:
328
373
  endpoint = f"{self.url}{REMI_PREDICT}"
329
374
  return self._request(
330
375
  "POST",
331
376
  endpoint,
377
+ extra_headers=extra_headers,
332
378
  payload=request.model_dump(),
333
379
  output=RemiResponse,
334
380
  )
@@ -413,10 +459,18 @@ class NuaClient:
413
459
  activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
414
460
  return self._request("GET", activity_endpoint, ProcessRequestStatus)
415
461
 
416
- def rerank(self, model: RerankModel) -> RerankResponse:
462
+ def rerank(
463
+ self,
464
+ model: RerankModel,
465
+ extra_headers: Optional[dict[str, str]] = None,
466
+ ) -> RerankResponse:
417
467
  endpoint = f"{self.url}{RERANK}"
418
468
  return self._request(
419
- "POST", endpoint, payload=model.model_dump(), output=RerankResponse
469
+ "POST",
470
+ endpoint,
471
+ payload=model.model_dump(),
472
+ output=RerankResponse,
473
+ extra_headers=extra_headers,
420
474
  )
421
475
 
422
476
 
@@ -454,9 +508,12 @@ class AsyncNuaClient:
454
508
  url: str,
455
509
  output: Type[ConvertType],
456
510
  payload: Optional[dict[Any, Any]] = None,
511
+ extra_headers: Optional[dict[str, str]] = None,
457
512
  timeout: int = 60,
458
513
  ) -> ConvertType:
459
- resp = await self.client.request(method, url, json=payload, timeout=timeout)
514
+ resp = await self.client.request(
515
+ method, url, json=payload, timeout=timeout, headers=extra_headers
516
+ )
460
517
  if resp.status_code != 200:
461
518
  raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
462
519
  try:
@@ -469,6 +526,7 @@ class AsyncNuaClient:
469
526
  self,
470
527
  method: str,
471
528
  url: str,
529
+ extra_headers: Optional[dict[str, str]] = None,
472
530
  payload: Optional[dict[Any, Any]] = None,
473
531
  timeout: int = 60,
474
532
  ) -> AsyncIterator[GenerativeChunk]:
@@ -477,8 +535,9 @@ class AsyncNuaClient:
477
535
  url,
478
536
  json=payload,
479
537
  timeout=timeout,
538
+ headers=extra_headers,
480
539
  ) as response:
481
- if response.headers.get("content-type") == "application/x-ndjson":
540
+ if response.headers.get("transfer-encoding") == "chunked":
482
541
  async for json_body in response.aiter_lines():
483
542
  try:
484
543
  yield GenerativeChunk.model_validate_json(json_body) # type: ignore
@@ -527,18 +586,30 @@ class AsyncNuaClient:
527
586
  return await self._request("GET", endpoint, output=StoredLearningConfiguration)
528
587
 
529
588
  async def sentence_predict(
530
- self, text: str, model: Optional[str] = None
589
+ self,
590
+ text: str,
591
+ model: Optional[str] = None,
592
+ extra_headers: Optional[dict[str, str]] = None,
531
593
  ) -> Sentence:
532
594
  endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
533
595
  if model:
534
596
  endpoint += f"&model={model}"
535
- return await self._request("GET", endpoint, output=Sentence)
597
+ return await self._request(
598
+ "GET", endpoint, output=Sentence, extra_headers=extra_headers
599
+ )
536
600
 
537
- async def tokens_predict(self, text: str, model: Optional[str] = None) -> Tokens:
601
+ async def tokens_predict(
602
+ self,
603
+ text: str,
604
+ model: Optional[str] = None,
605
+ extra_headers: Optional[dict[str, str]] = None,
606
+ ) -> Tokens:
538
607
  endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
539
608
  if model:
540
609
  endpoint += f"&model={model}"
541
- return await self._request("GET", endpoint, output=Tokens)
610
+ return await self._request(
611
+ "GET", endpoint, output=Tokens, extra_headers=extra_headers
612
+ )
542
613
 
543
614
  async def query_predict(
544
615
  self,
@@ -546,6 +617,7 @@ class AsyncNuaClient:
546
617
  semantic_model: Optional[str] = None,
547
618
  token_model: Optional[str] = None,
548
619
  generative_model: Optional[str] = None,
620
+ extra_headers: Optional[dict[str, str]] = None,
549
621
  ) -> QueryInfo:
550
622
  endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
551
623
  if semantic_model:
@@ -554,7 +626,9 @@ class AsyncNuaClient:
554
626
  endpoint += f"&token_model={token_model}"
555
627
  if generative_model:
556
628
  endpoint += f"&generative_model={generative_model}"
557
- return await self._request("GET", endpoint, output=QueryInfo)
629
+ return await self._request(
630
+ "GET", endpoint, output=QueryInfo, extra_headers=extra_headers
631
+ )
558
632
 
559
633
  @deprecated(version="2.1.0", reason="You should use generate function")
560
634
  async def generate_predict(
@@ -573,7 +647,11 @@ class AsyncNuaClient:
573
647
  )
574
648
 
575
649
  async def generate(
576
- self, body: ChatModel, model: Optional[str] = None, timeout: int = 300
650
+ self,
651
+ body: ChatModel,
652
+ model: Optional[str] = None,
653
+ extra_headers: Optional[dict[str, str]] = None,
654
+ timeout: int = 300,
577
655
  ) -> GenerativeFullResponse:
578
656
  endpoint = f"{self.url}{CHAT_PREDICT}"
579
657
  if model:
@@ -585,6 +663,7 @@ class AsyncNuaClient:
585
663
  endpoint,
586
664
  payload=body.model_dump(),
587
665
  timeout=timeout,
666
+ extra_headers=extra_headers,
588
667
  ):
589
668
  if isinstance(chunk.chunk, TextGenerativeResponse):
590
669
  result.answer += chunk.chunk.text
@@ -602,11 +681,20 @@ class AsyncNuaClient:
602
681
  result.code = chunk.chunk.code
603
682
  elif isinstance(chunk.chunk, ToolsGenerativeResponse):
604
683
  result.tools = chunk.chunk.tools
684
+ elif isinstance(chunk.chunk, ConsumptionGenerative):
685
+ result.consumption = Consumption(
686
+ normalized_tokens=chunk.chunk.normalized_tokens,
687
+ customer_key_tokens=chunk.chunk.customer_key_tokens,
688
+ )
605
689
 
606
690
  return result
607
691
 
608
692
  async def generate_stream(
609
- self, body: ChatModel, model: Optional[str] = None, timeout: int = 300
693
+ self,
694
+ body: ChatModel,
695
+ model: Optional[str] = None,
696
+ extra_headers: Optional[dict[str, str]] = None,
697
+ timeout: int = 300,
610
698
  ) -> AsyncIterator[GenerativeChunk]:
611
699
  endpoint = f"{self.url}{CHAT_PREDICT}"
612
700
  if model:
@@ -617,11 +705,16 @@ class AsyncNuaClient:
617
705
  endpoint,
618
706
  payload=body.model_dump(),
619
707
  timeout=timeout,
708
+ extra_headers=extra_headers,
620
709
  ):
621
710
  yield gr
622
711
 
623
712
  async def summarize(
624
- self, documents: dict[str, str], model: Optional[str] = None, timeout: int = 300
713
+ self,
714
+ documents: dict[str, str],
715
+ model: Optional[str] = None,
716
+ extra_headers: Optional[dict[str, str]] = None,
717
+ timeout: int = 300,
625
718
  ) -> SummarizedModel:
626
719
  endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
627
720
  if model:
@@ -639,6 +732,7 @@ class AsyncNuaClient:
639
732
  payload=body.model_dump(),
640
733
  output=SummarizedModel,
641
734
  timeout=timeout,
735
+ extra_headers=extra_headers,
642
736
  )
643
737
 
644
738
  async def rephrase(
@@ -672,13 +766,18 @@ class AsyncNuaClient:
672
766
  output=RephraseModel,
673
767
  )
674
768
 
675
- async def remi(self, request: RemiRequest) -> RemiResponse:
769
+ async def remi(
770
+ self,
771
+ request: RemiRequest,
772
+ extra_headers: Optional[dict[str, str]] = None,
773
+ ) -> RemiResponse:
676
774
  endpoint = f"{self.url}{REMI_PREDICT}"
677
775
  return await self._request(
678
776
  "POST",
679
777
  endpoint,
680
778
  payload=request.model_dump(),
681
779
  output=RemiResponse,
780
+ extra_headers=extra_headers,
682
781
  )
683
782
 
684
783
  async def generate_retrieval(
@@ -802,8 +901,14 @@ class AsyncNuaClient:
802
901
  "GET", activity_endpoint, output=ProcessRequestStatus
803
902
  )
804
903
 
805
- async def rerank(self, model: RerankModel) -> RerankResponse:
904
+ async def rerank(
905
+ self, model: RerankModel, extra_headers: Optional[dict[str, str]] = None
906
+ ) -> RerankResponse:
806
907
  endpoint = f"{self.url}{RERANK}"
807
908
  return await self._request(
808
- "POST", endpoint, payload=model.model_dump(), output=RerankResponse
909
+ "POST",
910
+ endpoint,
911
+ payload=model.model_dump(),
912
+ output=RerankResponse,
913
+ extra_headers=extra_headers,
809
914
  )
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union, cast
5
5
  import pydantic
6
6
  from pydantic import BaseModel, Field, RootModel, model_validator
7
7
  from typing_extensions import Annotated, Self
8
+ from nuclia_models.common.consumption import Consumption
8
9
 
9
10
 
10
11
  class GenerativeOption(BaseModel):
@@ -51,6 +52,7 @@ class ConfigSchema(BaseModel):
51
52
  class Sentence(BaseModel):
52
53
  data: List[float]
53
54
  time: float
55
+ consumption: Optional[Consumption] = None
54
56
 
55
57
 
56
58
  class Author(str, Enum):
@@ -137,6 +139,7 @@ class Token(BaseModel):
137
139
  class Tokens(BaseModel):
138
140
  tokens: List[Token]
139
141
  time: float
142
+ consumption: Optional[Consumption] = None
140
143
 
141
144
 
142
145
  class SummarizeResource(BaseModel):
@@ -155,6 +158,7 @@ class SummarizedResource(BaseModel):
155
158
  class SummarizedModel(BaseModel):
156
159
  resources: Dict[str, SummarizedResource]
157
160
  summary: str = ""
161
+ consumption: Optional[Consumption] = None
158
162
 
159
163
 
160
164
  class RephraseModel(RootModel[str]):
@@ -535,6 +539,7 @@ class StoredLearningConfiguration(BaseModel):
535
539
  class SentenceSearch(BaseModel):
536
540
  data: List[float] = []
537
541
  time: float
542
+ consumption: Optional[Consumption] = None
538
543
 
539
544
 
540
545
  class Ner(BaseModel):
@@ -547,6 +552,7 @@ class Ner(BaseModel):
547
552
  class TokenSearch(BaseModel):
548
553
  tokens: List[Ner] = []
549
554
  time: float
555
+ consumption: Optional[Consumption] = None
550
556
 
551
557
 
552
558
  class QueryInfo(BaseModel):
@@ -569,3 +575,4 @@ class RerankResponse(BaseModel):
569
575
  context_scores: dict[str, float] = Field(
570
576
  description="Scores for each context given by the reranker"
571
577
  )
578
+ consumption: Optional[Consumption] = None
nuclia/sdk/kb.py CHANGED
@@ -297,9 +297,15 @@ class NucliaKB:
297
297
  )
298
298
 
299
299
  @kb
300
- def summarize(self, *, resources: List[str], **kwargs):
300
+ def summarize(
301
+ self, *, resources: List[str], show_consumption: bool = False, **kwargs
302
+ ):
301
303
  ndb: NucliaDBClient = kwargs["ndb"]
302
- return ndb.ndb.summarize(kbid=ndb.kbid, resources=resources)
304
+ return ndb.ndb.summarize(
305
+ kbid=ndb.kbid,
306
+ resources=resources,
307
+ headers={"X-Show-Consumption": str(show_consumption).lower()},
308
+ )
303
309
 
304
310
  @kb
305
311
  def notifications(self, **kwargs):
@@ -715,6 +721,7 @@ class AsyncNucliaKB:
715
721
  resources: List[str],
716
722
  generative_model: Optional[str] = None,
717
723
  summary_kind: Optional[str] = None,
724
+ show_consumption: bool = False,
718
725
  timeout: int = 1000,
719
726
  **kwargs,
720
727
  ) -> SummarizedModel:
@@ -725,6 +732,7 @@ class AsyncNucliaKB:
725
732
  generative_model=generative_model,
726
733
  summary_kind=SummaryKind(summary_kind),
727
734
  ),
735
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
728
736
  timeout=timeout,
729
737
  )
730
738
  return SummarizedModel.model_validate(resp.json())
nuclia/sdk/predict.py CHANGED
@@ -52,9 +52,19 @@ class NucliaPredict:
52
52
  nc.del_config_predict(kbid)
53
53
 
54
54
  @nua
55
- def sentence(self, text: str, model: Optional[str] = None, **kwargs) -> Sentence:
55
+ def sentence(
56
+ self,
57
+ text: str,
58
+ model: Optional[str] = None,
59
+ show_consumption: bool = False,
60
+ **kwargs,
61
+ ) -> Sentence:
56
62
  nc: NuaClient = kwargs["nc"]
57
- return nc.sentence_predict(text, model)
63
+ return nc.sentence_predict(
64
+ text,
65
+ model,
66
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
67
+ )
58
68
 
59
69
  @nua
60
70
  def query(
@@ -63,19 +73,25 @@ class NucliaPredict:
63
73
  semantic_model: Optional[str] = None,
64
74
  token_model: Optional[str] = None,
65
75
  generative_model: Optional[str] = None,
76
+ show_consumption: bool = False,
66
77
  **kwargs,
67
78
  ) -> QueryInfo:
68
79
  nc: NuaClient = kwargs["nc"]
69
80
  return nc.query_predict(
70
- text,
81
+ text=text,
71
82
  semantic_model=semantic_model,
72
83
  token_model=token_model,
73
84
  generative_model=generative_model,
85
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
74
86
  )
75
87
 
76
88
  @nua
77
89
  def generate(
78
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
90
+ self,
91
+ text: Union[str, ChatModel],
92
+ model: Optional[str] = None,
93
+ show_consumption: bool = False,
94
+ **kwargs,
79
95
  ) -> GenerativeFullResponse:
80
96
  nc: NuaClient = kwargs["nc"]
81
97
  if isinstance(text, str):
@@ -88,11 +104,19 @@ class NucliaPredict:
88
104
  else:
89
105
  body = text
90
106
 
91
- return nc.generate(body, model)
107
+ return nc.generate(
108
+ body=body,
109
+ model=model,
110
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
111
+ )
92
112
 
93
113
  @nua
94
114
  def generate_stream(
95
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
115
+ self,
116
+ text: Union[str, ChatModel],
117
+ model: Optional[str] = None,
118
+ show_consumption: bool = False,
119
+ **kwargs,
96
120
  ) -> Iterator[GenerativeChunk]:
97
121
  nc: NuaClient = kwargs["nc"]
98
122
  if isinstance(text, str):
@@ -105,20 +129,42 @@ class NucliaPredict:
105
129
  else:
106
130
  body = text
107
131
 
108
- for chunk in nc.generate_stream(body, model):
132
+ for chunk in nc.generate_stream(
133
+ body=body,
134
+ model=model,
135
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
136
+ ):
109
137
  yield chunk
110
138
 
111
139
  @nua
112
- def tokens(self, text: str, model: Optional[str] = None, **kwargs) -> Tokens:
140
+ def tokens(
141
+ self,
142
+ text: str,
143
+ model: Optional[str] = None,
144
+ show_consumption: bool = False,
145
+ **kwargs,
146
+ ) -> Tokens:
113
147
  nc: NuaClient = kwargs["nc"]
114
- return nc.tokens_predict(text, model)
148
+ return nc.tokens_predict(
149
+ text,
150
+ model,
151
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
152
+ )
115
153
 
116
154
  @nua
117
155
  def summarize(
118
- self, texts: dict[str, str], model: Optional[str] = None, **kwargs
156
+ self,
157
+ texts: dict[str, str],
158
+ model: Optional[str] = None,
159
+ show_consumption: bool = False,
160
+ **kwargs,
119
161
  ) -> SummarizedModel:
120
162
  nc: NuaClient = kwargs["nc"]
121
- return nc.summarize(texts, model)
163
+ return nc.summarize(
164
+ documents=texts,
165
+ model=model,
166
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
167
+ )
122
168
 
123
169
  @nua
124
170
  def rephrase(
@@ -135,7 +181,12 @@ class NucliaPredict:
135
181
 
136
182
  @nua
137
183
  def rag(
138
- self, question: str, context: list[str], model: Optional[str] = None, **kwargs
184
+ self,
185
+ question: str,
186
+ context: list[str],
187
+ model: Optional[str] = None,
188
+ show_consumption: bool = False,
189
+ **kwargs,
139
190
  ) -> GenerativeFullResponse:
140
191
  nc: NuaClient = kwargs["nc"]
141
192
  body = ChatModel(
@@ -145,10 +196,19 @@ class NucliaPredict:
145
196
  query_context=context,
146
197
  )
147
198
 
148
- return nc.generate(body, model)
199
+ return nc.generate(
200
+ body,
201
+ model,
202
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
203
+ )
149
204
 
150
205
  @nua
151
- def remi(self, request: Optional[RemiRequest] = None, **kwargs) -> RemiResponse:
206
+ def remi(
207
+ self,
208
+ request: Optional[RemiRequest] = None,
209
+ show_consumption: bool = False,
210
+ **kwargs,
211
+ ) -> RemiResponse:
152
212
  """
153
213
  Perform a REMi evaluation over a RAG experience
154
214
 
@@ -162,10 +222,15 @@ class NucliaPredict:
162
222
  if request is None:
163
223
  request = RemiRequest(**kwargs)
164
224
  nc: NuaClient = kwargs["nc"]
165
- return nc.remi(request)
225
+ return nc.remi(
226
+ request=request,
227
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
228
+ )
166
229
 
167
230
  @nua
168
- def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
231
+ def rerank(
232
+ self, request: RerankModel, show_consumption: bool = False, **kwargs
233
+ ) -> RerankResponse:
169
234
  """
170
235
  Perform a reranking of the results based on the question and context provided.
171
236
 
@@ -173,7 +238,10 @@ class NucliaPredict:
173
238
  :return: RerankResponse
174
239
  """
175
240
  nc: NuaClient = kwargs["nc"]
176
- return nc.rerank(request)
241
+ return nc.rerank(
242
+ request,
243
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
244
+ )
177
245
 
178
246
 
179
247
  class AsyncNucliaPredict:
@@ -208,14 +276,26 @@ class AsyncNucliaPredict:
208
276
 
209
277
  @nua
210
278
  async def sentence(
211
- self, text: str, model: Optional[str] = None, **kwargs
279
+ self,
280
+ text: str,
281
+ model: Optional[str] = None,
282
+ show_consumption: bool = False,
283
+ **kwargs,
212
284
  ) -> Sentence:
213
285
  nc: AsyncNuaClient = kwargs["nc"]
214
- return await nc.sentence_predict(text, model)
286
+ return await nc.sentence_predict(
287
+ text,
288
+ model,
289
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
290
+ )
215
291
 
216
292
  @nua
217
293
  async def generate(
218
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
294
+ self,
295
+ text: Union[str, ChatModel],
296
+ model: Optional[str] = None,
297
+ show_consumption: bool = False,
298
+ **kwargs,
219
299
  ) -> GenerativeFullResponse:
220
300
  nc: AsyncNuaClient = kwargs["nc"]
221
301
  if isinstance(text, str):
@@ -227,11 +307,19 @@ class AsyncNucliaPredict:
227
307
  )
228
308
  else:
229
309
  body = text
230
- return await nc.generate(body, model)
310
+ return await nc.generate(
311
+ body=body,
312
+ model=model,
313
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
314
+ )
231
315
 
232
316
  @nua
233
317
  async def generate_stream(
234
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
318
+ self,
319
+ text: Union[str, ChatModel],
320
+ model: Optional[str] = None,
321
+ show_consumption: bool = False,
322
+ **kwargs,
235
323
  ) -> AsyncIterator[GenerativeChunk]:
236
324
  nc: AsyncNuaClient = kwargs["nc"]
237
325
  if isinstance(text, str):
@@ -244,13 +332,27 @@ class AsyncNucliaPredict:
244
332
  else:
245
333
  body = text
246
334
 
247
- async for chunk in nc.generate_stream(body, model):
335
+ async for chunk in nc.generate_stream(
336
+ body=body,
337
+ model=model,
338
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
339
+ ):
248
340
  yield chunk
249
341
 
250
342
  @nua
251
- async def tokens(self, text: str, model: Optional[str] = None, **kwargs) -> Tokens:
343
+ async def tokens(
344
+ self,
345
+ text: str,
346
+ model: Optional[str] = None,
347
+ show_consumption: bool = False,
348
+ **kwargs,
349
+ ) -> Tokens:
252
350
  nc: AsyncNuaClient = kwargs["nc"]
253
- return await nc.tokens_predict(text, model)
351
+ return await nc.tokens_predict(
352
+ text,
353
+ model,
354
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
355
+ )
254
356
 
255
357
  @nua
256
358
  async def query(
@@ -259,22 +361,32 @@ class AsyncNucliaPredict:
259
361
  semantic_model: Optional[str] = None,
260
362
  token_model: Optional[str] = None,
261
363
  generative_model: Optional[str] = None,
364
+ show_consumption: bool = False,
262
365
  **kwargs,
263
366
  ) -> QueryInfo:
264
367
  nc: AsyncNuaClient = kwargs["nc"]
265
368
  return await nc.query_predict(
266
- text,
369
+ text=text,
267
370
  semantic_model=semantic_model,
268
371
  token_model=token_model,
269
372
  generative_model=generative_model,
373
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
270
374
  )
271
375
 
272
376
  @nua
273
377
  async def summarize(
274
- self, texts: dict[str, str], model: Optional[str] = None, **kwargs
378
+ self,
379
+ texts: dict[str, str],
380
+ model: Optional[str] = None,
381
+ show_consumption: bool = False,
382
+ **kwargs,
275
383
  ) -> SummarizedModel:
276
384
  nc: AsyncNuaClient = kwargs["nc"]
277
- return await nc.summarize(texts, model)
385
+ return await nc.summarize(
386
+ documents=texts,
387
+ model=model,
388
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
389
+ )
278
390
 
279
391
  @nua
280
392
  async def rephrase(
@@ -298,7 +410,10 @@ class AsyncNucliaPredict:
298
410
 
299
411
  @nua
300
412
  async def remi(
301
- self, request: Optional[RemiRequest] = None, **kwargs
413
+ self,
414
+ request: Optional[RemiRequest] = None,
415
+ show_consumption: bool = False,
416
+ **kwargs,
302
417
  ) -> RemiResponse:
303
418
  """
304
419
  Perform a REMi evaluation over a RAG experience
@@ -311,10 +426,15 @@ class AsyncNucliaPredict:
311
426
  request = RemiRequest(**kwargs)
312
427
 
313
428
  nc: AsyncNuaClient = kwargs["nc"]
314
- return await nc.remi(request)
429
+ return await nc.remi(
430
+ request=request,
431
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
432
+ )
315
433
 
316
434
  @nua
317
- async def rerank(self, request: RerankModel, **kwargs) -> RerankResponse:
435
+ async def rerank(
436
+ self, request: RerankModel, show_consumption: bool = False, **kwargs
437
+ ) -> RerankResponse:
318
438
  """
319
439
  Perform a reranking of the results based on the question and context provided.
320
440
 
@@ -322,4 +442,7 @@ class AsyncNucliaPredict:
322
442
  :return: RerankResponse
323
443
  """
324
444
  nc: AsyncNuaClient = kwargs["nc"]
325
- return await nc.rerank(request)
445
+ return await nc.rerank(
446
+ request,
447
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
448
+ )
nuclia/sdk/search.py CHANGED
@@ -30,6 +30,7 @@ from nuclia.lib.kb import AsyncNucliaDBClient, NucliaDBClient
30
30
  from nuclia.sdk.logger import logger
31
31
  from nuclia.sdk.auth import AsyncNucliaAuth, NucliaAuth
32
32
  from nuclia.sdk.resource import RagImagesStrategiesParse, RagStrategiesParse
33
+ from nuclia_models.common.consumption import Consumption, TokensDetail
33
34
 
34
35
 
35
36
  @dataclass
@@ -49,6 +50,7 @@ class AskAnswer:
49
50
  relations: Optional[Relations]
50
51
  predict_request: Optional[ChatModel]
51
52
  error_details: Optional[str]
53
+ consumption: Optional[Consumption]
52
54
 
53
55
  def __str__(self):
54
56
  if self.answer:
@@ -184,8 +186,9 @@ class NucliaSearch:
184
186
  filters: Optional[Union[List[str], List[Filter]]] = None,
185
187
  rag_strategies: Optional[list[RagStrategies]] = None,
186
188
  rag_images_strategies: Optional[list[RagImagesStrategies]] = None,
189
+ show_consumption: bool = False,
187
190
  **kwargs,
188
- ):
191
+ ) -> AskAnswer:
189
192
  """
190
193
  Answer a question.
191
194
 
@@ -217,7 +220,11 @@ class NucliaSearch:
217
220
  else:
218
221
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
219
222
 
220
- ask_response: SyncAskResponse = ndb.ndb.ask(kbid=ndb.kbid, content=req)
223
+ ask_response: SyncAskResponse = ndb.ndb.ask(
224
+ kbid=ndb.kbid,
225
+ content=req,
226
+ headers={"X-Show-Consumption": str(show_consumption).lower()},
227
+ )
221
228
 
222
229
  result = AskAnswer(
223
230
  answer=ask_response.answer.encode(),
@@ -239,6 +246,7 @@ class NucliaSearch:
239
246
  else None,
240
247
  relations=ask_response.relations,
241
248
  prompt_context=ask_response.prompt_context,
249
+ consumption=ask_response.consumption,
242
250
  )
243
251
 
244
252
  if ask_response.prompt_context:
@@ -257,8 +265,9 @@ class NucliaSearch:
257
265
  schema: Union[str, Dict[str, Any]],
258
266
  query: Union[str, dict, AskRequest, None] = None,
259
267
  filters: Optional[Union[List[str], List[Filter]]] = None,
268
+ show_consumption: bool = False,
260
269
  **kwargs,
261
- ):
270
+ ) -> Optional[AskAnswer]:
262
271
  """
263
272
  Answer a question.
264
273
 
@@ -272,10 +281,10 @@ class NucliaSearch:
272
281
  schema_json = json.load(json_file_handler)
273
282
  except Exception:
274
283
  logger.exception("File format is not JSON")
275
- return
284
+ return None
276
285
  else:
277
286
  logger.exception("File not found")
278
- return
287
+ return None
279
288
  else:
280
289
  schema_json = schema
281
290
 
@@ -303,7 +312,11 @@ class NucliaSearch:
303
312
  req.filters = filters
304
313
  else:
305
314
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
306
- ask_response: SyncAskResponse = ndb.ndb.ask(kbid=ndb.kbid, content=req)
315
+ ask_response: SyncAskResponse = ndb.ndb.ask(
316
+ kbid=ndb.kbid,
317
+ content=req,
318
+ headers={"X-Show-Consumption": str(show_consumption).lower()},
319
+ )
307
320
 
308
321
  result = AskAnswer(
309
322
  answer=ask_response.answer.encode(),
@@ -325,6 +338,7 @@ class NucliaSearch:
325
338
  else None,
326
339
  relations=ask_response.relations,
327
340
  prompt_context=ask_response.prompt_context,
341
+ consumption=ask_response.consumption,
328
342
  )
329
343
  if ask_response.metadata is not None:
330
344
  if ask_response.metadata.timings is not None:
@@ -483,9 +497,10 @@ class AsyncNucliaSearch:
483
497
  *,
484
498
  query: Union[str, dict, AskRequest],
485
499
  filters: Optional[List[str]] = None,
500
+ show_consumption: bool = False,
486
501
  timeout: int = 100,
487
502
  **kwargs,
488
- ):
503
+ ) -> AskAnswer:
489
504
  """
490
505
  Answer a question.
491
506
 
@@ -509,7 +524,11 @@ class AsyncNucliaSearch:
509
524
  req = query
510
525
  else:
511
526
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
512
- ask_stream_response = await ndb.ask(req, timeout=timeout)
527
+ ask_stream_response = await ndb.ask(
528
+ req,
529
+ timeout=timeout,
530
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
531
+ )
513
532
  result = AskAnswer(
514
533
  answer=b"",
515
534
  learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
@@ -526,6 +545,7 @@ class AsyncNucliaSearch:
526
545
  predict_request=None,
527
546
  relations=None,
528
547
  prompt_context=None,
548
+ consumption=None,
529
549
  )
530
550
  async for line in ask_stream_response.aiter_lines():
531
551
  try:
@@ -548,6 +568,19 @@ class AsyncNucliaSearch:
548
568
  result.timings = ask_response_item.timings.model_dump()
549
569
  if ask_response_item.tokens:
550
570
  result.tokens = ask_response_item.tokens.model_dump()
571
+ elif ask_response_item.type == "consumption":
572
+ result.consumption = Consumption(
573
+ normalized_tokens=TokensDetail(
574
+ input=ask_response_item.normalized_tokens.input,
575
+ output=ask_response_item.normalized_tokens.output,
576
+ image=ask_response_item.normalized_tokens.image,
577
+ ),
578
+ customer_key_tokens=TokensDetail(
579
+ input=ask_response_item.customer_key_tokens.input,
580
+ output=ask_response_item.customer_key_tokens.output,
581
+ image=ask_response_item.customer_key_tokens.image,
582
+ ),
583
+ )
551
584
  elif ask_response_item.type == "status":
552
585
  result.status = ask_response_item.status
553
586
  elif ask_response_item.type == "prequeries":
@@ -569,6 +602,7 @@ class AsyncNucliaSearch:
569
602
  *,
570
603
  query: Union[str, dict, AskRequest],
571
604
  filters: Optional[List[str]] = None,
605
+ show_consumption: bool = False,
572
606
  timeout: int = 100,
573
607
  **kwargs,
574
608
  ) -> AsyncIterator[AskResponseItem]:
@@ -593,7 +627,11 @@ class AsyncNucliaSearch:
593
627
  req = query
594
628
  else:
595
629
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
596
- ask_stream_response = await ndb.ask(req, timeout=timeout)
630
+ ask_stream_response = await ndb.ask(
631
+ req,
632
+ timeout=timeout,
633
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
634
+ )
597
635
  async for line in ask_stream_response.aiter_lines():
598
636
  try:
599
637
  ask_response_item = AskResponseItem.model_validate_json(line)
@@ -609,9 +647,10 @@ class AsyncNucliaSearch:
609
647
  query: Union[str, dict, AskRequest],
610
648
  schema: Dict[str, Any],
611
649
  filters: Optional[List[str]] = None,
650
+ show_consumption: bool = False,
612
651
  timeout: int = 100,
613
652
  **kwargs,
614
- ):
653
+ ) -> AskAnswer:
615
654
  """
616
655
  Answer a question.
617
656
 
@@ -635,7 +674,11 @@ class AsyncNucliaSearch:
635
674
  req = query
636
675
  else:
637
676
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
638
- ask_stream_response = await ndb.ask(req, timeout=timeout)
677
+ ask_stream_response = await ndb.ask(
678
+ req,
679
+ timeout=timeout,
680
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
681
+ )
639
682
  result = AskAnswer(
640
683
  answer=b"",
641
684
  learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
@@ -652,6 +695,7 @@ class AsyncNucliaSearch:
652
695
  predict_request=None,
653
696
  relations=None,
654
697
  prompt_context=None,
698
+ consumption=None,
655
699
  )
656
700
  async for line in ask_stream_response.aiter_lines():
657
701
  try:
@@ -674,6 +718,19 @@ class AsyncNucliaSearch:
674
718
  result.timings = ask_response_item.timings.model_dump()
675
719
  if ask_response_item.tokens:
676
720
  result.tokens = ask_response_item.tokens.model_dump()
721
+ elif ask_response_item.type == "consumption":
722
+ result.consumption = Consumption(
723
+ normalized_tokens=TokensDetail(
724
+ input=ask_response_item.normalized_tokens.input,
725
+ output=ask_response_item.normalized_tokens.output,
726
+ image=ask_response_item.normalized_tokens.image,
727
+ ),
728
+ customer_key_tokens=TokensDetail(
729
+ input=ask_response_item.customer_key_tokens.input,
730
+ output=ask_response_item.customer_key_tokens.output,
731
+ image=ask_response_item.customer_key_tokens.image,
732
+ ),
733
+ )
677
734
  elif ask_response_item.type == "status":
678
735
  result.status = ask_response_item.status
679
736
  elif ask_response_item.type == "prequeries":
@@ -132,7 +132,11 @@ def test_ask_json(testing_config):
132
132
  async def test_ask_json_async(testing_config):
133
133
  search = AsyncNucliaSearch()
134
134
  results = await search.ask_json(
135
- query="Who is hedy Lamarr?", filters=["/icon/application/pdf"], schema=SCHEMA
135
+ query="Who is hedy Lamarr?",
136
+ filters=["/icon/application/pdf"],
137
+ schema=SCHEMA,
138
+ show_consumption=True,
136
139
  )
137
140
 
138
141
  assert "TECHNOLOGY" in results.object["document_type"]
142
+ assert results.consumption is not None
@@ -1,4 +1,7 @@
1
- from nuclia_models.predict.generative_responses import TextGenerativeResponse
1
+ from nuclia_models.predict.generative_responses import (
2
+ TextGenerativeResponse,
3
+ ConsumptionGenerative,
4
+ )
2
5
 
3
6
  from nuclia.lib.nua_responses import ChatModel, RerankModel, UserPrompt
4
7
  from nuclia.sdk.predict import AsyncNucliaPredict, NucliaPredict
@@ -8,9 +11,12 @@ from nuclia_models.predict.remi import RemiRequest
8
11
 
9
12
  def test_predict(testing_config):
10
13
  np = NucliaPredict()
11
- embed = np.sentence(text="This is my text", model="multilingual-2024-05-06")
14
+ embed = np.sentence(
15
+ text="This is my text", model="multilingual-2024-05-06", show_consumption=True
16
+ )
12
17
  assert embed.time > 0
13
18
  assert len(embed.data) == 1024
19
+ assert embed.consumption is not None
14
20
 
15
21
 
16
22
  def test_predict_query(testing_config):
@@ -20,11 +26,14 @@ def test_predict_query(testing_config):
20
26
  semantic_model="multilingual-2024-05-06",
21
27
  token_model="multilingual",
22
28
  generative_model="chatgpt-azure-4o-mini",
29
+ show_consumption=True,
23
30
  )
24
31
  assert query.language == "en"
25
32
  assert query.visual_llm is True
26
33
  assert query.entities and query.entities.tokens[0].text == "Ramon"
27
34
  assert query.sentence and len(query.sentence.data) == 1024
35
+ assert query.entities.consumption is not None
36
+ assert query.sentence.consumption is not None
28
37
 
29
38
 
30
39
  def test_rag(testing_config):
@@ -36,14 +45,34 @@ def test_rag(testing_config):
36
45
  "Eudald Camprubí is CEO at the same company as Ramon Navarro",
37
46
  ],
38
47
  model="chatgpt-azure-4o-mini",
48
+ show_consumption=True,
39
49
  )
40
50
  assert "Eudald" in generated.answer
51
+ assert generated.consumption is not None
41
52
 
42
53
 
43
54
  def test_generative(testing_config):
44
55
  np = NucliaPredict()
45
56
  generated = np.generate(text="How much is 2 + 2?", model="chatgpt-azure-4o-mini")
46
57
  assert "4" in generated.answer
58
+ assert generated.consumption is None
59
+
60
+
61
+ @pytest.mark.asyncio
62
+ async def test_generative_with_consumption(testing_config):
63
+ np = NucliaPredict()
64
+ generated = np.generate(
65
+ text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
66
+ )
67
+ assert "4" in generated.answer
68
+ assert generated.consumption is not None
69
+
70
+ anp = AsyncNucliaPredict()
71
+ async_generated = await anp.generate(
72
+ text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
73
+ )
74
+ assert "4" in async_generated.answer
75
+ assert async_generated.consumption is not None
47
76
 
48
77
 
49
78
  @pytest.mark.asyncio
@@ -70,13 +99,18 @@ def test_stream_generative(testing_config):
70
99
  @pytest.mark.asyncio
71
100
  async def test_async_stream_generative(testing_config):
72
101
  np = AsyncNucliaPredict()
102
+ consumption_found = False
103
+ found = False
73
104
  async for stream in np.generate_stream(
74
- text="How much is 2 + 2?", model="chatgpt-azure-4o-mini"
105
+ text="How much is 2 + 2?", model="chatgpt-azure-4o-mini", show_consumption=True
75
106
  ):
76
107
  if isinstance(stream.chunk, TextGenerativeResponse) and stream.chunk.text:
77
108
  if "4" in stream.chunk.text:
78
109
  found = True
110
+ elif isinstance(stream.chunk, ConsumptionGenerative):
111
+ consumption_found = True
79
112
  assert found
113
+ assert consumption_found
80
114
 
81
115
 
82
116
  SCHEMA = {
@@ -148,6 +182,8 @@ def test_nua_remi(testing_config):
148
182
  assert results.context_relevance[1] < 2
149
183
  assert results.groundedness[1] < 2
150
184
 
185
+ assert results.consumption is None
186
+
151
187
 
152
188
  @pytest.mark.asyncio
153
189
  async def test_nua_async_remi(testing_config):
@@ -161,7 +197,8 @@ async def test_nua_async_remi(testing_config):
161
197
  "Paris is the capital of France.",
162
198
  "Berlin is the capital of Germany.",
163
199
  ],
164
- )
200
+ ),
201
+ show_consumption=True,
165
202
  )
166
203
  assert results.answer_relevance.score >= 4
167
204
 
@@ -171,6 +208,8 @@ async def test_nua_async_remi(testing_config):
171
208
  assert results.context_relevance[1] < 2
172
209
  assert results.groundedness[1] < 2
173
210
 
211
+ assert results.consumption is not None
212
+
174
213
 
175
214
  def test_nua_rerank(testing_config):
176
215
  np = NucliaPredict()
@@ -185,3 +224,37 @@ def test_nua_rerank(testing_config):
185
224
  )
186
225
  )
187
226
  assert results.context_scores["1"] > results.context_scores["2"]
227
+ assert results.consumption is None
228
+
229
+
230
+ @pytest.mark.asyncio
231
+ async def test_nua_rerank_with_consumption(testing_config):
232
+ np = NucliaPredict()
233
+ results = np.rerank(
234
+ RerankModel(
235
+ user_id="Nuclia PY CLI",
236
+ question="What is the capital of France?",
237
+ context={
238
+ "1": "Paris is the capital of France.",
239
+ "2": "Berlin is the capital of Germany.",
240
+ },
241
+ ),
242
+ show_consumption=True,
243
+ )
244
+ assert results.context_scores["1"] > results.context_scores["2"]
245
+ assert results.consumption is not None
246
+
247
+ anp = AsyncNucliaPredict()
248
+ async_results = await anp.rerank(
249
+ RerankModel(
250
+ user_id="Nuclia PY CLI",
251
+ question="What is the capital of France?",
252
+ context={
253
+ "1": "Paris is the capital of France.",
254
+ "2": "Berlin is the capital of Germany.",
255
+ },
256
+ ),
257
+ show_consumption=True,
258
+ )
259
+ assert async_results.context_scores["1"] > async_results.context_scores["2"]
260
+ assert async_results.consumption is not None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nuclia
3
- Version: 4.9.3
3
+ Version: 4.9.4
4
4
  Summary: Nuclia Python SDK
5
5
  Author-email: Nuclia <info@nuclia.com>
6
6
  License-Expression: MIT
@@ -24,9 +24,9 @@ Requires-Dist: requests
24
24
  Requires-Dist: httpx
25
25
  Requires-Dist: httpcore>=1.0.0
26
26
  Requires-Dist: prompt_toolkit
27
- Requires-Dist: nucliadb_sdk<7,>=6.5
28
- Requires-Dist: nucliadb_models<7,>=6.5
29
- Requires-Dist: nuclia-models>=0.41.1
27
+ Requires-Dist: nucliadb_sdk<7,>=6.6.1
28
+ Requires-Dist: nucliadb_models<7,>=6.6.1
29
+ Requires-Dist: nuclia-models>=0.45.0
30
30
  Requires-Dist: tqdm
31
31
  Requires-Dist: aiofiles
32
32
  Requires-Dist: backoff
@@ -9,11 +9,11 @@ nuclia/cli/run.py,sha256=B1hP0upSbSCqqT89WAwsd93ZxkAoF6ajVyLOdYmo8fU,1560
9
9
  nuclia/cli/utils.py,sha256=iZ3P8juBdAGvaRUd2BGz7bpUXNDHdPrC5p876yyZ2Cs,1223
10
10
  nuclia/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  nuclia/lib/conversations.py,sha256=M6qhL9NPEKroYF767S-Q2XWokRrjX02kpYTzRvZKwUE,149
12
- nuclia/lib/kb.py,sha256=vSLfmV6HqPvWJwPVw4iIzkDnm0M96_ImTH3QMZdkd_I,30985
12
+ nuclia/lib/kb.py,sha256=hbWg-0pX1OCt9ezUXYcOfyWHahgmZLUQe2nlSycHetU,31364
13
13
  nuclia/lib/models.py,sha256=ekEQrVIFU3aFvt60yQh-zpWkGNORBMSc7c5Hd_VzPzI,1564
14
- nuclia/lib/nua.py,sha256=jVUV8I50iJpUQD6yn4V7wNDDSDP8oEWrRxnmrs33hnI,27591
14
+ nuclia/lib/nua.py,sha256=qP6Gv9ck117kdqNuKVl_qHodrDE8CS8xdh_bdzeJMiQ,30631
15
15
  nuclia/lib/nua_chat.py,sha256=ApL1Y1FWvAVUt-Y9a_8TUSJIhg8-UmBSy8TlDPn6tD8,3874
16
- nuclia/lib/nua_responses.py,sha256=FyCedSKHVGn4w380obV9B3D1JaqVyNTqda4-OCs4A9A,13637
16
+ nuclia/lib/nua_responses.py,sha256=KHAjoW3fu2QT7u4F4WfL4DQIM8lCE5W_14vuRtCmtZg,13970
17
17
  nuclia/lib/utils.py,sha256=9l6DxBk-11WqUhXRn99cqeuVTUOJXj-1S6zckal7wOk,6312
18
18
  nuclia/sdk/__init__.py,sha256=-nAw8i53XBdmbfTa1FJZ0FNRMNakimDVpD6W4OdES-c,1374
19
19
  nuclia/sdk/accounts.py,sha256=7XQ3K9_jlSuk2Cez868FtazZ05xSGab6h3Mt1qMMwIE,647
@@ -22,17 +22,17 @@ nuclia/sdk/auth.py,sha256=o9CiWt3P_UHQW_L5Jq_9IOB-KpSCXFsVTouQcZp8BJM,25072
22
22
  nuclia/sdk/backup.py,sha256=adbPcNEbHGZW698o028toXKfDkDrmk5QRIDSiN6SPys,6529
23
23
  nuclia/sdk/export_import.py,sha256=y5cTOxhILwRPIvR2Ya12bk-ReGbeDzA3C9TPxgnOHD4,9756
24
24
  nuclia/sdk/extract_strategy.py,sha256=NZBLLThdLyQYw8z1mT9iRhFjkE5sQP86-3QhsTiyV9o,2540
25
- nuclia/sdk/kb.py,sha256=2-H9FOvPgsG-ZYNUA8D4FYAJhX3K8m2VwbwSQy1JV7c,27044
25
+ nuclia/sdk/kb.py,sha256=rY1yVx_9S01twMXgrkpKSNpPChTupwYEw7EAIOwzLYA,27321
26
26
  nuclia/sdk/kbs.py,sha256=nXEvg5ddZYdDS8Kie7TrN-s1meU9ecYLf9FlT5xr-ro,9131
27
27
  nuclia/sdk/logger.py,sha256=UHB81eS6IGmLrsofKxLh8cmF2AsaTj_HXP0tGqMr_HM,57
28
28
  nuclia/sdk/logs.py,sha256=3jfORpo8fzZiXFFSbGY0o3Bre1ZgJaKQCXgxP1keNHw,9614
29
29
  nuclia/sdk/nua.py,sha256=6t0m0Sx-UhqNU2Hx9v6vTwy0m3a30K4T0KmP9G43MzY,293
30
30
  nuclia/sdk/nucliadb.py,sha256=bOESIppPgY7IrNqrYY7T3ESoxwttbOSTm5zj1xUS1jI,1288
31
- nuclia/sdk/predict.py,sha256=Gom-BlISXawThNq-f7fb1te5tY4catlIdmO-pMJyqkk,9815
31
+ nuclia/sdk/predict.py,sha256=ZQMtoeQBXaPuaJRxOmO23nnKkVKRpNCMi0JgXWSJ9lM,12836
32
32
  nuclia/sdk/process.py,sha256=WuNnqaWprp-EABWDC_z7O2woesGIlYWnDUKozh7Ibr4,2241
33
33
  nuclia/sdk/remi.py,sha256=BEb3O9R2jOFlOda4vjFucKKGO1c2eTkqYZdFlIy3Zmo,4357
34
34
  nuclia/sdk/resource.py,sha256=0lSvD4e1FpN5iM9W295dOKLJ8hsXfIe8HKdo0HsVg20,13976
35
- nuclia/sdk/search.py,sha256=1mLJzDO-W3ObReDL1xK8zgjIxkdpuIhyf87-L4bISKQ,25180
35
+ nuclia/sdk/search.py,sha256=tHpx0gwVwSepa3dn7AX2iWs7JO0pQkjJk7v8TEoL7Gg,27742
36
36
  nuclia/sdk/task.py,sha256=UawH-7IneRIGVOiLdDQ2vDBwe5eI51UXRbLRRVeu3C4,6095
37
37
  nuclia/sdk/upload.py,sha256=ZBzYROF3yP-77HcaR06OBsFjJAbTOCvF-nlxaqQZsT4,22720
38
38
  nuclia/sdk/zones.py,sha256=1ARWrTsTuzj8zguanpX3OaIw-3Qq_ULS_g4GG2mHxOA,342
@@ -49,7 +49,7 @@ nuclia/tests/test_kb/test_labels.py,sha256=IUdTq4mzv0OrOkwBWWy4UwKGKyJybtoHrgvXr
49
49
  nuclia/tests/test_kb/test_logs.py,sha256=Z9ELtiiU9NniITJzeWt92GCcERKYy9Nwc_fUVPboRU0,3121
50
50
  nuclia/tests/test_kb/test_remi.py,sha256=OX5N-MHbgcwpLg6fBjrAK_KhqkMspJo_VKQHCBCayZ8,2080
51
51
  nuclia/tests/test_kb/test_resource.py,sha256=05Xgmg5fwcPW2PZKnUSSjr6MPXp5w8XDgx8plfNcR68,1102
52
- nuclia/tests/test_kb/test_search.py,sha256=bsQhfB6-NYFwY3gqkrVJJnru153UgZqEhV22ho4VeWM,4660
52
+ nuclia/tests/test_kb/test_search.py,sha256=8v3u-VcczBCTc-rHlWVM5mUtLJKNBZSM0V4v2I1uE3E,4751
53
53
  nuclia/tests/test_kb/test_tasks.py,sha256=tfJl1js2o0_dKyXdOxiwdj929kbzMkfl5XjO8HVMInQ,3456
54
54
  nuclia/tests/test_manage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  nuclia/tests/test_manage/test_account.py,sha256=u9NhRK8gJLS7BEY618aGoYoV2rgDLHZUeSsWWYkDhF8,289
@@ -57,15 +57,15 @@ nuclia/tests/test_manage/test_auth.py,sha256=I5ho9rKhrzCiP57cDVcs2zrXyy1uWlKxzfE
57
57
  nuclia/tests/test_manage/test_kb.py,sha256=gqBMxmIuPUKC6kP2V_IapS9u72QR99V8CDQlJOa8sHU,1959
58
58
  nuclia/tests/test_nua/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  nuclia/tests/test_nua/test_agent.py,sha256=iPA8lVw7CyONS-0fVG7rDzv8T-LnUlNPMNynTnP-2ic,334
60
- nuclia/tests/test_nua/test_predict.py,sha256=8by69GgXuOZKAgksjSjgmbNr98jCXanF1HOo29oNuMg,5798
60
+ nuclia/tests/test_nua/test_predict.py,sha256=EKgpvHZHA4pf3ENE17j0B7CKewqHeJryR7D8JSTGJMo,8154
61
61
  nuclia/tests/test_nucliadb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
62
  nuclia/tests/test_nucliadb/test_crud.py,sha256=GuY76HRvt2DFaNgioKm5n0Aco1HnG7zzV_zKom5N8xc,1173
63
63
  nuclia/tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
64
  nuclia/tests/unit/test_export_import.py,sha256=xo_wVbjUnNlVV65ZGH7LtZ38qy39EkJp2hjOuTHC1nU,980
65
65
  nuclia/tests/unit/test_nua_responses.py,sha256=t_hIdVztTi27RWvpfTJUYcCL0lpKdZFegZIwLdaPNh8,319
66
- nuclia-4.9.3.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
67
- nuclia-4.9.3.dist-info/METADATA,sha256=rxQGLEDjPHayY_NPuXkmXpfsp4_BfoBC2ZjLlOjhyoM,2337
68
- nuclia-4.9.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
- nuclia-4.9.3.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
70
- nuclia-4.9.3.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
71
- nuclia-4.9.3.dist-info/RECORD,,
66
+ nuclia-4.9.4.dist-info/licenses/LICENSE,sha256=Ops2LTti_HJtpmWcanuUTdTY3vKDR1myJ0gmGBKC0FA,1063
67
+ nuclia-4.9.4.dist-info/METADATA,sha256=2fk95GmG2Qa-GGAOPxWXWgA-gAGv-vEIHMRftKOsgA0,2341
68
+ nuclia-4.9.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
+ nuclia-4.9.4.dist-info/entry_points.txt,sha256=iZHOyXPNS54r3eQmdi5So20xO1gudI9K2oP4sQsCJRw,46
70
+ nuclia-4.9.4.dist-info/top_level.txt,sha256=cqn_EitXOoXOSUvZnd4q6QGrhm04pg8tLAZtem-Zfdo,7
71
+ nuclia-4.9.4.dist-info/RECORD,,
File without changes