nuclia 4.9.3__py3-none-any.whl → 4.9.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/kb.py +37 -10
- nuclia/lib/nua.py +130 -25
- nuclia/lib/nua_responses.py +7 -0
- nuclia/sdk/kb.py +10 -2
- nuclia/sdk/predict.py +155 -32
- nuclia/sdk/search.py +68 -11
- nuclia/sdk/task.py +5 -5
- nuclia/tests/test_kb/test_search.py +5 -1
- nuclia/tests/test_nua/test_predict.py +77 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/METADATA +4 -4
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/RECORD +15 -15
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/WHEEL +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.3.dist-info → nuclia-4.9.5.dist-info}/top_level.txt +0 -0
nuclia/sdk/predict.py
CHANGED
@@ -52,9 +52,19 @@ class NucliaPredict:
|
|
52
52
|
nc.del_config_predict(kbid)
|
53
53
|
|
54
54
|
@nua
|
55
|
-
def sentence(
|
55
|
+
def sentence(
|
56
|
+
self,
|
57
|
+
text: str,
|
58
|
+
model: Optional[str] = None,
|
59
|
+
show_consumption: bool = False,
|
60
|
+
**kwargs,
|
61
|
+
) -> Sentence:
|
56
62
|
nc: NuaClient = kwargs["nc"]
|
57
|
-
return nc.sentence_predict(
|
63
|
+
return nc.sentence_predict(
|
64
|
+
text,
|
65
|
+
model,
|
66
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
67
|
+
)
|
58
68
|
|
59
69
|
@nua
|
60
70
|
def query(
|
@@ -63,19 +73,25 @@ class NucliaPredict:
|
|
63
73
|
semantic_model: Optional[str] = None,
|
64
74
|
token_model: Optional[str] = None,
|
65
75
|
generative_model: Optional[str] = None,
|
76
|
+
show_consumption: bool = False,
|
66
77
|
**kwargs,
|
67
78
|
) -> QueryInfo:
|
68
79
|
nc: NuaClient = kwargs["nc"]
|
69
80
|
return nc.query_predict(
|
70
|
-
text,
|
81
|
+
text=text,
|
71
82
|
semantic_model=semantic_model,
|
72
83
|
token_model=token_model,
|
73
84
|
generative_model=generative_model,
|
85
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
74
86
|
)
|
75
87
|
|
76
88
|
@nua
|
77
89
|
def generate(
|
78
|
-
self,
|
90
|
+
self,
|
91
|
+
text: Union[str, ChatModel],
|
92
|
+
model: Optional[str] = None,
|
93
|
+
show_consumption: bool = False,
|
94
|
+
**kwargs,
|
79
95
|
) -> GenerativeFullResponse:
|
80
96
|
nc: NuaClient = kwargs["nc"]
|
81
97
|
if isinstance(text, str):
|
@@ -88,11 +104,19 @@ class NucliaPredict:
|
|
88
104
|
else:
|
89
105
|
body = text
|
90
106
|
|
91
|
-
return nc.generate(
|
107
|
+
return nc.generate(
|
108
|
+
body=body,
|
109
|
+
model=model,
|
110
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
111
|
+
)
|
92
112
|
|
93
113
|
@nua
|
94
114
|
def generate_stream(
|
95
|
-
self,
|
115
|
+
self,
|
116
|
+
text: Union[str, ChatModel],
|
117
|
+
model: Optional[str] = None,
|
118
|
+
show_consumption: bool = False,
|
119
|
+
**kwargs,
|
96
120
|
) -> Iterator[GenerativeChunk]:
|
97
121
|
nc: NuaClient = kwargs["nc"]
|
98
122
|
if isinstance(text, str):
|
@@ -105,20 +129,42 @@ class NucliaPredict:
|
|
105
129
|
else:
|
106
130
|
body = text
|
107
131
|
|
108
|
-
for chunk in nc.generate_stream(
|
132
|
+
for chunk in nc.generate_stream(
|
133
|
+
body=body,
|
134
|
+
model=model,
|
135
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
136
|
+
):
|
109
137
|
yield chunk
|
110
138
|
|
111
139
|
@nua
|
112
|
-
def tokens(
|
140
|
+
def tokens(
|
141
|
+
self,
|
142
|
+
text: str,
|
143
|
+
model: Optional[str] = None,
|
144
|
+
show_consumption: bool = False,
|
145
|
+
**kwargs,
|
146
|
+
) -> Tokens:
|
113
147
|
nc: NuaClient = kwargs["nc"]
|
114
|
-
return nc.tokens_predict(
|
148
|
+
return nc.tokens_predict(
|
149
|
+
text,
|
150
|
+
model,
|
151
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
152
|
+
)
|
115
153
|
|
116
154
|
@nua
|
117
155
|
def summarize(
|
118
|
-
self,
|
156
|
+
self,
|
157
|
+
texts: dict[str, str],
|
158
|
+
model: Optional[str] = None,
|
159
|
+
show_consumption: bool = False,
|
160
|
+
**kwargs,
|
119
161
|
) -> SummarizedModel:
|
120
162
|
nc: NuaClient = kwargs["nc"]
|
121
|
-
return nc.summarize(
|
163
|
+
return nc.summarize(
|
164
|
+
documents=texts,
|
165
|
+
model=model,
|
166
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
167
|
+
)
|
122
168
|
|
123
169
|
@nua
|
124
170
|
def rephrase(
|
@@ -135,7 +181,12 @@ class NucliaPredict:
|
|
135
181
|
|
136
182
|
@nua
|
137
183
|
def rag(
|
138
|
-
self,
|
184
|
+
self,
|
185
|
+
question: str,
|
186
|
+
context: list[str],
|
187
|
+
model: Optional[str] = None,
|
188
|
+
show_consumption: bool = False,
|
189
|
+
**kwargs,
|
139
190
|
) -> GenerativeFullResponse:
|
140
191
|
nc: NuaClient = kwargs["nc"]
|
141
192
|
body = ChatModel(
|
@@ -145,10 +196,19 @@ class NucliaPredict:
|
|
145
196
|
query_context=context,
|
146
197
|
)
|
147
198
|
|
148
|
-
return nc.generate(
|
199
|
+
return nc.generate(
|
200
|
+
body,
|
201
|
+
model,
|
202
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
203
|
+
)
|
149
204
|
|
150
205
|
@nua
|
151
|
-
def remi(
|
206
|
+
def remi(
|
207
|
+
self,
|
208
|
+
request: Optional[RemiRequest] = None,
|
209
|
+
show_consumption: bool = False,
|
210
|
+
**kwargs,
|
211
|
+
) -> RemiResponse:
|
152
212
|
"""
|
153
213
|
Perform a REMi evaluation over a RAG experience
|
154
214
|
|
@@ -162,10 +222,15 @@ class NucliaPredict:
|
|
162
222
|
if request is None:
|
163
223
|
request = RemiRequest(**kwargs)
|
164
224
|
nc: NuaClient = kwargs["nc"]
|
165
|
-
return nc.remi(
|
225
|
+
return nc.remi(
|
226
|
+
request=request,
|
227
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
228
|
+
)
|
166
229
|
|
167
230
|
@nua
|
168
|
-
def rerank(
|
231
|
+
def rerank(
|
232
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
233
|
+
) -> RerankResponse:
|
169
234
|
"""
|
170
235
|
Perform a reranking of the results based on the question and context provided.
|
171
236
|
|
@@ -173,7 +238,10 @@ class NucliaPredict:
|
|
173
238
|
:return: RerankResponse
|
174
239
|
"""
|
175
240
|
nc: NuaClient = kwargs["nc"]
|
176
|
-
return nc.rerank(
|
241
|
+
return nc.rerank(
|
242
|
+
request,
|
243
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
244
|
+
)
|
177
245
|
|
178
246
|
|
179
247
|
class AsyncNucliaPredict:
|
@@ -208,14 +276,26 @@ class AsyncNucliaPredict:
|
|
208
276
|
|
209
277
|
@nua
|
210
278
|
async def sentence(
|
211
|
-
self,
|
279
|
+
self,
|
280
|
+
text: str,
|
281
|
+
model: Optional[str] = None,
|
282
|
+
show_consumption: bool = False,
|
283
|
+
**kwargs,
|
212
284
|
) -> Sentence:
|
213
285
|
nc: AsyncNuaClient = kwargs["nc"]
|
214
|
-
return await nc.sentence_predict(
|
286
|
+
return await nc.sentence_predict(
|
287
|
+
text,
|
288
|
+
model,
|
289
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
290
|
+
)
|
215
291
|
|
216
292
|
@nua
|
217
293
|
async def generate(
|
218
|
-
self,
|
294
|
+
self,
|
295
|
+
text: Union[str, ChatModel],
|
296
|
+
model: Optional[str] = None,
|
297
|
+
show_consumption: bool = False,
|
298
|
+
**kwargs,
|
219
299
|
) -> GenerativeFullResponse:
|
220
300
|
nc: AsyncNuaClient = kwargs["nc"]
|
221
301
|
if isinstance(text, str):
|
@@ -227,11 +307,19 @@ class AsyncNucliaPredict:
|
|
227
307
|
)
|
228
308
|
else:
|
229
309
|
body = text
|
230
|
-
return await nc.generate(
|
310
|
+
return await nc.generate(
|
311
|
+
body=body,
|
312
|
+
model=model,
|
313
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
314
|
+
)
|
231
315
|
|
232
316
|
@nua
|
233
317
|
async def generate_stream(
|
234
|
-
self,
|
318
|
+
self,
|
319
|
+
text: Union[str, ChatModel],
|
320
|
+
model: Optional[str] = None,
|
321
|
+
show_consumption: bool = False,
|
322
|
+
**kwargs,
|
235
323
|
) -> AsyncIterator[GenerativeChunk]:
|
236
324
|
nc: AsyncNuaClient = kwargs["nc"]
|
237
325
|
if isinstance(text, str):
|
@@ -244,13 +332,27 @@ class AsyncNucliaPredict:
|
|
244
332
|
else:
|
245
333
|
body = text
|
246
334
|
|
247
|
-
async for chunk in nc.generate_stream(
|
335
|
+
async for chunk in nc.generate_stream(
|
336
|
+
body=body,
|
337
|
+
model=model,
|
338
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
339
|
+
):
|
248
340
|
yield chunk
|
249
341
|
|
250
342
|
@nua
|
251
|
-
async def tokens(
|
343
|
+
async def tokens(
|
344
|
+
self,
|
345
|
+
text: str,
|
346
|
+
model: Optional[str] = None,
|
347
|
+
show_consumption: bool = False,
|
348
|
+
**kwargs,
|
349
|
+
) -> Tokens:
|
252
350
|
nc: AsyncNuaClient = kwargs["nc"]
|
253
|
-
return await nc.tokens_predict(
|
351
|
+
return await nc.tokens_predict(
|
352
|
+
text,
|
353
|
+
model,
|
354
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
355
|
+
)
|
254
356
|
|
255
357
|
@nua
|
256
358
|
async def query(
|
@@ -259,22 +361,32 @@ class AsyncNucliaPredict:
|
|
259
361
|
semantic_model: Optional[str] = None,
|
260
362
|
token_model: Optional[str] = None,
|
261
363
|
generative_model: Optional[str] = None,
|
364
|
+
show_consumption: bool = False,
|
262
365
|
**kwargs,
|
263
366
|
) -> QueryInfo:
|
264
367
|
nc: AsyncNuaClient = kwargs["nc"]
|
265
368
|
return await nc.query_predict(
|
266
|
-
text,
|
369
|
+
text=text,
|
267
370
|
semantic_model=semantic_model,
|
268
371
|
token_model=token_model,
|
269
372
|
generative_model=generative_model,
|
373
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
270
374
|
)
|
271
375
|
|
272
376
|
@nua
|
273
377
|
async def summarize(
|
274
|
-
self,
|
378
|
+
self,
|
379
|
+
texts: dict[str, str],
|
380
|
+
model: Optional[str] = None,
|
381
|
+
show_consumption: bool = False,
|
382
|
+
**kwargs,
|
275
383
|
) -> SummarizedModel:
|
276
384
|
nc: AsyncNuaClient = kwargs["nc"]
|
277
|
-
return await nc.summarize(
|
385
|
+
return await nc.summarize(
|
386
|
+
documents=texts,
|
387
|
+
model=model,
|
388
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
389
|
+
)
|
278
390
|
|
279
391
|
@nua
|
280
392
|
async def rephrase(
|
@@ -298,7 +410,10 @@ class AsyncNucliaPredict:
|
|
298
410
|
|
299
411
|
@nua
|
300
412
|
async def remi(
|
301
|
-
self,
|
413
|
+
self,
|
414
|
+
request: Optional[RemiRequest] = None,
|
415
|
+
show_consumption: bool = False,
|
416
|
+
**kwargs,
|
302
417
|
) -> RemiResponse:
|
303
418
|
"""
|
304
419
|
Perform a REMi evaluation over a RAG experience
|
@@ -311,10 +426,15 @@ class AsyncNucliaPredict:
|
|
311
426
|
request = RemiRequest(**kwargs)
|
312
427
|
|
313
428
|
nc: AsyncNuaClient = kwargs["nc"]
|
314
|
-
return await nc.remi(
|
429
|
+
return await nc.remi(
|
430
|
+
request=request,
|
431
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
432
|
+
)
|
315
433
|
|
316
434
|
@nua
|
317
|
-
async def rerank(
|
435
|
+
async def rerank(
|
436
|
+
self, request: RerankModel, show_consumption: bool = False, **kwargs
|
437
|
+
) -> RerankResponse:
|
318
438
|
"""
|
319
439
|
Perform a reranking of the results based on the question and context provided.
|
320
440
|
|
@@ -322,4 +442,7 @@ class AsyncNucliaPredict:
|
|
322
442
|
:return: RerankResponse
|
323
443
|
"""
|
324
444
|
nc: AsyncNuaClient = kwargs["nc"]
|
325
|
-
return await nc.rerank(
|
445
|
+
return await nc.rerank(
|
446
|
+
request,
|
447
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
448
|
+
)
|
nuclia/sdk/search.py
CHANGED
@@ -30,6 +30,7 @@ from nuclia.lib.kb import AsyncNucliaDBClient, NucliaDBClient
|
|
30
30
|
from nuclia.sdk.logger import logger
|
31
31
|
from nuclia.sdk.auth import AsyncNucliaAuth, NucliaAuth
|
32
32
|
from nuclia.sdk.resource import RagImagesStrategiesParse, RagStrategiesParse
|
33
|
+
from nuclia_models.common.consumption import Consumption, TokensDetail
|
33
34
|
|
34
35
|
|
35
36
|
@dataclass
|
@@ -49,6 +50,7 @@ class AskAnswer:
|
|
49
50
|
relations: Optional[Relations]
|
50
51
|
predict_request: Optional[ChatModel]
|
51
52
|
error_details: Optional[str]
|
53
|
+
consumption: Optional[Consumption]
|
52
54
|
|
53
55
|
def __str__(self):
|
54
56
|
if self.answer:
|
@@ -184,8 +186,9 @@ class NucliaSearch:
|
|
184
186
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
185
187
|
rag_strategies: Optional[list[RagStrategies]] = None,
|
186
188
|
rag_images_strategies: Optional[list[RagImagesStrategies]] = None,
|
189
|
+
show_consumption: bool = False,
|
187
190
|
**kwargs,
|
188
|
-
):
|
191
|
+
) -> AskAnswer:
|
189
192
|
"""
|
190
193
|
Answer a question.
|
191
194
|
|
@@ -217,7 +220,11 @@ class NucliaSearch:
|
|
217
220
|
else:
|
218
221
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
219
222
|
|
220
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
223
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
224
|
+
kbid=ndb.kbid,
|
225
|
+
content=req,
|
226
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
227
|
+
)
|
221
228
|
|
222
229
|
result = AskAnswer(
|
223
230
|
answer=ask_response.answer.encode(),
|
@@ -239,6 +246,7 @@ class NucliaSearch:
|
|
239
246
|
else None,
|
240
247
|
relations=ask_response.relations,
|
241
248
|
prompt_context=ask_response.prompt_context,
|
249
|
+
consumption=ask_response.consumption,
|
242
250
|
)
|
243
251
|
|
244
252
|
if ask_response.prompt_context:
|
@@ -257,8 +265,9 @@ class NucliaSearch:
|
|
257
265
|
schema: Union[str, Dict[str, Any]],
|
258
266
|
query: Union[str, dict, AskRequest, None] = None,
|
259
267
|
filters: Optional[Union[List[str], List[Filter]]] = None,
|
268
|
+
show_consumption: bool = False,
|
260
269
|
**kwargs,
|
261
|
-
):
|
270
|
+
) -> Optional[AskAnswer]:
|
262
271
|
"""
|
263
272
|
Answer a question.
|
264
273
|
|
@@ -272,10 +281,10 @@ class NucliaSearch:
|
|
272
281
|
schema_json = json.load(json_file_handler)
|
273
282
|
except Exception:
|
274
283
|
logger.exception("File format is not JSON")
|
275
|
-
return
|
284
|
+
return None
|
276
285
|
else:
|
277
286
|
logger.exception("File not found")
|
278
|
-
return
|
287
|
+
return None
|
279
288
|
else:
|
280
289
|
schema_json = schema
|
281
290
|
|
@@ -303,7 +312,11 @@ class NucliaSearch:
|
|
303
312
|
req.filters = filters
|
304
313
|
else:
|
305
314
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
306
|
-
ask_response: SyncAskResponse = ndb.ndb.ask(
|
315
|
+
ask_response: SyncAskResponse = ndb.ndb.ask(
|
316
|
+
kbid=ndb.kbid,
|
317
|
+
content=req,
|
318
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
319
|
+
)
|
307
320
|
|
308
321
|
result = AskAnswer(
|
309
322
|
answer=ask_response.answer.encode(),
|
@@ -325,6 +338,7 @@ class NucliaSearch:
|
|
325
338
|
else None,
|
326
339
|
relations=ask_response.relations,
|
327
340
|
prompt_context=ask_response.prompt_context,
|
341
|
+
consumption=ask_response.consumption,
|
328
342
|
)
|
329
343
|
if ask_response.metadata is not None:
|
330
344
|
if ask_response.metadata.timings is not None:
|
@@ -483,9 +497,10 @@ class AsyncNucliaSearch:
|
|
483
497
|
*,
|
484
498
|
query: Union[str, dict, AskRequest],
|
485
499
|
filters: Optional[List[str]] = None,
|
500
|
+
show_consumption: bool = False,
|
486
501
|
timeout: int = 100,
|
487
502
|
**kwargs,
|
488
|
-
):
|
503
|
+
) -> AskAnswer:
|
489
504
|
"""
|
490
505
|
Answer a question.
|
491
506
|
|
@@ -509,7 +524,11 @@ class AsyncNucliaSearch:
|
|
509
524
|
req = query
|
510
525
|
else:
|
511
526
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
512
|
-
ask_stream_response = await ndb.ask(
|
527
|
+
ask_stream_response = await ndb.ask(
|
528
|
+
req,
|
529
|
+
timeout=timeout,
|
530
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
531
|
+
)
|
513
532
|
result = AskAnswer(
|
514
533
|
answer=b"",
|
515
534
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -526,6 +545,7 @@ class AsyncNucliaSearch:
|
|
526
545
|
predict_request=None,
|
527
546
|
relations=None,
|
528
547
|
prompt_context=None,
|
548
|
+
consumption=None,
|
529
549
|
)
|
530
550
|
async for line in ask_stream_response.aiter_lines():
|
531
551
|
try:
|
@@ -548,6 +568,19 @@ class AsyncNucliaSearch:
|
|
548
568
|
result.timings = ask_response_item.timings.model_dump()
|
549
569
|
if ask_response_item.tokens:
|
550
570
|
result.tokens = ask_response_item.tokens.model_dump()
|
571
|
+
elif ask_response_item.type == "consumption":
|
572
|
+
result.consumption = Consumption(
|
573
|
+
normalized_tokens=TokensDetail(
|
574
|
+
input=ask_response_item.normalized_tokens.input,
|
575
|
+
output=ask_response_item.normalized_tokens.output,
|
576
|
+
image=ask_response_item.normalized_tokens.image,
|
577
|
+
),
|
578
|
+
customer_key_tokens=TokensDetail(
|
579
|
+
input=ask_response_item.customer_key_tokens.input,
|
580
|
+
output=ask_response_item.customer_key_tokens.output,
|
581
|
+
image=ask_response_item.customer_key_tokens.image,
|
582
|
+
),
|
583
|
+
)
|
551
584
|
elif ask_response_item.type == "status":
|
552
585
|
result.status = ask_response_item.status
|
553
586
|
elif ask_response_item.type == "prequeries":
|
@@ -569,6 +602,7 @@ class AsyncNucliaSearch:
|
|
569
602
|
*,
|
570
603
|
query: Union[str, dict, AskRequest],
|
571
604
|
filters: Optional[List[str]] = None,
|
605
|
+
show_consumption: bool = False,
|
572
606
|
timeout: int = 100,
|
573
607
|
**kwargs,
|
574
608
|
) -> AsyncIterator[AskResponseItem]:
|
@@ -593,7 +627,11 @@ class AsyncNucliaSearch:
|
|
593
627
|
req = query
|
594
628
|
else:
|
595
629
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
596
|
-
ask_stream_response = await ndb.ask(
|
630
|
+
ask_stream_response = await ndb.ask(
|
631
|
+
req,
|
632
|
+
timeout=timeout,
|
633
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
634
|
+
)
|
597
635
|
async for line in ask_stream_response.aiter_lines():
|
598
636
|
try:
|
599
637
|
ask_response_item = AskResponseItem.model_validate_json(line)
|
@@ -609,9 +647,10 @@ class AsyncNucliaSearch:
|
|
609
647
|
query: Union[str, dict, AskRequest],
|
610
648
|
schema: Dict[str, Any],
|
611
649
|
filters: Optional[List[str]] = None,
|
650
|
+
show_consumption: bool = False,
|
612
651
|
timeout: int = 100,
|
613
652
|
**kwargs,
|
614
|
-
):
|
653
|
+
) -> AskAnswer:
|
615
654
|
"""
|
616
655
|
Answer a question.
|
617
656
|
|
@@ -635,7 +674,11 @@ class AsyncNucliaSearch:
|
|
635
674
|
req = query
|
636
675
|
else:
|
637
676
|
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
|
638
|
-
ask_stream_response = await ndb.ask(
|
677
|
+
ask_stream_response = await ndb.ask(
|
678
|
+
req,
|
679
|
+
timeout=timeout,
|
680
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
681
|
+
)
|
639
682
|
result = AskAnswer(
|
640
683
|
answer=b"",
|
641
684
|
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
|
@@ -652,6 +695,7 @@ class AsyncNucliaSearch:
|
|
652
695
|
predict_request=None,
|
653
696
|
relations=None,
|
654
697
|
prompt_context=None,
|
698
|
+
consumption=None,
|
655
699
|
)
|
656
700
|
async for line in ask_stream_response.aiter_lines():
|
657
701
|
try:
|
@@ -674,6 +718,19 @@ class AsyncNucliaSearch:
|
|
674
718
|
result.timings = ask_response_item.timings.model_dump()
|
675
719
|
if ask_response_item.tokens:
|
676
720
|
result.tokens = ask_response_item.tokens.model_dump()
|
721
|
+
elif ask_response_item.type == "consumption":
|
722
|
+
result.consumption = Consumption(
|
723
|
+
normalized_tokens=TokensDetail(
|
724
|
+
input=ask_response_item.normalized_tokens.input,
|
725
|
+
output=ask_response_item.normalized_tokens.output,
|
726
|
+
image=ask_response_item.normalized_tokens.image,
|
727
|
+
),
|
728
|
+
customer_key_tokens=TokensDetail(
|
729
|
+
input=ask_response_item.customer_key_tokens.input,
|
730
|
+
output=ask_response_item.customer_key_tokens.output,
|
731
|
+
image=ask_response_item.customer_key_tokens.image,
|
732
|
+
),
|
733
|
+
)
|
677
734
|
elif ask_response_item.type == "status":
|
678
735
|
result.status = ask_response_item.status
|
679
736
|
elif ask_response_item.type == "prequeries":
|
nuclia/sdk/task.py
CHANGED
@@ -6,9 +6,9 @@ from nuclia.sdk.auth import NucliaAuth, AsyncNucliaAuth
|
|
6
6
|
from nuclia_models.worker.tasks import (
|
7
7
|
ApplyOptions,
|
8
8
|
TaskStartKB,
|
9
|
-
TaskResponse,
|
10
9
|
TaskList,
|
11
10
|
TaskName,
|
11
|
+
TaskResponse,
|
12
12
|
PARAMETERS_TYPING,
|
13
13
|
PublicTaskSet,
|
14
14
|
TASKS,
|
@@ -64,7 +64,7 @@ class NucliaTask:
|
|
64
64
|
return TaskResponse.model_validate(response.json())
|
65
65
|
|
66
66
|
@kb
|
67
|
-
def delete(self, *args, task_id: str, **kwargs):
|
67
|
+
def delete(self, *args, task_id: str, cleanup: bool = False, **kwargs):
|
68
68
|
"""
|
69
69
|
Delete task
|
70
70
|
|
@@ -72,7 +72,7 @@ class NucliaTask:
|
|
72
72
|
"""
|
73
73
|
ndb: NucliaDBClient = kwargs["ndb"]
|
74
74
|
try:
|
75
|
-
_ = ndb.delete_task(task_id=task_id)
|
75
|
+
_ = ndb.delete_task(task_id=task_id, cleanup=cleanup)
|
76
76
|
except InvalidPayload:
|
77
77
|
pass
|
78
78
|
|
@@ -158,7 +158,7 @@ class AsyncNucliaTask:
|
|
158
158
|
return TaskResponse.model_validate(response.json())
|
159
159
|
|
160
160
|
@kb
|
161
|
-
async def delete(self, *args, task_id: str, **kwargs):
|
161
|
+
async def delete(self, *args, task_id: str, cleanup: bool = False, **kwargs):
|
162
162
|
"""
|
163
163
|
Delete task
|
164
164
|
|
@@ -166,7 +166,7 @@ class AsyncNucliaTask:
|
|
166
166
|
"""
|
167
167
|
ndb: AsyncNucliaDBClient = kwargs["ndb"]
|
168
168
|
try:
|
169
|
-
_ = await ndb.delete_task(task_id=task_id)
|
169
|
+
_ = await ndb.delete_task(task_id=task_id, cleanup=cleanup)
|
170
170
|
except InvalidPayload:
|
171
171
|
pass
|
172
172
|
|
@@ -132,7 +132,11 @@ def test_ask_json(testing_config):
|
|
132
132
|
async def test_ask_json_async(testing_config):
|
133
133
|
search = AsyncNucliaSearch()
|
134
134
|
results = await search.ask_json(
|
135
|
-
query="Who is hedy Lamarr?",
|
135
|
+
query="Who is hedy Lamarr?",
|
136
|
+
filters=["/icon/application/pdf"],
|
137
|
+
schema=SCHEMA,
|
138
|
+
show_consumption=True,
|
136
139
|
)
|
137
140
|
|
138
141
|
assert "TECHNOLOGY" in results.object["document_type"]
|
142
|
+
assert results.consumption is not None
|