nuclia 4.9.2__py3-none-any.whl → 4.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nuclia/lib/kb.py +25 -6
- nuclia/lib/nua.py +141 -21
- nuclia/lib/nua_responses.py +19 -0
- nuclia/sdk/kb.py +10 -2
- nuclia/sdk/predict.py +175 -28
- nuclia/sdk/search.py +68 -11
- nuclia/tests/test_kb/test_search.py +5 -1
- nuclia/tests/test_nua/test_predict.py +93 -5
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/METADATA +4 -4
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/RECORD +14 -14
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/WHEEL +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/entry_points.txt +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/licenses/LICENSE +0 -0
- {nuclia-4.9.2.dist-info → nuclia-4.9.4.dist-info}/top_level.txt +0 -0
nuclia/lib/kb.py
CHANGED
@@ -324,13 +324,18 @@ class NucliaDBClient(BaseNucliaDBClient):
|
|
324
324
|
handle_http_sync_errors(response)
|
325
325
|
return int(response.headers.get("Upload-Offset"))
|
326
326
|
|
327
|
-
def summarize(
|
327
|
+
def summarize(
|
328
|
+
self,
|
329
|
+
request: SummarizeRequest,
|
330
|
+
extra_headers: Optional[dict[str, str]] = None,
|
331
|
+
timeout: int = 1000,
|
332
|
+
):
|
328
333
|
if self.url is None or self.writer_session is None:
|
329
334
|
raise Exception("KB not configured")
|
330
335
|
url = f"{self.url}{SUMMARIZE_URL}"
|
331
336
|
assert self.reader_session
|
332
337
|
response = self.reader_session.post(
|
333
|
-
url, json=request.model_dump(), timeout=timeout
|
338
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
334
339
|
)
|
335
340
|
handle_http_sync_errors(response)
|
336
341
|
return response
|
@@ -569,12 +574,21 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
569
574
|
await handle_http_async_errors(response)
|
570
575
|
return response
|
571
576
|
|
572
|
-
async def ask(
|
577
|
+
async def ask(
|
578
|
+
self,
|
579
|
+
request: AskRequest,
|
580
|
+
extra_headers: Optional[dict[str, str]] = None,
|
581
|
+
timeout: int = 1000,
|
582
|
+
):
|
573
583
|
if self.url is None or self.reader_session is None:
|
574
584
|
raise Exception("KB not configured")
|
575
585
|
url = f"{self.url}{ASK_URL}"
|
576
586
|
req = self.reader_session.build_request(
|
577
|
-
"POST",
|
587
|
+
"POST",
|
588
|
+
url,
|
589
|
+
json=request.model_dump(),
|
590
|
+
headers=extra_headers,
|
591
|
+
timeout=timeout,
|
578
592
|
)
|
579
593
|
response = await self.reader_session.send(req, stream=True)
|
580
594
|
await handle_http_async_errors(response)
|
@@ -681,13 +695,18 @@ class AsyncNucliaDBClient(BaseNucliaDBClient):
|
|
681
695
|
await handle_http_async_errors(response)
|
682
696
|
return int(response.headers.get("Upload-Offset"))
|
683
697
|
|
684
|
-
async def summarize(
|
698
|
+
async def summarize(
|
699
|
+
self,
|
700
|
+
request: SummarizeRequest,
|
701
|
+
extra_headers: Optional[dict[str, str]] = None,
|
702
|
+
timeout: int = 1000,
|
703
|
+
):
|
685
704
|
if self.url is None or self.writer_session is None:
|
686
705
|
raise Exception("KB not configured")
|
687
706
|
url = f"{self.url}{SUMMARIZE_URL}"
|
688
707
|
assert self.reader_session
|
689
708
|
response = await self.reader_session.post(
|
690
|
-
url, json=request.model_dump(), timeout=timeout
|
709
|
+
url, json=request.model_dump(), headers=extra_headers, timeout=timeout
|
691
710
|
)
|
692
711
|
await handle_http_async_errors(response)
|
693
712
|
return response
|
nuclia/lib/nua.py
CHANGED
@@ -18,6 +18,7 @@ from pydantic import BaseModel
|
|
18
18
|
|
19
19
|
from nuclia import REGIONAL
|
20
20
|
from nuclia.exceptions import NuaAPIException
|
21
|
+
from nuclia_models.common.consumption import ConsumptionGenerative
|
21
22
|
from nuclia_models.predict.generative_responses import (
|
22
23
|
GenerativeChunk,
|
23
24
|
GenerativeFullResponse,
|
@@ -28,6 +29,7 @@ from nuclia_models.predict.generative_responses import (
|
|
28
29
|
StatusGenerativeResponse,
|
29
30
|
ToolsGenerativeResponse,
|
30
31
|
)
|
32
|
+
from nuclia_models.common.consumption import Consumption
|
31
33
|
from nuclia.lib.nua_responses import (
|
32
34
|
ChatModel,
|
33
35
|
ChatResponse,
|
@@ -42,6 +44,8 @@ from nuclia.lib.nua_responses import (
|
|
42
44
|
PushResponseV2,
|
43
45
|
QueryInfo,
|
44
46
|
RephraseModel,
|
47
|
+
RerankModel,
|
48
|
+
RerankResponse,
|
45
49
|
RestrictedIDString,
|
46
50
|
Sentence,
|
47
51
|
Source,
|
@@ -77,6 +81,7 @@ PUSH_PROCESS = "/api/v2/processing/push"
|
|
77
81
|
SCHEMA = "/api/v1/learning/configuration/schema"
|
78
82
|
SCHEMA_KBID = "/api/v1/schema"
|
79
83
|
CONFIG = "/api/v1/config"
|
84
|
+
RERANK = "/api/v1/predict/rerank"
|
80
85
|
|
81
86
|
ConvertType = TypeVar("ConvertType", bound=BaseModel)
|
82
87
|
|
@@ -125,9 +130,12 @@ class NuaClient:
|
|
125
130
|
url: str,
|
126
131
|
output: Type[ConvertType],
|
127
132
|
payload: Optional[dict[Any, Any]] = None,
|
133
|
+
extra_headers: Optional[dict[str, str]] = None,
|
128
134
|
timeout: int = 60,
|
129
135
|
) -> ConvertType:
|
130
|
-
resp = self.client.request(
|
136
|
+
resp = self.client.request(
|
137
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
138
|
+
)
|
131
139
|
if resp.status_code != 200:
|
132
140
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
133
141
|
try:
|
@@ -140,6 +148,7 @@ class NuaClient:
|
|
140
148
|
self,
|
141
149
|
method: str,
|
142
150
|
url: str,
|
151
|
+
extra_headers: Optional[dict[str, str]] = None,
|
143
152
|
payload: Optional[dict[Any, Any]] = None,
|
144
153
|
timeout: int = 60,
|
145
154
|
) -> Iterator[GenerativeChunk]:
|
@@ -148,8 +157,9 @@ class NuaClient:
|
|
148
157
|
url,
|
149
158
|
json=payload,
|
150
159
|
timeout=timeout,
|
160
|
+
headers=extra_headers,
|
151
161
|
) as response:
|
152
|
-
if response.headers.get("
|
162
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
153
163
|
for json_body in response.iter_lines():
|
154
164
|
try:
|
155
165
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -191,17 +201,31 @@ class NuaClient:
|
|
191
201
|
endpoint = f"{self.url}{CONFIG}/{kbid}"
|
192
202
|
return self._request("GET", endpoint, output=StoredLearningConfiguration)
|
193
203
|
|
194
|
-
def sentence_predict(
|
204
|
+
def sentence_predict(
|
205
|
+
self,
|
206
|
+
text: str,
|
207
|
+
model: Optional[str] = None,
|
208
|
+
extra_headers: Optional[dict[str, str]] = None,
|
209
|
+
) -> Sentence:
|
195
210
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
196
211
|
if model:
|
197
212
|
endpoint += f"&model={model}"
|
198
|
-
return self._request(
|
213
|
+
return self._request(
|
214
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
215
|
+
)
|
199
216
|
|
200
|
-
def tokens_predict(
|
217
|
+
def tokens_predict(
|
218
|
+
self,
|
219
|
+
text: str,
|
220
|
+
model: Optional[str] = None,
|
221
|
+
extra_headers: Optional[dict[str, str]] = None,
|
222
|
+
) -> Tokens:
|
201
223
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
202
224
|
if model:
|
203
225
|
endpoint += f"&model={model}"
|
204
|
-
return self._request(
|
226
|
+
return self._request(
|
227
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
228
|
+
)
|
205
229
|
|
206
230
|
def query_predict(
|
207
231
|
self,
|
@@ -209,6 +233,7 @@ class NuaClient:
|
|
209
233
|
semantic_model: Optional[str] = None,
|
210
234
|
token_model: Optional[str] = None,
|
211
235
|
generative_model: Optional[str] = None,
|
236
|
+
extra_headers: Optional[dict[str, str]] = None,
|
212
237
|
) -> QueryInfo:
|
213
238
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
214
239
|
if semantic_model:
|
@@ -217,10 +242,16 @@ class NuaClient:
|
|
217
242
|
endpoint += f"&token_model={token_model}"
|
218
243
|
if generative_model:
|
219
244
|
endpoint += f"&generative_model={generative_model}"
|
220
|
-
return self._request(
|
245
|
+
return self._request(
|
246
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
247
|
+
)
|
221
248
|
|
222
249
|
def generate(
|
223
|
-
self,
|
250
|
+
self,
|
251
|
+
body: ChatModel,
|
252
|
+
model: Optional[str] = None,
|
253
|
+
extra_headers: Optional[dict[str, str]] = None,
|
254
|
+
timeout: int = 300,
|
224
255
|
) -> GenerativeFullResponse:
|
225
256
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
226
257
|
if model:
|
@@ -232,6 +263,7 @@ class NuaClient:
|
|
232
263
|
endpoint,
|
233
264
|
payload=body.model_dump(),
|
234
265
|
timeout=timeout,
|
266
|
+
extra_headers=extra_headers,
|
235
267
|
):
|
236
268
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
237
269
|
result.answer += chunk.chunk.text
|
@@ -249,10 +281,19 @@ class NuaClient:
|
|
249
281
|
result.code = chunk.chunk.code
|
250
282
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
251
283
|
result.tools = chunk.chunk.tools
|
284
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
285
|
+
result.consumption = Consumption(
|
286
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
287
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
288
|
+
)
|
252
289
|
return result
|
253
290
|
|
254
291
|
def generate_stream(
|
255
|
-
self,
|
292
|
+
self,
|
293
|
+
body: ChatModel,
|
294
|
+
model: Optional[str] = None,
|
295
|
+
extra_headers: Optional[dict[str, str]] = None,
|
296
|
+
timeout: int = 300,
|
256
297
|
) -> Iterator[GenerativeChunk]:
|
257
298
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
258
299
|
if model:
|
@@ -263,11 +304,16 @@ class NuaClient:
|
|
263
304
|
endpoint,
|
264
305
|
payload=body.model_dump(),
|
265
306
|
timeout=timeout,
|
307
|
+
extra_headers=extra_headers,
|
266
308
|
):
|
267
309
|
yield gr
|
268
310
|
|
269
311
|
def summarize(
|
270
|
-
self,
|
312
|
+
self,
|
313
|
+
documents: dict[str, str],
|
314
|
+
model: Optional[str] = None,
|
315
|
+
extra_headers: Optional[dict[str, str]] = None,
|
316
|
+
timeout: int = 300,
|
271
317
|
) -> SummarizedModel:
|
272
318
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
273
319
|
if model:
|
@@ -285,6 +331,7 @@ class NuaClient:
|
|
285
331
|
payload=body.model_dump(),
|
286
332
|
output=SummarizedModel,
|
287
333
|
timeout=timeout,
|
334
|
+
extra_headers=extra_headers,
|
288
335
|
)
|
289
336
|
|
290
337
|
def rephrase(
|
@@ -321,11 +368,13 @@ class NuaClient:
|
|
321
368
|
def remi(
|
322
369
|
self,
|
323
370
|
request: RemiRequest,
|
371
|
+
extra_headers: Optional[dict[str, str]] = None,
|
324
372
|
) -> RemiResponse:
|
325
373
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
326
374
|
return self._request(
|
327
375
|
"POST",
|
328
376
|
endpoint,
|
377
|
+
extra_headers=extra_headers,
|
329
378
|
payload=request.model_dump(),
|
330
379
|
output=RemiResponse,
|
331
380
|
)
|
@@ -410,6 +459,20 @@ class NuaClient:
|
|
410
459
|
activity_endpoint = f"{self.url}{STATUS_PROCESS}/{process_id}"
|
411
460
|
return self._request("GET", activity_endpoint, ProcessRequestStatus)
|
412
461
|
|
462
|
+
def rerank(
|
463
|
+
self,
|
464
|
+
model: RerankModel,
|
465
|
+
extra_headers: Optional[dict[str, str]] = None,
|
466
|
+
) -> RerankResponse:
|
467
|
+
endpoint = f"{self.url}{RERANK}"
|
468
|
+
return self._request(
|
469
|
+
"POST",
|
470
|
+
endpoint,
|
471
|
+
payload=model.model_dump(),
|
472
|
+
output=RerankResponse,
|
473
|
+
extra_headers=extra_headers,
|
474
|
+
)
|
475
|
+
|
413
476
|
|
414
477
|
class AsyncNuaClient:
|
415
478
|
def __init__(
|
@@ -445,9 +508,12 @@ class AsyncNuaClient:
|
|
445
508
|
url: str,
|
446
509
|
output: Type[ConvertType],
|
447
510
|
payload: Optional[dict[Any, Any]] = None,
|
511
|
+
extra_headers: Optional[dict[str, str]] = None,
|
448
512
|
timeout: int = 60,
|
449
513
|
) -> ConvertType:
|
450
|
-
resp = await self.client.request(
|
514
|
+
resp = await self.client.request(
|
515
|
+
method, url, json=payload, timeout=timeout, headers=extra_headers
|
516
|
+
)
|
451
517
|
if resp.status_code != 200:
|
452
518
|
raise NuaAPIException(code=resp.status_code, detail=resp.content.decode())
|
453
519
|
try:
|
@@ -460,6 +526,7 @@ class AsyncNuaClient:
|
|
460
526
|
self,
|
461
527
|
method: str,
|
462
528
|
url: str,
|
529
|
+
extra_headers: Optional[dict[str, str]] = None,
|
463
530
|
payload: Optional[dict[Any, Any]] = None,
|
464
531
|
timeout: int = 60,
|
465
532
|
) -> AsyncIterator[GenerativeChunk]:
|
@@ -468,8 +535,9 @@ class AsyncNuaClient:
|
|
468
535
|
url,
|
469
536
|
json=payload,
|
470
537
|
timeout=timeout,
|
538
|
+
headers=extra_headers,
|
471
539
|
) as response:
|
472
|
-
if response.headers.get("
|
540
|
+
if response.headers.get("transfer-encoding") == "chunked":
|
473
541
|
async for json_body in response.aiter_lines():
|
474
542
|
try:
|
475
543
|
yield GenerativeChunk.model_validate_json(json_body) # type: ignore
|
@@ -518,18 +586,30 @@ class AsyncNuaClient:
|
|
518
586
|
return await self._request("GET", endpoint, output=StoredLearningConfiguration)
|
519
587
|
|
520
588
|
async def sentence_predict(
|
521
|
-
self,
|
589
|
+
self,
|
590
|
+
text: str,
|
591
|
+
model: Optional[str] = None,
|
592
|
+
extra_headers: Optional[dict[str, str]] = None,
|
522
593
|
) -> Sentence:
|
523
594
|
endpoint = f"{self.url}{SENTENCE_PREDICT}?text={text}"
|
524
595
|
if model:
|
525
596
|
endpoint += f"&model={model}"
|
526
|
-
return await self._request(
|
597
|
+
return await self._request(
|
598
|
+
"GET", endpoint, output=Sentence, extra_headers=extra_headers
|
599
|
+
)
|
527
600
|
|
528
|
-
async def tokens_predict(
|
601
|
+
async def tokens_predict(
|
602
|
+
self,
|
603
|
+
text: str,
|
604
|
+
model: Optional[str] = None,
|
605
|
+
extra_headers: Optional[dict[str, str]] = None,
|
606
|
+
) -> Tokens:
|
529
607
|
endpoint = f"{self.url}{TOKENS_PREDICT}?text={text}"
|
530
608
|
if model:
|
531
609
|
endpoint += f"&model={model}"
|
532
|
-
return await self._request(
|
610
|
+
return await self._request(
|
611
|
+
"GET", endpoint, output=Tokens, extra_headers=extra_headers
|
612
|
+
)
|
533
613
|
|
534
614
|
async def query_predict(
|
535
615
|
self,
|
@@ -537,6 +617,7 @@ class AsyncNuaClient:
|
|
537
617
|
semantic_model: Optional[str] = None,
|
538
618
|
token_model: Optional[str] = None,
|
539
619
|
generative_model: Optional[str] = None,
|
620
|
+
extra_headers: Optional[dict[str, str]] = None,
|
540
621
|
) -> QueryInfo:
|
541
622
|
endpoint = f"{self.url}{QUERY_PREDICT}?text={text}"
|
542
623
|
if semantic_model:
|
@@ -545,7 +626,9 @@ class AsyncNuaClient:
|
|
545
626
|
endpoint += f"&token_model={token_model}"
|
546
627
|
if generative_model:
|
547
628
|
endpoint += f"&generative_model={generative_model}"
|
548
|
-
return await self._request(
|
629
|
+
return await self._request(
|
630
|
+
"GET", endpoint, output=QueryInfo, extra_headers=extra_headers
|
631
|
+
)
|
549
632
|
|
550
633
|
@deprecated(version="2.1.0", reason="You should use generate function")
|
551
634
|
async def generate_predict(
|
@@ -564,7 +647,11 @@ class AsyncNuaClient:
|
|
564
647
|
)
|
565
648
|
|
566
649
|
async def generate(
|
567
|
-
self,
|
650
|
+
self,
|
651
|
+
body: ChatModel,
|
652
|
+
model: Optional[str] = None,
|
653
|
+
extra_headers: Optional[dict[str, str]] = None,
|
654
|
+
timeout: int = 300,
|
568
655
|
) -> GenerativeFullResponse:
|
569
656
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
570
657
|
if model:
|
@@ -576,6 +663,7 @@ class AsyncNuaClient:
|
|
576
663
|
endpoint,
|
577
664
|
payload=body.model_dump(),
|
578
665
|
timeout=timeout,
|
666
|
+
extra_headers=extra_headers,
|
579
667
|
):
|
580
668
|
if isinstance(chunk.chunk, TextGenerativeResponse):
|
581
669
|
result.answer += chunk.chunk.text
|
@@ -593,11 +681,20 @@ class AsyncNuaClient:
|
|
593
681
|
result.code = chunk.chunk.code
|
594
682
|
elif isinstance(chunk.chunk, ToolsGenerativeResponse):
|
595
683
|
result.tools = chunk.chunk.tools
|
684
|
+
elif isinstance(chunk.chunk, ConsumptionGenerative):
|
685
|
+
result.consumption = Consumption(
|
686
|
+
normalized_tokens=chunk.chunk.normalized_tokens,
|
687
|
+
customer_key_tokens=chunk.chunk.customer_key_tokens,
|
688
|
+
)
|
596
689
|
|
597
690
|
return result
|
598
691
|
|
599
692
|
async def generate_stream(
|
600
|
-
self,
|
693
|
+
self,
|
694
|
+
body: ChatModel,
|
695
|
+
model: Optional[str] = None,
|
696
|
+
extra_headers: Optional[dict[str, str]] = None,
|
697
|
+
timeout: int = 300,
|
601
698
|
) -> AsyncIterator[GenerativeChunk]:
|
602
699
|
endpoint = f"{self.url}{CHAT_PREDICT}"
|
603
700
|
if model:
|
@@ -608,11 +705,16 @@ class AsyncNuaClient:
|
|
608
705
|
endpoint,
|
609
706
|
payload=body.model_dump(),
|
610
707
|
timeout=timeout,
|
708
|
+
extra_headers=extra_headers,
|
611
709
|
):
|
612
710
|
yield gr
|
613
711
|
|
614
712
|
async def summarize(
|
615
|
-
self,
|
713
|
+
self,
|
714
|
+
documents: dict[str, str],
|
715
|
+
model: Optional[str] = None,
|
716
|
+
extra_headers: Optional[dict[str, str]] = None,
|
717
|
+
timeout: int = 300,
|
616
718
|
) -> SummarizedModel:
|
617
719
|
endpoint = f"{self.url}{SUMMARIZE_PREDICT}"
|
618
720
|
if model:
|
@@ -630,6 +732,7 @@ class AsyncNuaClient:
|
|
630
732
|
payload=body.model_dump(),
|
631
733
|
output=SummarizedModel,
|
632
734
|
timeout=timeout,
|
735
|
+
extra_headers=extra_headers,
|
633
736
|
)
|
634
737
|
|
635
738
|
async def rephrase(
|
@@ -663,13 +766,18 @@ class AsyncNuaClient:
|
|
663
766
|
output=RephraseModel,
|
664
767
|
)
|
665
768
|
|
666
|
-
async def remi(
|
769
|
+
async def remi(
|
770
|
+
self,
|
771
|
+
request: RemiRequest,
|
772
|
+
extra_headers: Optional[dict[str, str]] = None,
|
773
|
+
) -> RemiResponse:
|
667
774
|
endpoint = f"{self.url}{REMI_PREDICT}"
|
668
775
|
return await self._request(
|
669
776
|
"POST",
|
670
777
|
endpoint,
|
671
778
|
payload=request.model_dump(),
|
672
779
|
output=RemiResponse,
|
780
|
+
extra_headers=extra_headers,
|
673
781
|
)
|
674
782
|
|
675
783
|
async def generate_retrieval(
|
@@ -792,3 +900,15 @@ class AsyncNuaClient:
|
|
792
900
|
return await self._request(
|
793
901
|
"GET", activity_endpoint, output=ProcessRequestStatus
|
794
902
|
)
|
903
|
+
|
904
|
+
async def rerank(
|
905
|
+
self, model: RerankModel, extra_headers: Optional[dict[str, str]] = None
|
906
|
+
) -> RerankResponse:
|
907
|
+
endpoint = f"{self.url}{RERANK}"
|
908
|
+
return await self._request(
|
909
|
+
"POST",
|
910
|
+
endpoint,
|
911
|
+
payload=model.model_dump(),
|
912
|
+
output=RerankResponse,
|
913
|
+
extra_headers=extra_headers,
|
914
|
+
)
|
nuclia/lib/nua_responses.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union, cast
|
|
5
5
|
import pydantic
|
6
6
|
from pydantic import BaseModel, Field, RootModel, model_validator
|
7
7
|
from typing_extensions import Annotated, Self
|
8
|
+
from nuclia_models.common.consumption import Consumption
|
8
9
|
|
9
10
|
|
10
11
|
class GenerativeOption(BaseModel):
|
@@ -51,6 +52,7 @@ class ConfigSchema(BaseModel):
|
|
51
52
|
class Sentence(BaseModel):
|
52
53
|
data: List[float]
|
53
54
|
time: float
|
55
|
+
consumption: Optional[Consumption] = None
|
54
56
|
|
55
57
|
|
56
58
|
class Author(str, Enum):
|
@@ -137,6 +139,7 @@ class Token(BaseModel):
|
|
137
139
|
class Tokens(BaseModel):
|
138
140
|
tokens: List[Token]
|
139
141
|
time: float
|
142
|
+
consumption: Optional[Consumption] = None
|
140
143
|
|
141
144
|
|
142
145
|
class SummarizeResource(BaseModel):
|
@@ -155,6 +158,7 @@ class SummarizedResource(BaseModel):
|
|
155
158
|
class SummarizedModel(BaseModel):
|
156
159
|
resources: Dict[str, SummarizedResource]
|
157
160
|
summary: str = ""
|
161
|
+
consumption: Optional[Consumption] = None
|
158
162
|
|
159
163
|
|
160
164
|
class RephraseModel(RootModel[str]):
|
@@ -535,6 +539,7 @@ class StoredLearningConfiguration(BaseModel):
|
|
535
539
|
class SentenceSearch(BaseModel):
|
536
540
|
data: List[float] = []
|
537
541
|
time: float
|
542
|
+
consumption: Optional[Consumption] = None
|
538
543
|
|
539
544
|
|
540
545
|
class Ner(BaseModel):
|
@@ -547,6 +552,7 @@ class Ner(BaseModel):
|
|
547
552
|
class TokenSearch(BaseModel):
|
548
553
|
tokens: List[Ner] = []
|
549
554
|
time: float
|
555
|
+
consumption: Optional[Consumption] = None
|
550
556
|
|
551
557
|
|
552
558
|
class QueryInfo(BaseModel):
|
@@ -557,3 +563,16 @@ class QueryInfo(BaseModel):
|
|
557
563
|
max_context: int
|
558
564
|
entities: Optional[TokenSearch]
|
559
565
|
sentence: Optional[SentenceSearch]
|
566
|
+
|
567
|
+
|
568
|
+
class RerankModel(BaseModel):
|
569
|
+
question: str
|
570
|
+
user_id: str
|
571
|
+
context: dict[str, str] = {}
|
572
|
+
|
573
|
+
|
574
|
+
class RerankResponse(BaseModel):
|
575
|
+
context_scores: dict[str, float] = Field(
|
576
|
+
description="Scores for each context given by the reranker"
|
577
|
+
)
|
578
|
+
consumption: Optional[Consumption] = None
|
nuclia/sdk/kb.py
CHANGED
@@ -297,9 +297,15 @@ class NucliaKB:
|
|
297
297
|
)
|
298
298
|
|
299
299
|
@kb
|
300
|
-
def summarize(
|
300
|
+
def summarize(
|
301
|
+
self, *, resources: List[str], show_consumption: bool = False, **kwargs
|
302
|
+
):
|
301
303
|
ndb: NucliaDBClient = kwargs["ndb"]
|
302
|
-
return ndb.ndb.summarize(
|
304
|
+
return ndb.ndb.summarize(
|
305
|
+
kbid=ndb.kbid,
|
306
|
+
resources=resources,
|
307
|
+
headers={"X-Show-Consumption": str(show_consumption).lower()},
|
308
|
+
)
|
303
309
|
|
304
310
|
@kb
|
305
311
|
def notifications(self, **kwargs):
|
@@ -715,6 +721,7 @@ class AsyncNucliaKB:
|
|
715
721
|
resources: List[str],
|
716
722
|
generative_model: Optional[str] = None,
|
717
723
|
summary_kind: Optional[str] = None,
|
724
|
+
show_consumption: bool = False,
|
718
725
|
timeout: int = 1000,
|
719
726
|
**kwargs,
|
720
727
|
) -> SummarizedModel:
|
@@ -725,6 +732,7 @@ class AsyncNucliaKB:
|
|
725
732
|
generative_model=generative_model,
|
726
733
|
summary_kind=SummaryKind(summary_kind),
|
727
734
|
),
|
735
|
+
extra_headers={"X-Show-Consumption": str(show_consumption).lower()},
|
728
736
|
timeout=timeout,
|
729
737
|
)
|
730
738
|
return SummarizedModel.model_validate(resp.json())
|