nuclia 4.9.2__py3-none-any.whl → 4.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/kb.py +25 -6
- nuclia/lib/nua.py +141 -21
- nuclia/lib/nua_responses.py +19 -0
- nuclia/sdk/kb.py +10 -2
- nuclia/sdk/predict.py +175 -28
- nuclia/sdk/search.py +68 -11
- nuclia/tests/test_kb/test_search.py +5 -1
- nuclia/tests/test_nua/test_predict.py +93 -5
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/METADATA +4 -4
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/RECORD +14 -14
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/WHEEL +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/top_level.txt +0 -0
nuclia/sdk/predict.py
CHANGED
@@ -13,6 +13,8 @@ from nuclia.lib.nua_responses import (
|
|
13
13
|
ConfigSchema,
|
14
14
|
LearningConfigurationCreation,
|
15
15
|
QueryInfo,
|
16
|
+
RerankModel,
|
17
|
+
RerankResponse,
|
16
18
|
Sentence,
|
17
19
|
StoredLearningConfiguration,
|
18
20
|
SummarizedModel,
|
@@ -50,9 +52,19 @@ class NucliaPredict:
|
|
50
52
|
nc.del_config_predict(kbid)
|
51
53
|
|
52
54
|
@nua
|
53
|
-
def sentence(
|
55
|
+
def sentence(
|
56
|
+
self,
|
57
|
+
text: str,
|
58
|
+
model: Optional[str] = None,
|
59
|
+
show_consumption: bool = False,
|
60
|
+
**kwargs,
|
61
|
+
) -> Sentence:
|
54
62
|
nc: NuaClient = kwargs["nc"]
|
55
|
-
return nc.sentence_predict(
|
63
|
+
return nc.sentence_predict(
|
64
|
+
text,
|
65
|
+
model,
|
66
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
67
|
+
)
|
56
68
|
|
57
69
|
@nua
|
58
70
|
def query(
|
@@ -61,19 +73,25 @@ class NucliaPredict:
|
|
61
73
|
semantic_model: Optional[str] = None,
|
62
74
|
token_model: Optional[str] = None,
|
63
75
|
generative_model: Optional[str] = None,
|
76
|
+
show_consumption: bool = False,
|
64
77
|
**kwargs,
|
65
78
|
) -> QueryInfo:
|
66
79
|
nc: NuaClient = kwargs["nc"]
|
67
80
|
return nc.query_predict(
|
68
|
-
text,
|
81
|
+
text=text,
|
69
82
|
semantic_model=semantic_model,
|
70
83
|
token_model=token_model,
|
71
84
|
generative_model=generative_model,
|
85
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
72
86
|
)
|
73
87
|
|
74
88
|
@nua
|
75
89
|
def generate(
|
76
|
-
self,
|
90
|
+
self,
|
91
|
+
text: Union[str, ChatModel],
|
92
|
+
model: Optional[str] = None,
|
93
|
+
show_consumption: bool = False,
|
94
|
+
**kwargs,
|
77
95
|
) -> GenerativeFullResponse:
|
78
96
|
nc: NuaClient = kwargs["nc"]
|
79
97
|
if isinstance(text, str):
|
@@ -86,11 +104,19 @@ class NucliaPredict:
|
|
86
104
|
else:
|
87
105
|
body = text
|
88
106
|
|
89
|
-
return nc.generate(
|
107
|
+
return nc.generate(
|
108
|
+
body=body,
|
109
|
+
model=model,
|
110
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
111
|
+
)
|
90
112
|
|
91
113
|
@nua
|
92
114
|
def generate_stream(
|
93
|
-
self,
|
115
|
+
self,
|
116
|
+
text: Union[str, ChatModel],
|
117
|
+
model: Optional[str] = None,
|
118
|
+
show_consumption: bool = False,
|
119
|
+
**kwargs,
|
94
120
|
) -> Iterator[GenerativeChunk]:
|
95
121
|
nc: NuaClient = kwargs["nc"]
|
96
122
|
if isinstance(text, str):
|
@@ -103,20 +129,42 @@ class NucliaPredict:
|
|
103
129
|
else:
|
104
130
|
body = text
|
105
131
|
|
106
|
-
for chunk in nc.generate_stream(
|
132
|
+
for chunk in nc.generate_stream(
|
133
|
+
body=body,
|
134
|
+
model=model,
|
135
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
136
|
+
):
|
107
137
|
yield chunk
|
108
138
|
|
109
139
|
@nua
|
110
|
-
def tokens(
|
140
|
+
def tokens(
|
141
|
+
self,
|
142
|
+
text: str,
|
143
|
+
model: Optional[str] = None,
|
144
|
+
show_consumption: bool = False,
|
145
|
+
**kwargs,
|
146
|
+
) -> Tokens:
|
111
147
|
nc: NuaClient = kwargs["nc"]
|
112
|
-
return nc.tokens_predict(
|
148
|
+
return nc.tokens_predict(
|
149
|
+
text,
|
150
|
+
model,
|
151
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
152
|
+
)
|
113
153
|
|
114
154
|
@nua
|
115
155
|
def summarize(
|
116
|
-
self,
|
156
|
+
self,
|
157
|
+
texts: dict[str, str],
|
158
|
+
model: Optional[str] = None,
|
159
|
+
show_consumption: bool = False,
|
160
|
+
**kwargs,
|
117
161
|
) -> SummarizedModel:
|
118
162
|
nc: NuaClient = kwargs["nc"]
|
119
|
-
return nc.summarize(
|
163
|
+
return nc.summarize(
|
164
|
+
documents=texts,
|
165
|
+
model=model,
|
166
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
167
|
+
)
|
120
168
|
|
121
169
|
@nua
|
122
170
|
def rephrase(
|
@@ -133,7 +181,12 @@ class NucliaPredict:
|
|
133
181
|
|
134
182
|
@nua
|
135
183
|
def rag(
|
136
|
-
self,
|
184
|
+
self,
|
185
|
+
question: str,
|
186
|
+
context: list[str],
|
187
|
+
model: Optional[str] = None,
|
188
|
+
show_consumption: bool = False,
|
189
|
+
**kwargs,
|
137
190
|
) -> GenerativeFullResponse:
|
138
191
|
nc: NuaClient = kwargs["nc"]
|
139
192
|
body = ChatModel(
|
@@ -143,10 +196,19 @@ class NucliaPredict:
|
|
143
196
|
query_context=context,
|
144
197
|
)
|
145
198
|
|
146
|
-
return nc.generate(
|
199
|
+
return nc.generate(
|
200
|
+
body,
|
201
|
+
model,
|
202
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
203
|
+
)
|
147
204
|
|
148
205
|
@nua
|
149
|
-
def remi(
|
206
|
+
def remi(
|
207
|
+
self,
|
208
|
+
request: Optional[RemiRequest] = None,
|
209
|
+
show_consumption: bool = False,
|
210
|
+
**kwargs,
|
211
|
+
) -> RemiResponse:
|
150
212
|
"""
|
151
213
|
Perform a REMi evaluation over a RAG experience
|
152
214
|
|
@@ -160,7 +222,26 @@ class NucliaPredict:
|
|
160
222
|
if request is None:
|
161
223
|
request = RemiRequest(**kwargs)
|
162
224
|
nc: NuaClient = kwargs["nc"]
|
163
|
-
return nc.remi(
|
225
|
+
return nc.remi(
|
226
|
+
request=request,
|
227
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
228
|
+
)
|
229
|
+
|
230
|
+
@nua
|
231
|
+
def rerank(
|
232
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
233
|
+
) -> RerankResponse:
|
234
|
+
"""
|
235
|
+
Perform a reranking of the results based on the question and context provided.
|
236
|
+
|
237
|
+
:param request: RerankModel
|
238
|
+
:return: RerankResponse
|
239
|
+
"""
|
240
|
+
nc: NuaClient = kwargs["nc"]
|
241
|
+
return nc.rerank(
|
242
|
+
request,
|
243
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
244
|
+
)
|
164
245
|
|
165
246
|
|
166
247
|
class AsyncNucliaPredict:
|
@@ -195,14 +276,26 @@ class AsyncNucliaPredict:
|
|
195
276
|
|
196
277
|
@nua
|
197
278
|
async def sentence(
|
198
|
-
self,
|
279
|
+
self,
|
280
|
+
text: str,
|
281
|
+
model: Optional[str] = None,
|
282
|
+
show_consumption: bool = False,
|
283
|
+
**kwargs,
|
199
284
|
) -> Sentence:
|
200
285
|
nc: AsyncNuaClient = kwargs["nc"]
|
201
|
-
return await nc.sentence_predict(
|
286
|
+
return await nc.sentence_predict(
|
287
|
+
text,
|
288
|
+
model,
|
289
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
290
|
+
)
|
202
291
|
|
203
292
|
@nua
|
204
293
|
async def generate(
|
205
|
-
self,
|
294
|
+
self,
|
295
|
+
text: Union[str, ChatModel],
|
296
|
+
model: Optional[str] = None,
|
297
|
+
show_consumption: bool = False,
|
298
|
+
**kwargs,
|
206
299
|
) -> GenerativeFullResponse:
|
207
300
|
nc: AsyncNuaClient = kwargs["nc"]
|
208
301
|
if isinstance(text, str):
|
@@ -214,11 +307,19 @@ class AsyncNucliaPredict:
|
|
214
307
|
)
|
215
308
|
else:
|
216
309
|
body = text
|
217
|
-
return await nc.generate(
|
310
|
+
return await nc.generate(
|
311
|
+
body=body,
|
312
|
+
model=model,
|
313
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
314
|
+
)
|
218
315
|
|
219
316
|
@nua
|
220
317
|
async def generate_stream(
|
221
|
-
self,
|
318
|
+
self,
|
319
|
+
text: Union[str, ChatModel],
|
320
|
+
model: Optional[str] = None,
|
321
|
+
show_consumption: bool = False,
|
322
|
+
**kwargs,
|
222
323
|
) -> AsyncIterator[GenerativeChunk]:
|
223
324
|
nc: AsyncNuaClient = kwargs["nc"]
|
224
325
|
if isinstance(text, str):
|
@@ -231,13 +332,27 @@ class AsyncNucliaPredict:
|
|
231
332
|
else:
|
232
333
|
body = text
|
233
334
|
|
234
|
-
async for chunk in nc.generate_stream(
|
335
|
+
async for chunk in nc.generate_stream(
|
336
|
+
body=body,
|
337
|
+
model=model,
|
338
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
339
|
+
):
|
235
340
|
yield chunk
|
236
341
|
|
237
342
|
@nua
|
238
|
-
async def tokens(
|
343
|
+
async def tokens(
|
344
|
+
self,
|
345
|
+
text: str,
|
346
|
+
model: Optional[str] = None,
|
347
|
+
show_consumption: bool = False,
|
348
|
+
**kwargs,
|
349
|
+
) -> Tokens:
|
239
350
|
nc: AsyncNuaClient = kwargs["nc"]
|
240
|
-
return await nc.tokens_predict(
|
351
|
+
return await nc.tokens_predict(
|
352
|
+
text,
|
353
|
+
model,
|
354
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
355
|
+
)
|
241
356
|
|
242
357
|
@nua
|
243
358
|
async def query(
|
@@ -246,22 +361,32 @@ class AsyncNucliaPredict:
|
|
246
361
|
semantic_model: Optional[str] = None,
|
247
362
|
token_model: Optional[str] = None,
|
248
363
|
generative_model: Optional[str] = None,
|
364
|
+
show_consumption: bool = False,
|
249
365
|
**kwargs,
|
250
366
|
) -> QueryInfo:
|
251
367
|
nc: AsyncNuaClient = kwargs["nc"]
|
252
368
|
return await nc.query_predict(
|
253
|
-
text,
|
369
|
+
text=text,
|
254
370
|
semantic_model=semantic_model,
|
255
371
|
token_model=token_model,
|
256
372
|
generative_model=generative_model,
|
373
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
257
374
|
)
|
258
375
|
|
259
376
|
@nua
|
260
377
|
async def summarize(
|
261
|
-
self,
|
378
|
+
self,
|
379
|
+
texts: dict[str, str],
|
380
|
+
model: Optional[str] = None,
|
381
|
+
show_consumption: bool = False,
|
382
|
+
**kwargs,
|
262
383
|
) -> SummarizedModel:
|
263
384
|
nc: AsyncNuaClient = kwargs["nc"]
|
264
|
-
return await nc.summarize(
|
385
|
+
return await nc.summarize(
|
386
|
+
documents=texts,
|
387
|
+
model=model,
|
388
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
389
|
+
)
|
265
390
|
|
266
391
|
@nua
|
267
392
|
async def rephrase(
|
@@ -285,7 +410,10 @@ class AsyncNucliaPredict:
|
|
285
410
|
|
286
411
|
@nua
|
287
412
|
async def remi(
|
288
|
-
self,
|
413
|
+
self,
|
414
|
+
request: Optional[RemiRequest] = None,
|
415
|
+
show_consumption: bool = False,
|
416
|
+
**kwargs,
|
289
417
|
) -> RemiResponse:
|
290
418
|
"""
|
291
419
|
Perform a REMi evaluation over a RAG experience
|
@@ -298,4 +426,23 @@ class AsyncNucliaPredict:
|
|
298
426
|
request = RemiRequest(**kwargs)
|
299
427
|
|
300
428
|
nc: AsyncNuaClient = kwargs["nc"]
|
301
|
-
return await nc.remi(
|
429
|
+
return await nc.remi(
|
430
|
+
request=request,
|
431
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
432
|
+
)
|
433
|
+
|
434
|
+
@nua
|
435
|
+
async def rerank(
|
436
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
437
|
+
) -> RerankResponse:
|
438
|
+
"""
|
439
|
+
Perform a reranking of the results based on the question and context provided.
|
440
|
+
|
441
|
+
:param request: RerankModel
|
442
|
+
:return: RerankResponse
|
443
|
+
"""
|
444
|
+
nc: AsyncNuaClient = kwargs["nc"]
|
445
|
+
return await nc.rerank(
|
446
|
+
request,
|
447
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
448
|
+
)
|
nuclia/sdk/search.py
CHANGED
@@ -30,6 +30,7 @@ from nuclia.lib.kb import AsyncNucliaDBClient, NucliaDBClient
|
|
30
30
|
from nuclia.sdk.logger import logger
|
31
31
|
from nuclia.sdk.auth import AsyncNucliaAuth, NucliaAuth
|
32
32
|
from nuclia.sdk.resource import RagImagesStrategiesParse, RagStrategiesParse
|
33
|
+
from nuclia_models.common.consumption import Consumption, TokensDetail
|
33
34
|
|
34
35
|
|
35
36
|
@dataclass
|
@@ -49,6 +50,7 @@ class AskAnswer:
|
|
49
50
|
relations: Optional[Relations]
|
50
51
|
predict_request: Optional[ChatModel]
|
51
52
|
error_details: Optional[str]
|
53
|
+
consumption: Optional[Consumption]
|
52
54
|
|
53
55
|
def __str__(self):
|
54
56
|
if self.answer:
|
@@ -184,8 +186,9 @@ class NucliaSearch:
|
|
184
186
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
185
187
|
rag_strategies: Optional[list[RagStrategies]] = None,
|
186
188
|
rag_images_strategies: Optional[list[RagImagesStrategies]] = None,
|
189
|
+
show_consumption: bool = False,
|
187
190
|
**kwargs,
|
188
|
-
):
|
191
|
+
) -> AskAnswer:
|
189
192
|
"""
|
190
193
|
Answer a question.
|
191
194
|
|
@@ -217,7 +220,11 @@ class NucliaSearch:
|
|
217
220
|
else:
|
218
221
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
219
222
|
|
220
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
223
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
224
|
+
kbid=ndb.kbid,
|
225
|
+
content=req,
|
226
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
227
|
+
)
|
221
228
|
|
222
229
|
result = AskAnswer(
|
223
230
|
answer=ask_response.answer.encode(),
|
@@ -239,6 +246,7 @@ class NucliaSearch:
|
|
239
246
|
else None,
|
240
247
|
relations=ask_response.relations,
|
241
248
|
prompt_context=ask_response.prompt_context,
|
249
|
+
consumption=ask_response.consumption,
|
242
250
|
)
|
243
251
|
|
244
252
|
if ask_response.prompt_context:
|
@@ -257,8 +265,9 @@ class NucliaSearch:
|
|
257
265
|
schema: Union[str, Dict[str, Any]],
|
258
266
|
query: Union[str, dict, AskRequest, None] = None,
|
259
267
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
268
|
+
show_consumption: bool = False,
|
260
269
|
**kwargs,
|
261
|
-
):
|
270
|
+
) -> Optional[AskAnswer]:
|
262
271
|
"""
|
263
272
|
Answer a question.
|
264
273
|
|
@@ -272,10 +281,10 @@ class NucliaSearch:
|
|
272
281
|
schema_json = json.load(json_file_handler)
|
273
282
|
except Exception:
|
274
283
|
logger.exception("File format is not JSON")
|
275
|
-
return
|
284
|
+
return None
|
276
285
|
else:
|
277
286
|
logger.exception("File not found")
|
278
|
-
return
|
287
|
+
return None
|
279
288
|
else:
|
280
289
|
schema_json = schema
|
281
290
|
|
@@ -303,7 +312,11 @@ class NucliaSearch:
|
|
303
312
|
req.filters = filters
|
304
313
|
else:
|
305
314
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
306
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
315
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
316
|
+
kbid=ndb.kbid,
|
317
|
+
content=req,
|
318
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
319
|
+
)
|
307
320
|
|
308
321
|
result = AskAnswer(
|
309
322
|
answer=ask_response.answer.encode(),
|
@@ -325,6 +338,7 @@ class NucliaSearch:
|
|
325
338
|
else None,
|
326
339
|
relations=ask_response.relations,
|
327
340
|
prompt_context=ask_response.prompt_context,
|
341
|
+
consumption=ask_response.consumption,
|
328
342
|
)
|
329
343
|
if ask_response.metadata is not None:
|
330
344
|
if ask_response.metadata.timings is not None:
|
@@ -483,9 +497,10 @@ class AsyncNucliaSearch:
|
|
483
497
|
*,
|
484
498
|
query: Union[str, dict, AskRequest],
|
485
499
|
filters: Optional[List[str]] = None,
|
500
|
+
show_consumption: bool = False,
|
486
501
|
timeout: int = 100,
|
487
502
|
**kwargs,
|
488
|
-
):
|
503
|
+
) -> AskAnswer:
|
489
504
|
"""
|
490
505
|
Answer a question.
|
491
506
|
|
@@ -509,7 +524,11 @@ class AsyncNucliaSearch:
|
|
509
524
|
req = query
|
510
525
|
else:
|
511
526
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
512
|
-
ask_stream_response = await ndb.ask(
|
527
|
+
ask_stream_response = await ndb.ask(
|
528
|
+
req,
|
529
|
+
timeout=timeout,
|
530
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
531
|
+
)
|
513
532
|
result = AskAnswer(
|
514
533
|
answer=b"",
|
515
534
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -526,6 +545,7 @@ class AsyncNucliaSearch:
|
|
526
545
|
predict_request=None,
|
527
546
|
relations=None,
|
528
547
|
prompt_context=None,
|
548
|
+
consumption=None,
|
529
549
|
)
|
530
550
|
async for line in ask_stream_response.aiter_lines():
|
531
551
|
try:
|
@@ -548,6 +568,19 @@ class AsyncNucliaSearch:
|
|
548
568
|
result.timings = ask_response_item.timings.model_dump()
|
549
569
|
if ask_response_item.tokens:
|
550
570
|
result.tokens = ask_response_item.tokens.model_dump()
|
571
|
+
elif ask_response_item.type == "consumption":
|
572
|
+
result.consumption = Consumption(
|
573
|
+
normalized_tokens=TokensDetail(
|
574
|
+
input=ask_response_item.normalized_tokens.input,
|
575
|
+
output=ask_response_item.normalized_tokens.output,
|
576
|
+
image=ask_response_item.normalized_tokens.image,
|
577
|
+
),
|
578
|
+
customer_key_tokens=TokensDetail(
|
579
|
+
input=ask_response_item.customer_key_tokens.input,
|
580
|
+
output=ask_response_item.customer_key_tokens.output,
|
581
|
+
image=ask_response_item.customer_key_tokens.image,
|
582
|
+
),
|
583
|
+
)
|
551
584
|
elif ask_response_item.type == "status":
|
552
585
|
result.status = ask_response_item.status
|
553
586
|
elif ask_response_item.type == "prequeries":
|
@@ -569,6 +602,7 @@ class AsyncNucliaSearch:
|
|
569
602
|
*,
|
570
603
|
query: Union[str, dict, AskRequest],
|
571
604
|
filters: Optional[List[str]] = None,
|
605
|
+
show_consumption: bool = False,
|
572
606
|
timeout: int = 100,
|
573
607
|
**kwargs,
|
574
608
|
) -> AsyncIterator[AskResponseItem]:
|
@@ -593,7 +627,11 @@ class AsyncNucliaSearch:
|
|
593
627
|
req = query
|
594
628
|
else:
|
595
629
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
596
|
-
ask_stream_response = await ndb.ask(
|
630
|
+
ask_stream_response = await ndb.ask(
|
631
|
+
req,
|
632
|
+
timeout=timeout,
|
633
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
634
|
+
)
|
597
635
|
async for line in ask_stream_response.aiter_lines():
|
598
636
|
try:
|
599
637
|
ask_response_item = AskResponseItem.model_validate_json(line)
|
@@ -609,9 +647,10 @@ class AsyncNucliaSearch:
|
|
609
647
|
query: Union[str, dict, AskRequest],
|
610
648
|
schema: Dict[str, Any],
|
611
649
|
filters: Optional[List[str]] = None,
|
650
|
+
show_consumption: bool = False,
|
612
651
|
timeout: int = 100,
|
613
652
|
**kwargs,
|
614
|
-
):
|
653
|
+
) -> AskAnswer:
|
615
654
|
"""
|
616
655
|
Answer a question.
|
617
656
|
|
@@ -635,7 +674,11 @@ class AsyncNucliaSearch:
|
|
635
674
|
req = query
|
636
675
|
else:
|
637
676
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
638
|
-
ask_stream_response = await ndb.ask(
|
677
|
+
ask_stream_response = await ndb.ask(
|
678
|
+
req,
|
679
|
+
timeout=timeout,
|
680
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
681
|
+
)
|
639
682
|
result = AskAnswer(
|
640
683
|
answer=b"",
|
641
684
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -652,6 +695,7 @@ class AsyncNucliaSearch:
|
|
652
695
|
predict_request=None,
|
653
696
|
relations=None,
|
654
697
|
prompt_context=None,
|
698
|
+
consumption=None,
|
655
699
|
)
|
656
700
|
async for line in ask_stream_response.aiter_lines():
|
657
701
|
try:
|
@@ -674,6 +718,19 @@ class AsyncNucliaSearch:
|
|
674
718
|
result.timings = ask_response_item.timings.model_dump()
|
675
719
|
if ask_response_item.tokens:
|
676
720
|
result.tokens = ask_response_item.tokens.model_dump()
|
721
|
+
elif ask_response_item.type == "consumption":
|
722
|
+
result.consumption = Consumption(
|
723
|
+
normalized_tokens=TokensDetail(
|
724
|
+
input=ask_response_item.normalized_tokens.input,
|
725
|
+
output=ask_response_item.normalized_tokens.output,
|
726
|
+
image=ask_response_item.normalized_tokens.image,
|
727
|
+
),
|
728
|
+
customer_key_tokens=TokensDetail(
|
729
|
+
input=ask_response_item.customer_key_tokens.input,
|
730
|
+
output=ask_response_item.customer_key_tokens.output,
|
731
|
+
image=ask_response_item.customer_key_tokens.image,
|
732
|
+
),
|
733
|
+
)
|
677
734
|
elif ask_response_item.type == "status":
|
678
735
|
result.status = ask_response_item.status
|
679
736
|
elif ask_response_item.type == "prequeries":
|
@@ -132,7 +132,11 @@ def test_ask_json(testing_config):
|
|
132
132
|
async def test_ask_json_async(testing_config):
|
133
133
|
search = AsyncNucliaSearch()
|
134
134
|
results = await search.ask_json(
|
135
|
-
query="Who is hedy Lamarr?",
|
135
|
+
query="Who is hedy Lamarr?",
|
136
|
+
filters=["/icon/application/pdf"],
|
137
|
+
schema=SCHEMA,
|
138
|
+
show_consumption=True,
|
136
139
|
)
|
137
140
|
|
138
141
|
assert "TECHNOLOGY" in results.object["document_type"]
|
142
|
+
assert results.consumption is not None
|