nuclia 4.9.2__py3-none-any.whl → 4.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nuclia/sdk/predict.py CHANGED
@@ -13,6 +13,8 @@ from nuclia.lib.nua_responses import (
13
13
  ConfigSchema,
14
14
  LearningConfigurationCreation,
15
15
  QueryInfo,
16
+ RerankModel,
17
+ RerankResponse,
16
18
  Sentence,
17
19
  StoredLearningConfiguration,
18
20
  SummarizedModel,
@@ -50,9 +52,19 @@ class NucliaPredict:
50
52
  nc.del_config_predict(kbid)
51
53
 
52
54
  @nua
53
- def sentence(self, text: str, model: Optional[str] = None, **kwargs) -> Sentence:
55
+ def sentence(
56
+ self,
57
+ text: str,
58
+ model: Optional[str] = None,
59
+ show_consumption: bool = False,
60
+ **kwargs,
61
+ ) -> Sentence:
54
62
  nc: NuaClient = kwargs["nc"]
55
- return nc.sentence_predict(text, model)
63
+ return nc.sentence_predict(
64
+ text,
65
+ model,
66
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
67
+ )
56
68
 
57
69
  @nua
58
70
  def query(
@@ -61,19 +73,25 @@ class NucliaPredict:
61
73
  semantic_model: Optional[str] = None,
62
74
  token_model: Optional[str] = None,
63
75
  generative_model: Optional[str] = None,
76
+ show_consumption: bool = False,
64
77
  **kwargs,
65
78
  ) -> QueryInfo:
66
79
  nc: NuaClient = kwargs["nc"]
67
80
  return nc.query_predict(
68
- text,
81
+ text=text,
69
82
  semantic_model=semantic_model,
70
83
  token_model=token_model,
71
84
  generative_model=generative_model,
85
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
72
86
  )
73
87
 
74
88
  @nua
75
89
  def generate(
76
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
90
+ self,
91
+ text: Union[str, ChatModel],
92
+ model: Optional[str] = None,
93
+ show_consumption: bool = False,
94
+ **kwargs,
77
95
  ) -> GenerativeFullResponse:
78
96
  nc: NuaClient = kwargs["nc"]
79
97
  if isinstance(text, str):
@@ -86,11 +104,19 @@ class NucliaPredict:
86
104
  else:
87
105
  body = text
88
106
 
89
- return nc.generate(body, model)
107
+ return nc.generate(
108
+ body=body,
109
+ model=model,
110
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
111
+ )
90
112
 
91
113
  @nua
92
114
  def generate_stream(
93
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
115
+ self,
116
+ text: Union[str, ChatModel],
117
+ model: Optional[str] = None,
118
+ show_consumption: bool = False,
119
+ **kwargs,
94
120
  ) -> Iterator[GenerativeChunk]:
95
121
  nc: NuaClient = kwargs["nc"]
96
122
  if isinstance(text, str):
@@ -103,20 +129,42 @@ class NucliaPredict:
103
129
  else:
104
130
  body = text
105
131
 
106
- for chunk in nc.generate_stream(body, model):
132
+ for chunk in nc.generate_stream(
133
+ body=body,
134
+ model=model,
135
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
136
+ ):
107
137
  yield chunk
108
138
 
109
139
  @nua
110
- def tokens(self, text: str, model: Optional[str] = None, **kwargs) -> Tokens:
140
+ def tokens(
141
+ self,
142
+ text: str,
143
+ model: Optional[str] = None,
144
+ show_consumption: bool = False,
145
+ **kwargs,
146
+ ) -> Tokens:
111
147
  nc: NuaClient = kwargs["nc"]
112
- return nc.tokens_predict(text, model)
148
+ return nc.tokens_predict(
149
+ text,
150
+ model,
151
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
152
+ )
113
153
 
114
154
  @nua
115
155
  def summarize(
116
- self, texts: dict[str, str], model: Optional[str] = None, **kwargs
156
+ self,
157
+ texts: dict[str, str],
158
+ model: Optional[str] = None,
159
+ show_consumption: bool = False,
160
+ **kwargs,
117
161
  ) -> SummarizedModel:
118
162
  nc: NuaClient = kwargs["nc"]
119
- return nc.summarize(texts, model)
163
+ return nc.summarize(
164
+ documents=texts,
165
+ model=model,
166
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
167
+ )
120
168
 
121
169
  @nua
122
170
  def rephrase(
@@ -133,7 +181,12 @@ class NucliaPredict:
133
181
 
134
182
  @nua
135
183
  def rag(
136
- self, question: str, context: list[str], model: Optional[str] = None, **kwargs
184
+ self,
185
+ question: str,
186
+ context: list[str],
187
+ model: Optional[str] = None,
188
+ show_consumption: bool = False,
189
+ **kwargs,
137
190
  ) -> GenerativeFullResponse:
138
191
  nc: NuaClient = kwargs["nc"]
139
192
  body = ChatModel(
@@ -143,10 +196,19 @@ class NucliaPredict:
143
196
  query_context=context,
144
197
  )
145
198
 
146
- return nc.generate(body, model)
199
+ return nc.generate(
200
+ body,
201
+ model,
202
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
203
+ )
147
204
 
148
205
  @nua
149
- def remi(self, request: Optional[RemiRequest] = None, **kwargs) -> RemiResponse:
206
+ def remi(
207
+ self,
208
+ request: Optional[RemiRequest] = None,
209
+ show_consumption: bool = False,
210
+ **kwargs,
211
+ ) -> RemiResponse:
150
212
  """
151
213
  Perform a REMi evaluation over a RAG experience
152
214
 
@@ -160,7 +222,26 @@ class NucliaPredict:
160
222
  if request is None:
161
223
  request = RemiRequest(**kwargs)
162
224
  nc: NuaClient = kwargs["nc"]
163
- return nc.remi(request)
225
+ return nc.remi(
226
+ request=request,
227
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
228
+ )
229
+
230
+ @nua
231
+ def rerank(
232
+ self, request: RerankModel, show_consumption: bool = False, **kwargs
233
+ ) -> RerankResponse:
234
+ """
235
+ Perform a reranking of the results based on the question and context provided.
236
+
237
+ :param request: RerankModel
238
+ :return: RerankResponse
239
+ """
240
+ nc: NuaClient = kwargs["nc"]
241
+ return nc.rerank(
242
+ request,
243
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
244
+ )
164
245
 
165
246
 
166
247
  class AsyncNucliaPredict:
@@ -195,14 +276,26 @@ class AsyncNucliaPredict:
195
276
 
196
277
  @nua
197
278
  async def sentence(
198
- self, text: str, model: Optional[str] = None, **kwargs
279
+ self,
280
+ text: str,
281
+ model: Optional[str] = None,
282
+ show_consumption: bool = False,
283
+ **kwargs,
199
284
  ) -> Sentence:
200
285
  nc: AsyncNuaClient = kwargs["nc"]
201
- return await nc.sentence_predict(text, model)
286
+ return await nc.sentence_predict(
287
+ text,
288
+ model,
289
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
290
+ )
202
291
 
203
292
  @nua
204
293
  async def generate(
205
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
294
+ self,
295
+ text: Union[str, ChatModel],
296
+ model: Optional[str] = None,
297
+ show_consumption: bool = False,
298
+ **kwargs,
206
299
  ) -> GenerativeFullResponse:
207
300
  nc: AsyncNuaClient = kwargs["nc"]
208
301
  if isinstance(text, str):
@@ -214,11 +307,19 @@ class AsyncNucliaPredict:
214
307
  )
215
308
  else:
216
309
  body = text
217
- return await nc.generate(body, model)
310
+ return await nc.generate(
311
+ body=body,
312
+ model=model,
313
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
314
+ )
218
315
 
219
316
  @nua
220
317
  async def generate_stream(
221
- self, text: Union[str, ChatModel], model: Optional[str] = None, **kwargs
318
+ self,
319
+ text: Union[str, ChatModel],
320
+ model: Optional[str] = None,
321
+ show_consumption: bool = False,
322
+ **kwargs,
222
323
  ) -> AsyncIterator[GenerativeChunk]:
223
324
  nc: AsyncNuaClient = kwargs["nc"]
224
325
  if isinstance(text, str):
@@ -231,13 +332,27 @@ class AsyncNucliaPredict:
231
332
  else:
232
333
  body = text
233
334
 
234
- async for chunk in nc.generate_stream(body, model):
335
+ async for chunk in nc.generate_stream(
336
+ body=body,
337
+ model=model,
338
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
339
+ ):
235
340
  yield chunk
236
341
 
237
342
  @nua
238
- async def tokens(self, text: str, model: Optional[str] = None, **kwargs) -> Tokens:
343
+ async def tokens(
344
+ self,
345
+ text: str,
346
+ model: Optional[str] = None,
347
+ show_consumption: bool = False,
348
+ **kwargs,
349
+ ) -> Tokens:
239
350
  nc: AsyncNuaClient = kwargs["nc"]
240
- return await nc.tokens_predict(text, model)
351
+ return await nc.tokens_predict(
352
+ text,
353
+ model,
354
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
355
+ )
241
356
 
242
357
  @nua
243
358
  async def query(
@@ -246,22 +361,32 @@ class AsyncNucliaPredict:
246
361
  semantic_model: Optional[str] = None,
247
362
  token_model: Optional[str] = None,
248
363
  generative_model: Optional[str] = None,
364
+ show_consumption: bool = False,
249
365
  **kwargs,
250
366
  ) -> QueryInfo:
251
367
  nc: AsyncNuaClient = kwargs["nc"]
252
368
  return await nc.query_predict(
253
- text,
369
+ text=text,
254
370
  semantic_model=semantic_model,
255
371
  token_model=token_model,
256
372
  generative_model=generative_model,
373
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
257
374
  )
258
375
 
259
376
  @nua
260
377
  async def summarize(
261
- self, texts: dict[str, str], model: Optional[str] = None, **kwargs
378
+ self,
379
+ texts: dict[str, str],
380
+ model: Optional[str] = None,
381
+ show_consumption: bool = False,
382
+ **kwargs,
262
383
  ) -> SummarizedModel:
263
384
  nc: AsyncNuaClient = kwargs["nc"]
264
- return await nc.summarize(texts, model)
385
+ return await nc.summarize(
386
+ documents=texts,
387
+ model=model,
388
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
389
+ )
265
390
 
266
391
  @nua
267
392
  async def rephrase(
@@ -285,7 +410,10 @@ class AsyncNucliaPredict:
285
410
 
286
411
  @nua
287
412
  async def remi(
288
- self, request: Optional[RemiRequest] = None, **kwargs
413
+ self,
414
+ request: Optional[RemiRequest] = None,
415
+ show_consumption: bool = False,
416
+ **kwargs,
289
417
  ) -> RemiResponse:
290
418
  """
291
419
  Perform a REMi evaluation over a RAG experience
@@ -298,4 +426,23 @@ class AsyncNucliaPredict:
298
426
  request = RemiRequest(**kwargs)
299
427
 
300
428
  nc: AsyncNuaClient = kwargs["nc"]
301
- return await nc.remi(request)
429
+ return await nc.remi(
430
+ request=request,
431
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
432
+ )
433
+
434
+ @nua
435
+ async def rerank(
436
+ self, request: RerankModel, show_consumption: bool = False, **kwargs
437
+ ) -> RerankResponse:
438
+ """
439
+ Perform a reranking of the results based on the question and context provided.
440
+
441
+ :param request: RerankModel
442
+ :return: RerankResponse
443
+ """
444
+ nc: AsyncNuaClient = kwargs["nc"]
445
+ return await nc.rerank(
446
+ request,
447
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
448
+ )
nuclia/sdk/search.py CHANGED
@@ -30,6 +30,7 @@ from nuclia.lib.kb import AsyncNucliaDBClient, NucliaDBClient
30
30
  from nuclia.sdk.logger import logger
31
31
  from nuclia.sdk.auth import AsyncNucliaAuth, NucliaAuth
32
32
  from nuclia.sdk.resource import RagImagesStrategiesParse, RagStrategiesParse
33
+ from nuclia_models.common.consumption import Consumption, TokensDetail
33
34
 
34
35
 
35
36
  @dataclass
@@ -49,6 +50,7 @@ class AskAnswer:
49
50
  relations: Optional[Relations]
50
51
  predict_request: Optional[ChatModel]
51
52
  error_details: Optional[str]
53
+ consumption: Optional[Consumption]
52
54
 
53
55
  def __str__(self):
54
56
  if self.answer:
@@ -184,8 +186,9 @@ class NucliaSearch:
184
186
  filters: Optional[Union[List[str], List[Filter]]] = None,
185
187
  rag_strategies: Optional[list[RagStrategies]] = None,
186
188
  rag_images_strategies: Optional[list[RagImagesStrategies]] = None,
189
+ show_consumption: bool = False,
187
190
  **kwargs,
188
- ):
191
+ ) -> AskAnswer:
189
192
  """
190
193
  Answer a question.
191
194
 
@@ -217,7 +220,11 @@ class NucliaSearch:
217
220
  else:
218
221
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
219
222
 
220
- ask_response: SyncAskResponse = ndb.ndb.ask(kbid=ndb.kbid, content=req)
223
+ ask_response: SyncAskResponse = ndb.ndb.ask(
224
+ kbid=ndb.kbid,
225
+ content=req,
226
+ headers={"X-Show-Consumption": str(show_consumption).lower()},
227
+ )
221
228
 
222
229
  result = AskAnswer(
223
230
  answer=ask_response.answer.encode(),
@@ -239,6 +246,7 @@ class NucliaSearch:
239
246
  else None,
240
247
  relations=ask_response.relations,
241
248
  prompt_context=ask_response.prompt_context,
249
+ consumption=ask_response.consumption,
242
250
  )
243
251
 
244
252
  if ask_response.prompt_context:
@@ -257,8 +265,9 @@ class NucliaSearch:
257
265
  schema: Union[str, Dict[str, Any]],
258
266
  query: Union[str, dict, AskRequest, None] = None,
259
267
  filters: Optional[Union[List[str], List[Filter]]] = None,
268
+ show_consumption: bool = False,
260
269
  **kwargs,
261
- ):
270
+ ) -> Optional[AskAnswer]:
262
271
  """
263
272
  Answer a question.
264
273
 
@@ -272,10 +281,10 @@ class NucliaSearch:
272
281
  schema_json = json.load(json_file_handler)
273
282
  except Exception:
274
283
  logger.exception("File format is not JSON")
275
- return
284
+ return None
276
285
  else:
277
286
  logger.exception("File not found")
278
- return
287
+ return None
279
288
  else:
280
289
  schema_json = schema
281
290
 
@@ -303,7 +312,11 @@ class NucliaSearch:
303
312
  req.filters = filters
304
313
  else:
305
314
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
306
- ask_response: SyncAskResponse = ndb.ndb.ask(kbid=ndb.kbid, content=req)
315
+ ask_response: SyncAskResponse = ndb.ndb.ask(
316
+ kbid=ndb.kbid,
317
+ content=req,
318
+ headers={"X-Show-Consumption": str(show_consumption).lower()},
319
+ )
307
320
 
308
321
  result = AskAnswer(
309
322
  answer=ask_response.answer.encode(),
@@ -325,6 +338,7 @@ class NucliaSearch:
325
338
  else None,
326
339
  relations=ask_response.relations,
327
340
  prompt_context=ask_response.prompt_context,
341
+ consumption=ask_response.consumption,
328
342
  )
329
343
  if ask_response.metadata is not None:
330
344
  if ask_response.metadata.timings is not None:
@@ -483,9 +497,10 @@ class AsyncNucliaSearch:
483
497
  *,
484
498
  query: Union[str, dict, AskRequest],
485
499
  filters: Optional[List[str]] = None,
500
+ show_consumption: bool = False,
486
501
  timeout: int = 100,
487
502
  **kwargs,
488
- ):
503
+ ) -> AskAnswer:
489
504
  """
490
505
  Answer a question.
491
506
 
@@ -509,7 +524,11 @@ class AsyncNucliaSearch:
509
524
  req = query
510
525
  else:
511
526
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
512
- ask_stream_response = await ndb.ask(req, timeout=timeout)
527
+ ask_stream_response = await ndb.ask(
528
+ req,
529
+ timeout=timeout,
530
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
531
+ )
513
532
  result = AskAnswer(
514
533
  answer=b"",
515
534
  learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
@@ -526,6 +545,7 @@ class AsyncNucliaSearch:
526
545
  predict_request=None,
527
546
  relations=None,
528
547
  prompt_context=None,
548
+ consumption=None,
529
549
  )
530
550
  async for line in ask_stream_response.aiter_lines():
531
551
  try:
@@ -548,6 +568,19 @@ class AsyncNucliaSearch:
548
568
  result.timings = ask_response_item.timings.model_dump()
549
569
  if ask_response_item.tokens:
550
570
  result.tokens = ask_response_item.tokens.model_dump()
571
+ elif ask_response_item.type == "consumption":
572
+ result.consumption = Consumption(
573
+ normalized_tokens=TokensDetail(
574
+ input=ask_response_item.normalized_tokens.input,
575
+ output=ask_response_item.normalized_tokens.output,
576
+ image=ask_response_item.normalized_tokens.image,
577
+ ),
578
+ customer_key_tokens=TokensDetail(
579
+ input=ask_response_item.customer_key_tokens.input,
580
+ output=ask_response_item.customer_key_tokens.output,
581
+ image=ask_response_item.customer_key_tokens.image,
582
+ ),
583
+ )
551
584
  elif ask_response_item.type == "status":
552
585
  result.status = ask_response_item.status
553
586
  elif ask_response_item.type == "prequeries":
@@ -569,6 +602,7 @@ class AsyncNucliaSearch:
569
602
  *,
570
603
  query: Union[str, dict, AskRequest],
571
604
  filters: Optional[List[str]] = None,
605
+ show_consumption: bool = False,
572
606
  timeout: int = 100,
573
607
  **kwargs,
574
608
  ) -> AsyncIterator[AskResponseItem]:
@@ -593,7 +627,11 @@ class AsyncNucliaSearch:
593
627
  req = query
594
628
  else:
595
629
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
596
- ask_stream_response = await ndb.ask(req, timeout=timeout)
630
+ ask_stream_response = await ndb.ask(
631
+ req,
632
+ timeout=timeout,
633
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
634
+ )
597
635
  async for line in ask_stream_response.aiter_lines():
598
636
  try:
599
637
  ask_response_item = AskResponseItem.model_validate_json(line)
@@ -609,9 +647,10 @@ class AsyncNucliaSearch:
609
647
  query: Union[str, dict, AskRequest],
610
648
  schema: Dict[str, Any],
611
649
  filters: Optional[List[str]] = None,
650
+ show_consumption: bool = False,
612
651
  timeout: int = 100,
613
652
  **kwargs,
614
- ):
653
+ ) -> AskAnswer:
615
654
  """
616
655
  Answer a question.
617
656
 
@@ -635,7 +674,11 @@ class AsyncNucliaSearch:
635
674
  req = query
636
675
  else:
637
676
  raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
638
- ask_stream_response = await ndb.ask(req, timeout=timeout)
677
+ ask_stream_response = await ndb.ask(
678
+ req,
679
+ timeout=timeout,
680
+ extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
681
+ )
639
682
  result = AskAnswer(
640
683
  answer=b"",
641
684
  learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
@@ -652,6 +695,7 @@ class AsyncNucliaSearch:
652
695
  predict_request=None,
653
696
  relations=None,
654
697
  prompt_context=None,
698
+ consumption=None,
655
699
  )
656
700
  async for line in ask_stream_response.aiter_lines():
657
701
  try:
@@ -674,6 +718,19 @@ class AsyncNucliaSearch:
674
718
  result.timings = ask_response_item.timings.model_dump()
675
719
  if ask_response_item.tokens:
676
720
  result.tokens = ask_response_item.tokens.model_dump()
721
+ elif ask_response_item.type == "consumption":
722
+ result.consumption = Consumption(
723
+ normalized_tokens=TokensDetail(
724
+ input=ask_response_item.normalized_tokens.input,
725
+ output=ask_response_item.normalized_tokens.output,
726
+ image=ask_response_item.normalized_tokens.image,
727
+ ),
728
+ customer_key_tokens=TokensDetail(
729
+ input=ask_response_item.customer_key_tokens.input,
730
+ output=ask_response_item.customer_key_tokens.output,
731
+ image=ask_response_item.customer_key_tokens.image,
732
+ ),
733
+ )
677
734
  elif ask_response_item.type == "status":
678
735
  result.status = ask_response_item.status
679
736
  elif ask_response_item.type == "prequeries":
@@ -132,7 +132,11 @@ def test_ask_json(testing_config):
132
132
  async def test_ask_json_async(testing_config):
133
133
  search = AsyncNucliaSearch()
134
134
  results = await search.ask_json(
135
- query="Who is hedy Lamarr?", filters=["/icon/application/pdf"], schema=SCHEMA
135
+ query="Who is hedy Lamarr?",
136
+ filters=["/icon/application/pdf"],
137
+ schema=SCHEMA,
138
+ show_consumption=True,
136
139
  )
137
140
 
138
141
  assert "TECHNOLOGY" in results.object["document_type"]
142
+ assert results.consumption is not None