nucliadb 6.2.1.post3312__py3-none-any.whl → 6.2.1.post3326__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/common/models_utils/from_proto.py +16 -7
- nucliadb/search/api/v1/__init__.py +1 -0
- nucliadb/search/api/v1/resource/ask.py +2 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +123 -0
- nucliadb/search/api/v1/resource/utils.py +28 -0
- nucliadb/search/predict.py +41 -0
- nucliadb/search/predict_models.py +160 -0
- nucliadb/search/search/ingestion_agents.py +88 -0
- {nucliadb-6.2.1.post3312.dist-info → nucliadb-6.2.1.post3326.dist-info}/METADATA +6 -6
- {nucliadb-6.2.1.post3312.dist-info → nucliadb-6.2.1.post3326.dist-info}/RECORD +13 -9
- {nucliadb-6.2.1.post3312.dist-info → nucliadb-6.2.1.post3326.dist-info}/WHEEL +0 -0
- {nucliadb-6.2.1.post3312.dist-info → nucliadb-6.2.1.post3326.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.1.post3312.dist-info → nucliadb-6.2.1.post3326.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,13 @@ from typing import Any
|
|
23
23
|
|
24
24
|
from google.protobuf.json_format import MessageToDict
|
25
25
|
|
26
|
-
from nucliadb_models.common import Classification, FieldID, FieldTypeName
|
26
|
+
from nucliadb_models.common import Classification, FieldID, FieldTypeName, QuestionAnswers
|
27
27
|
from nucliadb_models.conversation import Conversation, FieldConversation
|
28
28
|
from nucliadb_models.entities import EntitiesGroup, EntitiesGroupSummary, Entity
|
29
29
|
from nucliadb_models.extracted import (
|
30
30
|
ExtractedText,
|
31
31
|
FieldComputedMetadata,
|
32
|
+
FieldMetadata,
|
32
33
|
FieldQuestionAnswers,
|
33
34
|
FileExtractedData,
|
34
35
|
LargeComputedMetadata,
|
@@ -236,6 +237,15 @@ def field_question_answers(
|
|
236
237
|
return FieldQuestionAnswers(**value)
|
237
238
|
|
238
239
|
|
240
|
+
def question_answers(message: resources_pb2.QuestionAnswers) -> QuestionAnswers:
|
241
|
+
value = MessageToDict(
|
242
|
+
message,
|
243
|
+
preserving_proto_field_name=True,
|
244
|
+
including_default_value_fields=True,
|
245
|
+
)
|
246
|
+
return QuestionAnswers(**value)
|
247
|
+
|
248
|
+
|
239
249
|
def extracted_text(message: resources_pb2.ExtractedText) -> ExtractedText:
|
240
250
|
return ExtractedText(
|
241
251
|
**MessageToDict(
|
@@ -304,10 +314,9 @@ def field_computed_metadata(
|
|
304
314
|
) -> FieldComputedMetadata:
|
305
315
|
if shortened:
|
306
316
|
shorten_fieldmetadata(message)
|
307
|
-
metadata =
|
317
|
+
metadata = field_metadata(message.metadata)
|
308
318
|
split_metadata = {
|
309
|
-
split:
|
310
|
-
for split, metadata_split in message.split_metadata.items()
|
319
|
+
split: field_metadata(metadata_split) for split, metadata_split in message.split_metadata.items()
|
311
320
|
}
|
312
321
|
value = MessageToDict(
|
313
322
|
message,
|
@@ -319,9 +328,9 @@ def field_computed_metadata(
|
|
319
328
|
return FieldComputedMetadata(**value)
|
320
329
|
|
321
330
|
|
322
|
-
def
|
331
|
+
def field_metadata(
|
323
332
|
message: resources_pb2.FieldMetadata,
|
324
|
-
) ->
|
333
|
+
) -> FieldMetadata:
|
325
334
|
# Backwards compatibility with old entities format
|
326
335
|
# TODO: Remove once deprecated fields are removed
|
327
336
|
# If we recieved processor entities in the new field and the old field is empty, we copy them to the old field
|
@@ -347,7 +356,7 @@ def convert_fieldmetadata_pb_to_dict(
|
|
347
356
|
value["relations"] = [
|
348
357
|
convert_pb_relation_to_api(rel) for relations in message.relations for rel in relations.relations
|
349
358
|
]
|
350
|
-
return value
|
359
|
+
return FieldMetadata(**value)
|
351
360
|
|
352
361
|
|
353
362
|
def conversation(message: resources_pb2.Conversation) -> Conversation:
|
@@ -28,4 +28,5 @@ from . import suggest # noqa
|
|
28
28
|
from . import summarize # noqa
|
29
29
|
from .resource import ask as ask_resource # noqa
|
30
30
|
from .resource import search as search_resource # noqa
|
31
|
+
from .resource import ingestion_agents as ingestion_agents_resource # noqa
|
31
32
|
from .router import api # noqa
|
@@ -17,14 +17,14 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
from typing import
|
20
|
+
from typing import Union
|
21
21
|
|
22
22
|
from fastapi import Header, Request, Response
|
23
23
|
from fastapi_versioning import version
|
24
24
|
from starlette.responses import StreamingResponse
|
25
25
|
|
26
|
-
from nucliadb.common import datamanagers
|
27
26
|
from nucliadb.models.responses import HTTPClientError
|
27
|
+
from nucliadb.search.api.v1.resource.utils import get_resource_uuid_by_slug
|
28
28
|
from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_SLUG_PREFIX, api
|
29
29
|
from nucliadb_models.resource import NucliaDBRoles
|
30
30
|
from nucliadb_models.search import AskRequest, NucliaDBClientType, SyncAskResponse
|
@@ -104,8 +104,3 @@ async def resource_ask_endpoint_by_slug(
|
|
104
104
|
x_synchronous,
|
105
105
|
resource=resource_id,
|
106
106
|
)
|
107
|
-
|
108
|
-
|
109
|
-
async def get_resource_uuid_by_slug(kbid: str, slug: str) -> Optional[str]:
|
110
|
-
async with datamanagers.with_ro_transaction() as txn:
|
111
|
-
return await datamanagers.resources.get_resource_uuid_from_slug(txn, kbid=kbid, slug=slug)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
from typing import Union
|
21
|
+
|
22
|
+
from fastapi import Header, Request, Response
|
23
|
+
from fastapi_versioning import version
|
24
|
+
|
25
|
+
from nucliadb.common.models_utils import from_proto
|
26
|
+
from nucliadb.models.responses import HTTPClientError
|
27
|
+
from nucliadb.search.api.v1.resource.utils import get_resource_uuid_by_slug
|
28
|
+
from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_PREFIX, RESOURCE_SLUG_PREFIX, api
|
29
|
+
from nucliadb.search.predict_models import AugmentedField, RunAgentsResponse
|
30
|
+
from nucliadb.search.search.exceptions import ResourceNotFoundError
|
31
|
+
from nucliadb.search.search.ingestion_agents import run_agents
|
32
|
+
from nucliadb_models.agents.ingestion import (
|
33
|
+
AppliedDataAugmentation,
|
34
|
+
NewTextField,
|
35
|
+
ResourceAgentsRequest,
|
36
|
+
ResourceAgentsResponse,
|
37
|
+
)
|
38
|
+
from nucliadb_models.agents.ingestion import AugmentedField as PublicAugmentedField
|
39
|
+
from nucliadb_models.resource import NucliaDBRoles
|
40
|
+
from nucliadb_utils.authentication import requires_one
|
41
|
+
|
42
|
+
|
43
|
+
@api.post(
|
44
|
+
f"/{KB_PREFIX}/{{kbid}}/{RESOURCE_PREFIX}/{{rid}}/run-agents",
|
45
|
+
status_code=200,
|
46
|
+
summary="Run Agents on Resource",
|
47
|
+
description="Run Agents on Resource",
|
48
|
+
tags=["Ingestion Agents"],
|
49
|
+
response_model_exclude_unset=True,
|
50
|
+
response_model=ResourceAgentsResponse,
|
51
|
+
)
|
52
|
+
@requires_one([NucliaDBRoles.READER])
|
53
|
+
@version(1)
|
54
|
+
async def run_agents_by_uuid(
|
55
|
+
request: Request,
|
56
|
+
response: Response,
|
57
|
+
kbid: str,
|
58
|
+
rid: str,
|
59
|
+
item: ResourceAgentsRequest,
|
60
|
+
x_nucliadb_user: str = Header(""),
|
61
|
+
) -> Union[ResourceAgentsResponse, HTTPClientError]:
|
62
|
+
return await _run_agents_endpoint(kbid, rid, x_nucliadb_user, item)
|
63
|
+
|
64
|
+
|
65
|
+
@api.post(
|
66
|
+
f"/{KB_PREFIX}/{{kbid}}/{RESOURCE_SLUG_PREFIX}/{{slug}}/run-agents",
|
67
|
+
status_code=200,
|
68
|
+
summary="Run Agents on Resource (by slug)",
|
69
|
+
description="Run Agents on Resource (by slug)",
|
70
|
+
tags=["Ingestion Agents"],
|
71
|
+
response_model_exclude_unset=True,
|
72
|
+
response_model=ResourceAgentsResponse,
|
73
|
+
)
|
74
|
+
@requires_one([NucliaDBRoles.READER])
|
75
|
+
@version(1)
|
76
|
+
async def run_agents_by_slug(
|
77
|
+
request: Request,
|
78
|
+
response: Response,
|
79
|
+
kbid: str,
|
80
|
+
slug: str,
|
81
|
+
item: ResourceAgentsRequest,
|
82
|
+
x_nucliadb_user: str = Header(""),
|
83
|
+
) -> Union[ResourceAgentsResponse, HTTPClientError]:
|
84
|
+
resource_id = await get_resource_uuid_by_slug(kbid, slug)
|
85
|
+
if resource_id is None:
|
86
|
+
return HTTPClientError(status_code=404, detail="Resource not found")
|
87
|
+
return await _run_agents_endpoint(kbid, resource_id, x_nucliadb_user, item)
|
88
|
+
|
89
|
+
|
90
|
+
async def _run_agents_endpoint(
|
91
|
+
kbid: str, resource_id: str, user_id: str, item: ResourceAgentsRequest
|
92
|
+
) -> Union[ResourceAgentsResponse, HTTPClientError]:
|
93
|
+
try:
|
94
|
+
run_agents_response: RunAgentsResponse = await run_agents(
|
95
|
+
kbid, resource_id, user_id, filters=item.filters
|
96
|
+
)
|
97
|
+
except ResourceNotFoundError:
|
98
|
+
return HTTPClientError(status_code=404, detail="Resource not found")
|
99
|
+
response = ResourceAgentsResponse(results={})
|
100
|
+
for field_id, augmented_field in run_agents_response.results.items():
|
101
|
+
response.results[field_id] = _parse_augmented_field(augmented_field)
|
102
|
+
return response
|
103
|
+
|
104
|
+
|
105
|
+
def _parse_augmented_field(augmented_field: AugmentedField) -> PublicAugmentedField:
|
106
|
+
ada = augmented_field.applied_data_augmentation
|
107
|
+
return PublicAugmentedField(
|
108
|
+
metadata=from_proto.field_metadata(augmented_field.metadata),
|
109
|
+
applied_data_augmentation=AppliedDataAugmentation(
|
110
|
+
changed=ada.changed,
|
111
|
+
qas=from_proto.question_answers(ada.qas) if ada.qas else None,
|
112
|
+
new_text_fields=[
|
113
|
+
NewTextField(
|
114
|
+
text_field=from_proto.field_text(ntf["text_field"]),
|
115
|
+
destination=ntf["destination"],
|
116
|
+
)
|
117
|
+
for ntf in ada.new_text_fields
|
118
|
+
],
|
119
|
+
),
|
120
|
+
input_nuclia_tokens=augmented_field.input_nuclia_tokens,
|
121
|
+
output_nuclia_tokens=augmented_field.output_nuclia_tokens,
|
122
|
+
time=augmented_field.time,
|
123
|
+
)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
|
21
|
+
from typing import Optional
|
22
|
+
|
23
|
+
from nucliadb.common import datamanagers
|
24
|
+
|
25
|
+
|
26
|
+
async def get_resource_uuid_by_slug(kbid: str, slug: str) -> Optional[str]:
|
27
|
+
async with datamanagers.with_ro_transaction() as txn:
|
28
|
+
return await datamanagers.resources.get_resource_uuid_from_slug(txn, kbid=kbid, slug=slug)
|
nucliadb/search/predict.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
+
import base64
|
20
21
|
import json
|
21
22
|
import os
|
22
23
|
import random
|
@@ -31,6 +32,12 @@ from pydantic import ValidationError
|
|
31
32
|
|
32
33
|
from nucliadb.common import datamanagers
|
33
34
|
from nucliadb.search import logger
|
35
|
+
from nucliadb.search.predict_models import (
|
36
|
+
AppliedDataAugmentation,
|
37
|
+
AugmentedField,
|
38
|
+
RunAgentsRequest,
|
39
|
+
RunAgentsResponse,
|
40
|
+
)
|
34
41
|
from nucliadb.tests.vectors import Q, Qm2023
|
35
42
|
from nucliadb_models.internal.predict import (
|
36
43
|
Ner,
|
@@ -47,6 +54,7 @@ from nucliadb_models.search import (
|
|
47
54
|
SummarizedResponse,
|
48
55
|
SummarizeModel,
|
49
56
|
)
|
57
|
+
from nucliadb_protos.resources_pb2 import FieldMetadata
|
50
58
|
from nucliadb_protos.utils_pb2 import RelationNode
|
51
59
|
from nucliadb_telemetry import errors, metrics
|
52
60
|
from nucliadb_utils.exceptions import LimitsExceededError
|
@@ -401,6 +409,24 @@ class PredictEngine:
|
|
401
409
|
data = await resp.json()
|
402
410
|
return RerankResponse.model_validate(data)
|
403
411
|
|
412
|
+
@predict_observer.wrap({"type": "run_agents"})
|
413
|
+
async def run_agents(self, kbid: str, item: RunAgentsRequest) -> RunAgentsResponse:
|
414
|
+
try:
|
415
|
+
self.check_nua_key_is_configured_for_onprem()
|
416
|
+
except NUAKeyMissingError:
|
417
|
+
error = "Nuclia Service account is not defined. Summarize operation could not be performed"
|
418
|
+
logger.warning(error)
|
419
|
+
raise SendToPredictError(error)
|
420
|
+
resp = await self.make_request(
|
421
|
+
"POST",
|
422
|
+
url=self.get_predict_url("/run-agents", kbid),
|
423
|
+
json=item.model_dump(),
|
424
|
+
headers=self.get_predict_headers(kbid),
|
425
|
+
)
|
426
|
+
await self.check_response(resp, expected_status=200)
|
427
|
+
data = await resp.json()
|
428
|
+
return RunAgentsResponse.model_validate(data)
|
429
|
+
|
404
430
|
|
405
431
|
class DummyPredictEngine(PredictEngine):
|
406
432
|
default_semantic_threshold = 0.7
|
@@ -529,6 +555,21 @@ class DummyPredictEngine(PredictEngine):
|
|
529
555
|
)
|
530
556
|
return response
|
531
557
|
|
558
|
+
async def run_agents(self, kbid: str, item: RunAgentsRequest) -> RunAgentsResponse:
|
559
|
+
self.calls.append(("run_agents", (kbid, item)))
|
560
|
+
fm = FieldMetadata()
|
561
|
+
ada = AppliedDataAugmentation()
|
562
|
+
serialized_fm = base64.b64encode(fm.SerializeToString()).decode("utf-8")
|
563
|
+
augmented_field = AugmentedField(
|
564
|
+
metadata=serialized_fm, # type: ignore
|
565
|
+
applied_data_augmentation=ada,
|
566
|
+
input_nuclia_tokens=1.0,
|
567
|
+
output_nuclia_tokens=1.0,
|
568
|
+
time=1.0,
|
569
|
+
)
|
570
|
+
response = RunAgentsResponse(results={"field_id": augmented_field})
|
571
|
+
return response
|
572
|
+
|
532
573
|
|
533
574
|
def get_answer_generator(response: aiohttp.ClientResponse):
|
534
575
|
"""
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
|
20
|
+
from base64 import b64decode, b64encode
|
21
|
+
from enum import Enum
|
22
|
+
from typing import Optional, TypedDict
|
23
|
+
|
24
|
+
from google.protobuf.message import DecodeError, Message
|
25
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
26
|
+
|
27
|
+
from nucliadb_protos.resources_pb2 import FieldMetadata, FieldText, QuestionAnswers
|
28
|
+
|
29
|
+
|
30
|
+
class FieldInfo(BaseModel):
|
31
|
+
"""
|
32
|
+
Model to represent the field information required
|
33
|
+
"""
|
34
|
+
|
35
|
+
text: str = Field(..., title="The text of the field")
|
36
|
+
metadata: str = Field(
|
37
|
+
title="The metadata of the field as a base64 string serialized nucliadb_protos.resources.FieldMetadata protobuf",
|
38
|
+
)
|
39
|
+
field_id: str = Field(
|
40
|
+
...,
|
41
|
+
title="The field ID of the field (rid/field_type/field[/split]) or any unique identifier",
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class OperationType(str, Enum):
|
46
|
+
graph = "graph"
|
47
|
+
label = "label"
|
48
|
+
ask = "ask"
|
49
|
+
qa = "qa"
|
50
|
+
extract = "extract"
|
51
|
+
prompt_guard = "prompt_guard"
|
52
|
+
llama_guard = "llama_guard"
|
53
|
+
|
54
|
+
|
55
|
+
class NameOperationFilter(BaseModel):
|
56
|
+
operation_type: OperationType = Field(..., description="Type of the operation")
|
57
|
+
task_names: list[str] = Field(
|
58
|
+
default_factory=list,
|
59
|
+
description="list of task names. If None or empty, all tasks for that operation are applied.",
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
class RunAgentsRequest(BaseModel):
|
64
|
+
"""
|
65
|
+
Model to represent a request for the Augment model
|
66
|
+
The text will be augmented with the Knowledge Box's configured Data Augmentation Agents
|
67
|
+
"""
|
68
|
+
|
69
|
+
fields: list[FieldInfo] = Field(
|
70
|
+
...,
|
71
|
+
title="The fields to be augmented with the Knowledge Box's Data Augmentation Agents",
|
72
|
+
)
|
73
|
+
user_id: str = Field(..., title="The user ID of the user making the request")
|
74
|
+
filters: Optional[list[NameOperationFilter]] = Field(
|
75
|
+
default=None,
|
76
|
+
title="Filters to select which Data Augmentation Agents are applied to the text. If empty, all configured agents for the Knowledge Box are applied.",
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
class NewTextField(TypedDict):
|
81
|
+
text_field: FieldText
|
82
|
+
destination: str
|
83
|
+
|
84
|
+
|
85
|
+
class AppliedDataAugmentation(BaseModel):
|
86
|
+
model_config = ConfigDict(
|
87
|
+
# Since we have protos as fields, we need to enable arbitrary_types_allowed
|
88
|
+
arbitrary_types_allowed=True,
|
89
|
+
)
|
90
|
+
qas: Optional[QuestionAnswers] = Field(
|
91
|
+
default=None,
|
92
|
+
description="Question and answers generated by the Question Answers agent",
|
93
|
+
)
|
94
|
+
new_text_fields: list[NewTextField] = Field(
|
95
|
+
default_factory=list,
|
96
|
+
description="New text fields. Only generated by the Labeler agent as of now.",
|
97
|
+
)
|
98
|
+
changed: bool = Field(
|
99
|
+
default=True,
|
100
|
+
description="Indicates if the FieldMetadata was changed by the agents",
|
101
|
+
)
|
102
|
+
|
103
|
+
@field_validator("qas", mode="before")
|
104
|
+
def validate_qas(cls, qas: Optional[str]) -> Optional[QuestionAnswers]:
|
105
|
+
if qas is None:
|
106
|
+
return None
|
107
|
+
try:
|
108
|
+
return QuestionAnswers.FromString(b64decode(qas))
|
109
|
+
except DecodeError:
|
110
|
+
raise ValueError("Invalid QuestionAnswers protobuf")
|
111
|
+
|
112
|
+
@field_validator("new_text_fields", mode="before")
|
113
|
+
def validate_new_text_fields(cls, new_text_fields: list[NewTextField]) -> list[NewTextField]:
|
114
|
+
try:
|
115
|
+
return [
|
116
|
+
NewTextField(
|
117
|
+
text_field=FieldText.FromString(b64decode(text_field["text_field"])), # type: ignore
|
118
|
+
destination=text_field["destination"],
|
119
|
+
)
|
120
|
+
for text_field in new_text_fields
|
121
|
+
]
|
122
|
+
except DecodeError:
|
123
|
+
raise ValueError("Invalid NewTextField value")
|
124
|
+
|
125
|
+
|
126
|
+
class AugmentedField(BaseModel):
|
127
|
+
# Since we have protos as fields, we need to enable arbitrary_types_allowed
|
128
|
+
model_config = ConfigDict(
|
129
|
+
arbitrary_types_allowed=True,
|
130
|
+
# Custom encoding to be able to serialize protobuf messages
|
131
|
+
json_encoders={Message: lambda m: b64encode(m.SerializeToString()).decode()},
|
132
|
+
)
|
133
|
+
metadata: FieldMetadata = Field(
|
134
|
+
...,
|
135
|
+
title="The updated metadata of the field as a base64 string serialized nucliadb_protos.resources.FieldMetadata protobuf",
|
136
|
+
)
|
137
|
+
applied_data_augmentation: AppliedDataAugmentation = Field(
|
138
|
+
..., title="The results of the Applied Data Augmentation"
|
139
|
+
)
|
140
|
+
input_nuclia_tokens: float = Field(
|
141
|
+
..., title="The number of input Nuclia tokens consumed for the field"
|
142
|
+
)
|
143
|
+
output_nuclia_tokens: float = Field(
|
144
|
+
..., title="The number of output Nuclia tokens consumed for the field"
|
145
|
+
)
|
146
|
+
time: float = Field(..., title="The time taken to execute the Data Augmentation agents to the field")
|
147
|
+
|
148
|
+
@field_validator("metadata", mode="before")
|
149
|
+
def validate_metadata(cls, metadata: str) -> FieldMetadata:
|
150
|
+
try:
|
151
|
+
return FieldMetadata.FromString(b64decode(metadata))
|
152
|
+
except DecodeError:
|
153
|
+
raise ValueError("Invalid FieldMetadata protobuf")
|
154
|
+
|
155
|
+
|
156
|
+
class RunAgentsResponse(BaseModel):
|
157
|
+
results: dict[str, AugmentedField] = Field(
|
158
|
+
...,
|
159
|
+
title="Pairs of augmented FieldMetadata and Data Augmentation results by field id",
|
160
|
+
)
|
@@ -0,0 +1,88 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
import asyncio
|
21
|
+
from base64 import b64encode
|
22
|
+
from typing import Optional
|
23
|
+
|
24
|
+
from nucliadb.common import datamanagers
|
25
|
+
from nucliadb.ingest.fields.base import Field
|
26
|
+
from nucliadb.search.predict_models import (
|
27
|
+
FieldInfo,
|
28
|
+
NameOperationFilter,
|
29
|
+
OperationType,
|
30
|
+
RunAgentsRequest,
|
31
|
+
RunAgentsResponse,
|
32
|
+
)
|
33
|
+
from nucliadb.search.search.exceptions import ResourceNotFoundError
|
34
|
+
from nucliadb.search.utilities import get_predict
|
35
|
+
from nucliadb_models.agents.ingestion import AgentsFilter
|
36
|
+
from nucliadb_protos.resources_pb2 import FieldMetadata
|
37
|
+
|
38
|
+
|
39
|
+
async def run_agents(
|
40
|
+
kbid: str, rid: str, user_id: str, filters: Optional[list[AgentsFilter]] = None
|
41
|
+
) -> RunAgentsResponse:
|
42
|
+
fields = await fetch_resource_fields(kbid, rid)
|
43
|
+
|
44
|
+
item = RunAgentsRequest(user_id=user_id, filters=_parse_filters(filters), fields=fields)
|
45
|
+
|
46
|
+
predict = get_predict()
|
47
|
+
return await predict.run_agents(kbid, item)
|
48
|
+
|
49
|
+
|
50
|
+
def _parse_filters(filters: Optional[list[AgentsFilter]]) -> Optional[list[NameOperationFilter]]:
|
51
|
+
if filters is None:
|
52
|
+
return None
|
53
|
+
return [
|
54
|
+
NameOperationFilter(
|
55
|
+
operation_type=OperationType(filter.type.value), task_names=filter.task_names
|
56
|
+
)
|
57
|
+
for filter in filters
|
58
|
+
]
|
59
|
+
|
60
|
+
|
61
|
+
async def fetch_resource_fields(kbid: str, rid: str) -> list[FieldInfo]:
|
62
|
+
async with datamanagers.with_ro_transaction() as txn:
|
63
|
+
resource = await datamanagers.resources.get_resource(txn, kbid=kbid, rid=rid)
|
64
|
+
if resource is None:
|
65
|
+
raise ResourceNotFoundError()
|
66
|
+
fields = await resource.get_fields(force=True)
|
67
|
+
|
68
|
+
tasks: list[asyncio.Task] = []
|
69
|
+
for field in fields.values():
|
70
|
+
tasks.append(asyncio.create_task(_hydrate_field_info(field)))
|
71
|
+
await asyncio.gather(*tasks)
|
72
|
+
|
73
|
+
return [task.result() for task in tasks]
|
74
|
+
|
75
|
+
|
76
|
+
async def _hydrate_field_info(field_obj: Field) -> FieldInfo:
|
77
|
+
extracted_text = await field_obj.get_extracted_text(force=True)
|
78
|
+
fcmw = await field_obj.get_field_metadata(force=True)
|
79
|
+
if fcmw is None:
|
80
|
+
metadata = FieldMetadata()
|
81
|
+
else:
|
82
|
+
metadata = fcmw.metadata
|
83
|
+
serialized_metadata = b64encode(metadata.SerializeToString()).decode()
|
84
|
+
return FieldInfo(
|
85
|
+
text=extracted_text.text if extracted_text is not None else "",
|
86
|
+
metadata=serialized_metadata,
|
87
|
+
field_id=f"{field_obj.type}/{field_obj.id}",
|
88
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nucliadb
|
3
|
-
Version: 6.2.1.
|
3
|
+
Version: 6.2.1.post3326
|
4
4
|
Summary: NucliaDB
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
6
6
|
License: AGPL
|
@@ -20,11 +20,11 @@ Classifier: Programming Language :: Python :: 3.12
|
|
20
20
|
Classifier: Programming Language :: Python :: 3 :: Only
|
21
21
|
Requires-Python: <4,>=3.9
|
22
22
|
Description-Content-Type: text/markdown
|
23
|
-
Requires-Dist: nucliadb-telemetry[all]>=6.2.1.
|
24
|
-
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.2.1.
|
25
|
-
Requires-Dist: nucliadb-protos>=6.2.1.
|
26
|
-
Requires-Dist: nucliadb-models>=6.2.1.
|
27
|
-
Requires-Dist: nidx-protos>=6.2.1.
|
23
|
+
Requires-Dist: nucliadb-telemetry[all]>=6.2.1.post3326
|
24
|
+
Requires-Dist: nucliadb-utils[cache,fastapi,storages]>=6.2.1.post3326
|
25
|
+
Requires-Dist: nucliadb-protos>=6.2.1.post3326
|
26
|
+
Requires-Dist: nucliadb-models>=6.2.1.post3326
|
27
|
+
Requires-Dist: nidx-protos>=6.2.1.post3326
|
28
28
|
Requires-Dist: nucliadb-admin-assets>=1.0.0.post1224
|
29
29
|
Requires-Dist: nuclia-models>=0.24.2
|
30
30
|
Requires-Dist: uvicorn
|
@@ -90,7 +90,7 @@ nucliadb/common/maindb/local.py,sha256=uE9DIQX1yCNHNN8Tx4fPgSiuTtWpQhlfWkMJ8QZPa
|
|
90
90
|
nucliadb/common/maindb/pg.py,sha256=FNq2clckJYj4Te-1svjQblqGoAF5OwJ5nwz2JtxD0d4,13645
|
91
91
|
nucliadb/common/maindb/utils.py,sha256=zWLs82rWEVhpc1dYvdqTZiAcjZroB6Oo5MQaxMeFuKk,3301
|
92
92
|
nucliadb/common/models_utils/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
93
|
-
nucliadb/common/models_utils/from_proto.py,sha256=
|
93
|
+
nucliadb/common/models_utils/from_proto.py,sha256=d-q6-JWvjVVxtI57itQVKRIY8Hzv-tBhDt-tr4aIKNM,15979
|
94
94
|
nucliadb/common/models_utils/to_proto.py,sha256=97JvOR_3odu50YvzLa2CERfEN3w_QPmAVcCJwJB5m5A,2438
|
95
95
|
nucliadb/export_import/__init__.py,sha256=y-Is0Bxa8TMV6UiOW0deC_D3U465P65CQ5RjBjIWnow,932
|
96
96
|
nucliadb/export_import/datamanager.py,sha256=b9Vhf-WqJ8HosTdNpKXlGj-Vi7MHyMoPxL0VpfPlYlE,6720
|
@@ -181,13 +181,14 @@ nucliadb/search/__init__.py,sha256=tnypbqcH4nBHbGpkINudhKgdLKpwXQCvDtPchUlsyY4,1
|
|
181
181
|
nucliadb/search/app.py,sha256=6UV7rO0f3w5bNFXLdQM8bwUwXayMGnM4hF6GGv7WPv4,4260
|
182
182
|
nucliadb/search/lifecycle.py,sha256=DW8v4WUi4rZqc7xTOi3rE67W7877WG7fH9oTZbolHdE,2099
|
183
183
|
nucliadb/search/openapi.py,sha256=t3Wo_4baTrfPftg2BHsyLWNZ1MYn7ZRdW7ht-wFOgRs,1016
|
184
|
-
nucliadb/search/predict.py,sha256=
|
184
|
+
nucliadb/search/predict.py,sha256=z2-RkhMkH-5T6PtFkfESxNof07XiS5FxicLHPRyCUXc,22284
|
185
|
+
nucliadb/search/predict_models.py,sha256=m3abP2J-xZIhKa230OvnnBRXxL4oSlozl5PpA7VeVvE,5908
|
185
186
|
nucliadb/search/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
186
187
|
nucliadb/search/run.py,sha256=aFb-CXRi_C8YMpP_ivNj8KW1BYhADj88y8K9Lr_nUPI,1402
|
187
188
|
nucliadb/search/settings.py,sha256=vem3EcyYlTPSim0kEK-xe-erF4BZg0CT_LAb8ZRQAE8,1684
|
188
189
|
nucliadb/search/utilities.py,sha256=9SsRDw0rJVXVoLBfF7rBb6q080h-thZc7u8uRcTiBeY,1037
|
189
190
|
nucliadb/search/api/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
190
|
-
nucliadb/search/api/v1/__init__.py,sha256=
|
191
|
+
nucliadb/search/api/v1/__init__.py,sha256=8w6VhZ5rbzX1xLSXr336d2IE-O0dQiv-ba6UYdRKnHA,1325
|
191
192
|
nucliadb/search/api/v1/ask.py,sha256=F2dR3-swb3Xz8MfZPYL3G65KY2R_mgef4YVBbu8kLi4,4352
|
192
193
|
nucliadb/search/api/v1/catalog.py,sha256=TF19WN-qgZZLkqBwVH5xNsMxYTrmdEflPvy7qft_4lE,7010
|
193
194
|
nucliadb/search/api/v1/feedback.py,sha256=kNLc4dHz2SXHzV0PwC1WiRAwY88fDptPcP-kO0q-FrQ,2620
|
@@ -200,8 +201,10 @@ nucliadb/search/api/v1/suggest.py,sha256=S0YUTAWukzZSYZJzN3T5MUgPM3599HQvG76GOCB
|
|
200
201
|
nucliadb/search/api/v1/summarize.py,sha256=VAHJvE6V3xUgEBfqNKhgoxmDqCvh30RnrEIBVhMcNLU,2499
|
201
202
|
nucliadb/search/api/v1/utils.py,sha256=5Ve-frn7LAE2jqAgB85F8RSeqxDlyA08--gS-AdOLS4,1434
|
202
203
|
nucliadb/search/api/v1/resource/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
203
|
-
nucliadb/search/api/v1/resource/ask.py,sha256=
|
204
|
+
nucliadb/search/api/v1/resource/ask.py,sha256=nsVzBSanSSlf0Ody6LSTjdEy75Vg283_YhbkAtWEjh8,3637
|
205
|
+
nucliadb/search/api/v1/resource/ingestion_agents.py,sha256=fqqRCd8Wc9GciS5P98lcnihvTKStsZYYtOU-T1bc-6E,4771
|
204
206
|
nucliadb/search/api/v1/resource/search.py,sha256=oSU5lwG7XRnD7oBFct31JaECGjTjX5R8mxNF1mskINc,4715
|
207
|
+
nucliadb/search/api/v1/resource/utils.py,sha256=-NjZqAQtFEXKpIh8ui5S26ItnJ5rzmmG0BHxGSS9QPw,1141
|
205
208
|
nucliadb/search/requesters/__init__.py,sha256=itSI7dtTwFP55YMX4iK7JzdMHS5CQVUiB1XzQu4UBh8,833
|
206
209
|
nucliadb/search/requesters/utils.py,sha256=qL81UVPNgBftUMLpcxIYVr7ILsMqpKCo-9SY2EvAaXw,6681
|
207
210
|
nucliadb/search/search/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
|
@@ -214,6 +217,7 @@ nucliadb/search/search/find.py,sha256=AocqiH_mWvF_szUaW0ONqWrZAbX-k_VhM0Lpv7D669
|
|
214
217
|
nucliadb/search/search/find_merge.py,sha256=3FnzKFEnVemg6FO_6zveulbAU7klvsiPEBvLrpBBMg8,17450
|
215
218
|
nucliadb/search/search/graph_strategy.py,sha256=ahwcUTQZ0Ll-rnS285DO9PmRyiM-1p4BM3UvmOYVwhM,31750
|
216
219
|
nucliadb/search/search/hydrator.py,sha256=-R37gCrGxkyaiHQalnTWHNG_FCx11Zucd7qA1vQCxuw,6985
|
220
|
+
nucliadb/search/search/ingestion_agents.py,sha256=NeJr4EEX-bvFFMGvXOOwLv8uU7NuQ-ntJnnrhnKfMzY,3174
|
217
221
|
nucliadb/search/search/merge.py,sha256=i_PTBFRqC5iTTziOMEltxLIlmokIou5hjjgR4BnoLBE,22635
|
218
222
|
nucliadb/search/search/metrics.py,sha256=81X-tahGW4n2CLvUzCPdNxNClmZqUWZjcVOGCUHoiUM,2872
|
219
223
|
nucliadb/search/search/paragraphs.py,sha256=pNAEiYqJGGUVcEf7xf-PFMVqz0PX4Qb-WNG-_zPGN2o,7799
|
@@ -331,8 +335,8 @@ nucliadb/writer/tus/local.py,sha256=7jYa_w9b-N90jWgN2sQKkNcomqn6JMVBOVeDOVYJHto,
|
|
331
335
|
nucliadb/writer/tus/s3.py,sha256=vF0NkFTXiXhXq3bCVXXVV-ED38ECVoUeeYViP8uMqcU,8357
|
332
336
|
nucliadb/writer/tus/storage.py,sha256=ToqwjoYnjI4oIcwzkhha_MPxi-k4Jk3Lt55zRwaC1SM,2903
|
333
337
|
nucliadb/writer/tus/utils.py,sha256=MSdVbRsRSZVdkaum69_0wku7X3p5wlZf4nr6E0GMKbw,2556
|
334
|
-
nucliadb-6.2.1.
|
335
|
-
nucliadb-6.2.1.
|
336
|
-
nucliadb-6.2.1.
|
337
|
-
nucliadb-6.2.1.
|
338
|
-
nucliadb-6.2.1.
|
338
|
+
nucliadb-6.2.1.post3326.dist-info/METADATA,sha256=dlbBvdHdR_9AzOSB7w415CnvT_1xpxCNMDsxwPHoyek,4291
|
339
|
+
nucliadb-6.2.1.post3326.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
340
|
+
nucliadb-6.2.1.post3326.dist-info/entry_points.txt,sha256=XqGfgFDuY3zXQc8ewXM2TRVjTModIq851zOsgrmaXx4,1268
|
341
|
+
nucliadb-6.2.1.post3326.dist-info/top_level.txt,sha256=hwYhTVnX7jkQ9gJCkVrbqEG1M4lT2F_iPQND1fCzF80,20
|
342
|
+
nucliadb-6.2.1.post3326.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|