nucliadb 6.2.0.post2675__py3-none-any.whl → 6.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0028_extracted_vectors_reference.py +61 -0
- migrations/0029_backfill_field_status.py +149 -0
- migrations/0030_label_deduplication.py +60 -0
- nucliadb/common/cluster/manager.py +41 -331
- nucliadb/common/cluster/rebalance.py +2 -2
- nucliadb/common/cluster/rollover.py +12 -71
- nucliadb/common/cluster/settings.py +3 -0
- nucliadb/common/cluster/standalone/utils.py +0 -43
- nucliadb/common/cluster/utils.py +0 -16
- nucliadb/common/counters.py +1 -0
- nucliadb/common/datamanagers/fields.py +48 -7
- nucliadb/common/datamanagers/vectorsets.py +11 -2
- nucliadb/common/external_index_providers/base.py +2 -1
- nucliadb/common/external_index_providers/pinecone.py +3 -5
- nucliadb/common/ids.py +18 -4
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +76 -37
- nucliadb/export_import/models.py +3 -3
- nucliadb/health.py +0 -7
- nucliadb/ingest/app.py +0 -8
- nucliadb/ingest/consumer/auditing.py +1 -1
- nucliadb/ingest/consumer/shard_creator.py +1 -1
- nucliadb/ingest/fields/base.py +83 -21
- nucliadb/ingest/orm/brain.py +55 -56
- nucliadb/ingest/orm/broker_message.py +12 -2
- nucliadb/ingest/orm/entities.py +6 -17
- nucliadb/ingest/orm/knowledgebox.py +44 -22
- nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
- nucliadb/ingest/orm/processor/processor.py +5 -2
- nucliadb/ingest/orm/resource.py +222 -413
- nucliadb/ingest/processing.py +8 -2
- nucliadb/ingest/serialize.py +77 -46
- nucliadb/ingest/service/writer.py +2 -56
- nucliadb/ingest/settings.py +1 -4
- nucliadb/learning_proxy.py +6 -4
- nucliadb/purge/__init__.py +102 -12
- nucliadb/purge/orphan_shards.py +6 -4
- nucliadb/reader/api/models.py +3 -3
- nucliadb/reader/api/v1/__init__.py +1 -0
- nucliadb/reader/api/v1/download.py +2 -2
- nucliadb/reader/api/v1/knowledgebox.py +3 -3
- nucliadb/reader/api/v1/resource.py +23 -12
- nucliadb/reader/api/v1/services.py +4 -4
- nucliadb/reader/api/v1/vectorsets.py +48 -0
- nucliadb/search/api/v1/ask.py +11 -1
- nucliadb/search/api/v1/feedback.py +3 -3
- nucliadb/search/api/v1/knowledgebox.py +8 -13
- nucliadb/search/api/v1/search.py +3 -2
- nucliadb/search/api/v1/suggest.py +0 -2
- nucliadb/search/predict.py +6 -4
- nucliadb/search/requesters/utils.py +1 -2
- nucliadb/search/search/chat/ask.py +77 -13
- nucliadb/search/search/chat/prompt.py +16 -5
- nucliadb/search/search/chat/query.py +74 -34
- nucliadb/search/search/exceptions.py +2 -7
- nucliadb/search/search/find.py +9 -5
- nucliadb/search/search/find_merge.py +10 -4
- nucliadb/search/search/graph_strategy.py +884 -0
- nucliadb/search/search/hydrator.py +6 -0
- nucliadb/search/search/merge.py +79 -24
- nucliadb/search/search/query.py +74 -245
- nucliadb/search/search/query_parser/exceptions.py +11 -1
- nucliadb/search/search/query_parser/fetcher.py +405 -0
- nucliadb/search/search/query_parser/models.py +0 -3
- nucliadb/search/search/query_parser/parser.py +22 -21
- nucliadb/search/search/rerankers.py +1 -42
- nucliadb/search/search/shards.py +19 -0
- nucliadb/standalone/api_router.py +2 -14
- nucliadb/standalone/settings.py +4 -0
- nucliadb/train/generators/field_streaming.py +7 -3
- nucliadb/train/lifecycle.py +3 -6
- nucliadb/train/nodes.py +14 -12
- nucliadb/train/resource.py +380 -0
- nucliadb/writer/api/constants.py +20 -16
- nucliadb/writer/api/v1/__init__.py +1 -0
- nucliadb/writer/api/v1/export_import.py +1 -1
- nucliadb/writer/api/v1/field.py +13 -7
- nucliadb/writer/api/v1/knowledgebox.py +3 -46
- nucliadb/writer/api/v1/resource.py +20 -13
- nucliadb/writer/api/v1/services.py +10 -1
- nucliadb/writer/api/v1/upload.py +61 -34
- nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
- nucliadb/writer/back_pressure.py +17 -46
- nucliadb/writer/resource/basic.py +9 -7
- nucliadb/writer/resource/field.py +42 -9
- nucliadb/writer/settings.py +2 -2
- nucliadb/writer/tus/gcs.py +11 -10
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
- nucliadb/common/cluster/discovery/base.py +0 -178
- nucliadb/common/cluster/discovery/k8s.py +0 -301
- nucliadb/common/cluster/discovery/manual.py +0 -57
- nucliadb/common/cluster/discovery/single.py +0 -51
- nucliadb/common/cluster/discovery/types.py +0 -32
- nucliadb/common/cluster/discovery/utils.py +0 -67
- nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
- nucliadb/common/cluster/standalone/index_node.py +0 -123
- nucliadb/common/cluster/standalone/service.py +0 -84
- nucliadb/standalone/introspect.py +0 -208
- nucliadb-6.2.0.post2675.dist-info/zip-safe +0 -1
- /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.0.post2675.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -97,9 +97,13 @@ async def generate_field_streaming_payloads(
|
|
97
97
|
tl.field = field
|
98
98
|
tl.field_type = field_type
|
99
99
|
tl.split = split
|
100
|
-
|
101
|
-
if
|
102
|
-
tl.text.
|
100
|
+
|
101
|
+
if trainset.exclude_text:
|
102
|
+
tl.text.text = ""
|
103
|
+
else:
|
104
|
+
extracted = await get_field_text(kbid, rid, field, field_type)
|
105
|
+
if extracted is not None:
|
106
|
+
tl.text.CopyFrom(extracted)
|
103
107
|
|
104
108
|
metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
|
105
109
|
if metadata_obj is not None:
|
nucliadb/train/lifecycle.py
CHANGED
@@ -22,10 +22,7 @@ from contextlib import asynccontextmanager
|
|
22
22
|
|
23
23
|
from fastapi import FastAPI
|
24
24
|
|
25
|
-
from nucliadb.common.
|
26
|
-
setup_cluster_discovery,
|
27
|
-
teardown_cluster_discovery,
|
28
|
-
)
|
25
|
+
from nucliadb.common.nidx import start_nidx_utility, stop_nidx_utility
|
29
26
|
from nucliadb.train import SERVICE_NAME
|
30
27
|
from nucliadb.train.utils import (
|
31
28
|
start_shard_manager,
|
@@ -40,7 +37,7 @@ from nucliadb_utils.utilities import start_audit_utility, stop_audit_utility
|
|
40
37
|
@asynccontextmanager
|
41
38
|
async def lifespan(app: FastAPI):
|
42
39
|
await setup_telemetry(SERVICE_NAME)
|
43
|
-
await
|
40
|
+
await start_nidx_utility()
|
44
41
|
await start_shard_manager()
|
45
42
|
await start_train_grpc(SERVICE_NAME)
|
46
43
|
await start_audit_utility(SERVICE_NAME)
|
@@ -50,5 +47,5 @@ async def lifespan(app: FastAPI):
|
|
50
47
|
await stop_audit_utility()
|
51
48
|
await stop_train_grpc()
|
52
49
|
await stop_shard_manager()
|
53
|
-
await
|
50
|
+
await stop_nidx_utility()
|
54
51
|
await clean_telemetry(SERVICE_NAME)
|
nucliadb/train/nodes.py
CHANGED
@@ -28,6 +28,12 @@ from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
|
|
28
28
|
from nucliadb.common.maindb.driver import Driver, Transaction
|
29
29
|
from nucliadb.ingest.orm.entities import EntitiesManager
|
30
30
|
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
|
31
|
+
from nucliadb.train.resource import (
|
32
|
+
generate_train_resource,
|
33
|
+
iterate_fields,
|
34
|
+
iterate_paragraphs,
|
35
|
+
iterate_sentences,
|
36
|
+
)
|
31
37
|
from nucliadb_protos.train_pb2 import (
|
32
38
|
GetFieldsRequest,
|
33
39
|
GetParagraphsRequest,
|
@@ -39,9 +45,7 @@ from nucliadb_protos.train_pb2 import (
|
|
39
45
|
TrainSentence,
|
40
46
|
)
|
41
47
|
from nucliadb_protos.writer_pb2 import ShardObject
|
42
|
-
from nucliadb_utils import const
|
43
48
|
from nucliadb_utils.storages.storage import Storage
|
44
|
-
from nucliadb_utils.utilities import has_feature
|
45
49
|
|
46
50
|
|
47
51
|
class TrainShardManager(manager.KBShardManager):
|
@@ -57,9 +61,7 @@ class TrainShardManager(manager.KBShardManager):
|
|
57
61
|
except StopIteration:
|
58
62
|
raise KeyError("Shard not found")
|
59
63
|
|
60
|
-
node_obj, shard_id = manager.choose_node(
|
61
|
-
shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
|
62
|
-
)
|
64
|
+
node_obj, shard_id = manager.choose_node(shard_object)
|
63
65
|
return node_obj, shard_id
|
64
66
|
|
65
67
|
async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
|
@@ -87,11 +89,11 @@ class TrainShardManager(manager.KBShardManager):
|
|
87
89
|
# Filter by uuid
|
88
90
|
resource = await kb.get(request.uuid)
|
89
91
|
if resource:
|
90
|
-
async for sentence in
|
92
|
+
async for sentence in iterate_sentences(resource, request.metadata):
|
91
93
|
yield sentence
|
92
94
|
else:
|
93
95
|
async for resource in kb.iterate_resources():
|
94
|
-
async for sentence in
|
96
|
+
async for sentence in iterate_sentences(resource, request.metadata):
|
95
97
|
yield sentence
|
96
98
|
|
97
99
|
async def kb_paragraphs(self, request: GetParagraphsRequest) -> AsyncIterator[TrainParagraph]:
|
@@ -101,11 +103,11 @@ class TrainShardManager(manager.KBShardManager):
|
|
101
103
|
# Filter by uuid
|
102
104
|
resource = await kb.get(request.uuid)
|
103
105
|
if resource:
|
104
|
-
async for paragraph in
|
106
|
+
async for paragraph in iterate_paragraphs(resource, request.metadata):
|
105
107
|
yield paragraph
|
106
108
|
else:
|
107
109
|
async for resource in kb.iterate_resources():
|
108
|
-
async for paragraph in
|
110
|
+
async for paragraph in iterate_paragraphs(resource, request.metadata):
|
109
111
|
yield paragraph
|
110
112
|
|
111
113
|
async def kb_fields(self, request: GetFieldsRequest) -> AsyncIterator[TrainField]:
|
@@ -115,11 +117,11 @@ class TrainShardManager(manager.KBShardManager):
|
|
115
117
|
# Filter by uuid
|
116
118
|
resource = await kb.get(request.uuid)
|
117
119
|
if resource:
|
118
|
-
async for field in
|
120
|
+
async for field in iterate_fields(resource, request.metadata):
|
119
121
|
yield field
|
120
122
|
else:
|
121
123
|
async for resource in kb.iterate_resources():
|
122
|
-
async for field in
|
124
|
+
async for field in iterate_fields(resource, request.metadata):
|
123
125
|
yield field
|
124
126
|
|
125
127
|
async def kb_resources(self, request: GetResourcesRequest) -> AsyncIterator[TrainResource]:
|
@@ -132,4 +134,4 @@ class TrainShardManager(manager.KBShardManager):
|
|
132
134
|
if rid is not None:
|
133
135
|
resource = await kb.get(rid.decode())
|
134
136
|
if resource is not None:
|
135
|
-
yield await
|
137
|
+
yield await generate_train_resource(resource, request.metadata)
|
@@ -0,0 +1,380 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
from typing import AsyncIterator, MutableMapping, Optional
|
23
|
+
|
24
|
+
from nucliadb.common import datamanagers
|
25
|
+
from nucliadb.ingest.orm.resource import Resource
|
26
|
+
from nucliadb_protos.resources_pb2 import (
|
27
|
+
FieldID,
|
28
|
+
FieldMetadata,
|
29
|
+
ParagraphAnnotation,
|
30
|
+
)
|
31
|
+
from nucliadb_protos.train_pb2 import (
|
32
|
+
EnabledMetadata,
|
33
|
+
TrainField,
|
34
|
+
TrainMetadata,
|
35
|
+
TrainParagraph,
|
36
|
+
TrainResource,
|
37
|
+
TrainSentence,
|
38
|
+
)
|
39
|
+
from nucliadb_protos.train_pb2 import Position as TrainPosition
|
40
|
+
|
41
|
+
|
42
|
+
async def iterate_sentences(
|
43
|
+
resource: Resource,
|
44
|
+
enabled_metadata: EnabledMetadata,
|
45
|
+
) -> AsyncIterator[TrainSentence]: # pragma: no cover
|
46
|
+
fields = await resource.get_fields(force=True)
|
47
|
+
metadata = TrainMetadata()
|
48
|
+
userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
|
49
|
+
if enabled_metadata.labels:
|
50
|
+
if resource.basic is None:
|
51
|
+
resource.basic = await resource.get_basic()
|
52
|
+
if resource.basic is not None:
|
53
|
+
metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
|
54
|
+
for fieldmetadata in resource.basic.fieldmetadata:
|
55
|
+
field_id = resource.generate_field_id(fieldmetadata.field)
|
56
|
+
for annotationparagraph in fieldmetadata.paragraphs:
|
57
|
+
userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
|
58
|
+
|
59
|
+
for (type_id, field_id), field in fields.items():
|
60
|
+
fieldid = FieldID(field_type=type_id, field=field_id)
|
61
|
+
field_key = resource.generate_field_id(fieldid)
|
62
|
+
fm = await field.get_field_metadata()
|
63
|
+
extracted_text = None
|
64
|
+
vo = None
|
65
|
+
text = None
|
66
|
+
|
67
|
+
if enabled_metadata.vector:
|
68
|
+
# XXX: Given that nobody requested any particular vectorset, we'll
|
69
|
+
# return any
|
70
|
+
vectorset_id = None
|
71
|
+
async with datamanagers.with_ro_transaction() as txn:
|
72
|
+
async for vectorset_id, vs in datamanagers.vectorsets.iter(
|
73
|
+
txn=txn, kbid=resource.kb.kbid
|
74
|
+
):
|
75
|
+
break
|
76
|
+
assert vectorset_id is not None, "All KBs must have at least a vectorset"
|
77
|
+
vo = await field.get_vectors(vectorset_id, vs.storage_key_kind)
|
78
|
+
|
79
|
+
extracted_text = await field.get_extracted_text()
|
80
|
+
|
81
|
+
if fm is None:
|
82
|
+
continue
|
83
|
+
|
84
|
+
field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
|
85
|
+
for subfield_metadata, splitted_metadata in fm.split_metadata.items():
|
86
|
+
field_metadatas.append((subfield_metadata, splitted_metadata))
|
87
|
+
|
88
|
+
for subfield, field_metadata in field_metadatas:
|
89
|
+
if enabled_metadata.labels:
|
90
|
+
metadata.labels.ClearField("field")
|
91
|
+
metadata.labels.field.extend(field_metadata.classifications)
|
92
|
+
|
93
|
+
entities: dict[str, str] = {}
|
94
|
+
if enabled_metadata.entities:
|
95
|
+
_update_entities_dict(entities, field_metadata)
|
96
|
+
|
97
|
+
precomputed_vectors = {}
|
98
|
+
if vo is not None:
|
99
|
+
if subfield is not None:
|
100
|
+
vectors = vo.split_vectors[subfield]
|
101
|
+
base_vector_key = f"{resource.uuid}/{field_key}/{subfield}"
|
102
|
+
else:
|
103
|
+
vectors = vo.vectors
|
104
|
+
base_vector_key = f"{resource.uuid}/{field_key}"
|
105
|
+
for index, vector in enumerate(vectors.vectors):
|
106
|
+
vector_key = f"{base_vector_key}/{index}/{vector.start}-{vector.end}"
|
107
|
+
precomputed_vectors[vector_key] = vector.vector
|
108
|
+
|
109
|
+
if extracted_text is not None:
|
110
|
+
if subfield is not None:
|
111
|
+
text = extracted_text.split_text[subfield]
|
112
|
+
else:
|
113
|
+
text = extracted_text.text
|
114
|
+
|
115
|
+
for paragraph in field_metadata.paragraphs:
|
116
|
+
if subfield is not None:
|
117
|
+
paragraph_key = (
|
118
|
+
f"{resource.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
paragraph_key = f"{resource.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
|
122
|
+
|
123
|
+
if enabled_metadata.labels:
|
124
|
+
metadata.labels.ClearField("field")
|
125
|
+
metadata.labels.paragraph.extend(paragraph.classifications)
|
126
|
+
if paragraph_key in userdefinedparagraphclass:
|
127
|
+
metadata.labels.paragraph.extend(
|
128
|
+
userdefinedparagraphclass[paragraph_key].classifications
|
129
|
+
)
|
130
|
+
|
131
|
+
for index, sentence in enumerate(paragraph.sentences):
|
132
|
+
if subfield is not None:
|
133
|
+
sentence_key = f"{resource.uuid}/{field_key}/{subfield}/{index}/{sentence.start}-{sentence.end}"
|
134
|
+
else:
|
135
|
+
sentence_key = (
|
136
|
+
f"{resource.uuid}/{field_key}/{index}/{sentence.start}-{sentence.end}"
|
137
|
+
)
|
138
|
+
|
139
|
+
if vo is not None:
|
140
|
+
metadata.ClearField("vector")
|
141
|
+
vector_tmp = precomputed_vectors.get(sentence_key)
|
142
|
+
if vector_tmp:
|
143
|
+
metadata.vector.extend(vector_tmp)
|
144
|
+
|
145
|
+
if extracted_text is not None and text is not None:
|
146
|
+
metadata.text = text[sentence.start : sentence.end]
|
147
|
+
|
148
|
+
metadata.ClearField("entities")
|
149
|
+
metadata.ClearField("entity_positions")
|
150
|
+
if enabled_metadata.entities and text is not None:
|
151
|
+
local_text = text[sentence.start : sentence.end]
|
152
|
+
add_entities_to_metadata(entities, local_text, metadata)
|
153
|
+
|
154
|
+
pb_sentence = TrainSentence()
|
155
|
+
pb_sentence.uuid = resource.uuid
|
156
|
+
pb_sentence.field.CopyFrom(fieldid)
|
157
|
+
pb_sentence.paragraph = paragraph_key
|
158
|
+
pb_sentence.sentence = sentence_key
|
159
|
+
pb_sentence.metadata.CopyFrom(metadata)
|
160
|
+
yield pb_sentence
|
161
|
+
|
162
|
+
|
163
|
+
async def iterate_paragraphs(
|
164
|
+
resource: Resource, enabled_metadata: EnabledMetadata
|
165
|
+
) -> AsyncIterator[TrainParagraph]:
|
166
|
+
fields = await resource.get_fields(force=True)
|
167
|
+
metadata = TrainMetadata()
|
168
|
+
userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
|
169
|
+
if enabled_metadata.labels:
|
170
|
+
if resource.basic is None:
|
171
|
+
resource.basic = await resource.get_basic()
|
172
|
+
if resource.basic is not None:
|
173
|
+
metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
|
174
|
+
for fieldmetadata in resource.basic.fieldmetadata:
|
175
|
+
field_id = resource.generate_field_id(fieldmetadata.field)
|
176
|
+
for annotationparagraph in fieldmetadata.paragraphs:
|
177
|
+
userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
|
178
|
+
|
179
|
+
for (type_id, field_id), field in fields.items():
|
180
|
+
fieldid = FieldID(field_type=type_id, field=field_id)
|
181
|
+
field_key = resource.generate_field_id(fieldid)
|
182
|
+
fm = await field.get_field_metadata()
|
183
|
+
extracted_text = None
|
184
|
+
text = None
|
185
|
+
|
186
|
+
extracted_text = await field.get_extracted_text()
|
187
|
+
|
188
|
+
if fm is None:
|
189
|
+
continue
|
190
|
+
|
191
|
+
field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
|
192
|
+
for subfield_metadata, splitted_metadata in fm.split_metadata.items():
|
193
|
+
field_metadatas.append((subfield_metadata, splitted_metadata))
|
194
|
+
|
195
|
+
for subfield, field_metadata in field_metadatas:
|
196
|
+
if enabled_metadata.labels:
|
197
|
+
metadata.labels.ClearField("field")
|
198
|
+
metadata.labels.field.extend(field_metadata.classifications)
|
199
|
+
|
200
|
+
entities: dict[str, str] = {}
|
201
|
+
if enabled_metadata.entities:
|
202
|
+
_update_entities_dict(entities, field_metadata)
|
203
|
+
|
204
|
+
if extracted_text is not None:
|
205
|
+
if subfield is not None:
|
206
|
+
text = extracted_text.split_text[subfield]
|
207
|
+
else:
|
208
|
+
text = extracted_text.text
|
209
|
+
|
210
|
+
for paragraph in field_metadata.paragraphs:
|
211
|
+
if subfield is not None:
|
212
|
+
paragraph_key = (
|
213
|
+
f"{resource.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
paragraph_key = f"{resource.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
|
217
|
+
|
218
|
+
if enabled_metadata.labels:
|
219
|
+
metadata.labels.ClearField("paragraph")
|
220
|
+
metadata.labels.paragraph.extend(paragraph.classifications)
|
221
|
+
|
222
|
+
if extracted_text is not None and text is not None:
|
223
|
+
metadata.text = text[paragraph.start : paragraph.end]
|
224
|
+
|
225
|
+
metadata.ClearField("entities")
|
226
|
+
metadata.ClearField("entity_positions")
|
227
|
+
if enabled_metadata.entities and text is not None:
|
228
|
+
local_text = text[paragraph.start : paragraph.end]
|
229
|
+
add_entities_to_metadata(entities, local_text, metadata)
|
230
|
+
|
231
|
+
if paragraph_key in userdefinedparagraphclass:
|
232
|
+
metadata.labels.paragraph.extend(
|
233
|
+
userdefinedparagraphclass[paragraph_key].classifications
|
234
|
+
)
|
235
|
+
|
236
|
+
pb_paragraph = TrainParagraph()
|
237
|
+
pb_paragraph.uuid = resource.uuid
|
238
|
+
pb_paragraph.field.CopyFrom(fieldid)
|
239
|
+
pb_paragraph.paragraph = paragraph_key
|
240
|
+
pb_paragraph.metadata.CopyFrom(metadata)
|
241
|
+
|
242
|
+
yield pb_paragraph
|
243
|
+
|
244
|
+
|
245
|
+
async def iterate_fields(
|
246
|
+
resource: Resource, enabled_metadata: EnabledMetadata
|
247
|
+
) -> AsyncIterator[TrainField]:
|
248
|
+
fields = await resource.get_fields(force=True)
|
249
|
+
metadata = TrainMetadata()
|
250
|
+
if enabled_metadata.labels:
|
251
|
+
if resource.basic is None:
|
252
|
+
resource.basic = await resource.get_basic()
|
253
|
+
if resource.basic is not None:
|
254
|
+
metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
|
255
|
+
|
256
|
+
for (type_id, field_id), field in fields.items():
|
257
|
+
fieldid = FieldID(field_type=type_id, field=field_id)
|
258
|
+
fm = await field.get_field_metadata()
|
259
|
+
extracted_text = None
|
260
|
+
|
261
|
+
if enabled_metadata.text:
|
262
|
+
extracted_text = await field.get_extracted_text()
|
263
|
+
|
264
|
+
if fm is None:
|
265
|
+
continue
|
266
|
+
|
267
|
+
field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
|
268
|
+
for subfield_metadata, splitted_metadata in fm.split_metadata.items():
|
269
|
+
field_metadatas.append((subfield_metadata, splitted_metadata))
|
270
|
+
|
271
|
+
for subfield, splitted_metadata in field_metadatas:
|
272
|
+
if enabled_metadata.labels:
|
273
|
+
metadata.labels.ClearField("field")
|
274
|
+
metadata.labels.field.extend(splitted_metadata.classifications)
|
275
|
+
|
276
|
+
if extracted_text is not None:
|
277
|
+
if subfield is not None:
|
278
|
+
metadata.text = extracted_text.split_text[subfield]
|
279
|
+
else:
|
280
|
+
metadata.text = extracted_text.text
|
281
|
+
|
282
|
+
if enabled_metadata.entities:
|
283
|
+
metadata.ClearField("entities")
|
284
|
+
_update_entities_dict(metadata.entities, splitted_metadata)
|
285
|
+
|
286
|
+
pb_field = TrainField()
|
287
|
+
pb_field.uuid = resource.uuid
|
288
|
+
pb_field.field.CopyFrom(fieldid)
|
289
|
+
pb_field.metadata.CopyFrom(metadata)
|
290
|
+
yield pb_field
|
291
|
+
|
292
|
+
|
293
|
+
async def generate_train_resource(
|
294
|
+
resource: Resource, enabled_metadata: EnabledMetadata
|
295
|
+
) -> TrainResource:
|
296
|
+
fields = await resource.get_fields(force=True)
|
297
|
+
metadata = TrainMetadata()
|
298
|
+
if enabled_metadata.labels:
|
299
|
+
if resource.basic is None:
|
300
|
+
resource.basic = await resource.get_basic()
|
301
|
+
if resource.basic is not None:
|
302
|
+
metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
|
303
|
+
|
304
|
+
metadata.labels.ClearField("field")
|
305
|
+
metadata.ClearField("entities")
|
306
|
+
|
307
|
+
for (_, _), field in fields.items():
|
308
|
+
extracted_text = None
|
309
|
+
fm = await field.get_field_metadata()
|
310
|
+
|
311
|
+
if enabled_metadata.text:
|
312
|
+
extracted_text = await field.get_extracted_text()
|
313
|
+
|
314
|
+
if extracted_text is not None:
|
315
|
+
metadata.text += extracted_text.text
|
316
|
+
for text in extracted_text.split_text.values():
|
317
|
+
metadata.text += f" {text}"
|
318
|
+
|
319
|
+
if fm is None:
|
320
|
+
continue
|
321
|
+
|
322
|
+
field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
|
323
|
+
for subfield_metadata, splitted_metadata in fm.split_metadata.items():
|
324
|
+
field_metadatas.append((subfield_metadata, splitted_metadata))
|
325
|
+
|
326
|
+
for _, splitted_metadata in field_metadatas:
|
327
|
+
if enabled_metadata.labels:
|
328
|
+
metadata.labels.field.extend(splitted_metadata.classifications)
|
329
|
+
|
330
|
+
if enabled_metadata.entities:
|
331
|
+
_update_entities_dict(metadata.entities, splitted_metadata)
|
332
|
+
|
333
|
+
pb_resource = TrainResource()
|
334
|
+
pb_resource.uuid = resource.uuid
|
335
|
+
if resource.basic is not None:
|
336
|
+
pb_resource.title = resource.basic.title
|
337
|
+
pb_resource.icon = resource.basic.icon
|
338
|
+
pb_resource.slug = resource.basic.slug
|
339
|
+
pb_resource.modified.CopyFrom(resource.basic.modified)
|
340
|
+
pb_resource.created.CopyFrom(resource.basic.created)
|
341
|
+
pb_resource.metadata.CopyFrom(metadata)
|
342
|
+
return pb_resource
|
343
|
+
|
344
|
+
|
345
|
+
def add_entities_to_metadata(entities: dict[str, str], local_text: str, metadata: TrainMetadata) -> None:
|
346
|
+
for entity_key, entity_value in entities.items():
|
347
|
+
if entity_key not in local_text:
|
348
|
+
# Add the entity only if found in text
|
349
|
+
continue
|
350
|
+
metadata.entities[entity_key] = entity_value
|
351
|
+
|
352
|
+
# Add positions for the entity relative to the local text
|
353
|
+
poskey = f"{entity_value}/{entity_key}"
|
354
|
+
metadata.entity_positions[poskey].entity = entity_key
|
355
|
+
last_occurrence_end = 0
|
356
|
+
for _ in range(local_text.count(entity_key)):
|
357
|
+
start = local_text.index(entity_key, last_occurrence_end)
|
358
|
+
end = start + len(entity_key)
|
359
|
+
metadata.entity_positions[poskey].positions.append(TrainPosition(start=start, end=end))
|
360
|
+
last_occurrence_end = end
|
361
|
+
|
362
|
+
|
363
|
+
def _update_entities_dict(target_entites_dict: MutableMapping[str, str], field_metadata: FieldMetadata):
|
364
|
+
"""
|
365
|
+
Update the entities dict with the entities from the field metadata.
|
366
|
+
Method created to ease the transition from legacy ner field to new entities field.
|
367
|
+
"""
|
368
|
+
# Data Augmentation + Processor entities
|
369
|
+
# This will overwrite entities detected from more than one data augmentation task
|
370
|
+
# TODO: Change TrainMetadata proto to accept multiple entities with the same text
|
371
|
+
entity_map = {
|
372
|
+
entity.text: entity.label
|
373
|
+
for data_augmentation_task_id, entities_wrapper in field_metadata.entities.items()
|
374
|
+
for entity in entities_wrapper.entities
|
375
|
+
}
|
376
|
+
target_entites_dict.update(entity_map)
|
377
|
+
|
378
|
+
# Legacy processor entities
|
379
|
+
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
380
|
+
target_entites_dict.update(field_metadata.ner)
|
nucliadb/writer/api/constants.py
CHANGED
@@ -17,21 +17,25 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
from typing import TYPE_CHECKING
|
21
|
-
|
22
20
|
from fastapi.params import Header
|
23
21
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
22
|
+
X_SKIP_STORE = Header(
|
23
|
+
description="If set to true, file fields will not be saved in the blob storage. They will only be sent to process.",
|
24
|
+
)
|
25
|
+
X_NUCLIADB_USER = Header()
|
26
|
+
X_FILE_PASSWORD = Header(
|
27
|
+
description="If a file is password protected, the password must be provided here for the file to be processed",
|
28
|
+
)
|
29
|
+
X_EXTRACT_STRATEGY = Header(
|
30
|
+
description="Extract strategy to use when uploading a file. If not provided, the default strategy will be used.",
|
31
|
+
)
|
32
|
+
X_FILENAME = Header(min_length=1, description="Name of the file being uploaded.")
|
33
|
+
X_MD5 = Header(
|
34
|
+
min_length=32,
|
35
|
+
max_length=32,
|
36
|
+
description="MD5 hash of the file being uploaded. This is used to check if the file has been uploaded before.",
|
37
|
+
)
|
38
|
+
X_PASSWORD = Header(
|
39
|
+
min_length=1, description="If the file is password protected, the password must be provided here."
|
40
|
+
)
|
41
|
+
X_LANGUAGE = Header()
|
@@ -112,7 +112,7 @@ async def kb_create_and_import_endpoint(request: Request):
|
|
112
112
|
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
113
113
|
import_kb_config = KnowledgeBoxConfig(
|
114
114
|
title=f"Imported KB - {now}",
|
115
|
-
learning_configuration=learning_config.
|
115
|
+
learning_configuration=learning_config.model_dump(),
|
116
116
|
)
|
117
117
|
kbid, slug = await create_kb(import_kb_config)
|
118
118
|
|
nucliadb/writer/api/v1/field.py
CHANGED
@@ -18,7 +18,7 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
from inspect import iscoroutinefunction
|
21
|
-
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
|
21
|
+
from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, Union
|
22
22
|
|
23
23
|
from fastapi import HTTPException, Response
|
24
24
|
from fastapi_versioning import version
|
@@ -30,9 +30,9 @@ from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
|
|
30
30
|
from nucliadb.ingest.processing import PushPayload, Source
|
31
31
|
from nucliadb.writer import SERVICE_NAME
|
32
32
|
from nucliadb.writer.api.constants import (
|
33
|
-
SKIP_STORE_DEFAULT,
|
34
33
|
X_FILE_PASSWORD,
|
35
34
|
X_NUCLIADB_USER,
|
35
|
+
X_SKIP_STORE,
|
36
36
|
)
|
37
37
|
from nucliadb.writer.api.v1 import transaction
|
38
38
|
from nucliadb.writer.api.v1.resource import (
|
@@ -55,7 +55,7 @@ from nucliadb_models.utils import FieldIdString
|
|
55
55
|
from nucliadb_models.writer import ResourceFieldAdded, ResourceUpdated
|
56
56
|
from nucliadb_protos import resources_pb2
|
57
57
|
from nucliadb_protos.resources_pb2 import FieldID, Metadata
|
58
|
-
from nucliadb_protos.writer_pb2 import BrokerMessage
|
58
|
+
from nucliadb_protos.writer_pb2 import BrokerMessage, FieldIDStatus, FieldStatus
|
59
59
|
from nucliadb_utils.authentication import requires
|
60
60
|
from nucliadb_utils.exceptions import LimitsExceededError, SendToProcessError
|
61
61
|
from nucliadb_utils.utilities import (
|
@@ -380,7 +380,7 @@ async def add_resource_field_file_rslug_prefix(
|
|
380
380
|
rslug: str,
|
381
381
|
field_id: FieldIdString,
|
382
382
|
field_payload: models.FileField,
|
383
|
-
x_skip_store: bool =
|
383
|
+
x_skip_store: Annotated[bool, X_SKIP_STORE] = False,
|
384
384
|
) -> ResourceFieldAdded:
|
385
385
|
return await add_field_to_resource_by_slug(
|
386
386
|
request, kbid, rslug, field_id, field_payload, skip_store=x_skip_store
|
@@ -402,7 +402,7 @@ async def add_resource_field_file_rid_prefix(
|
|
402
402
|
rid: str,
|
403
403
|
field_id: FieldIdString,
|
404
404
|
field_payload: models.FileField,
|
405
|
-
x_skip_store: bool =
|
405
|
+
x_skip_store: Annotated[bool, X_SKIP_STORE] = False,
|
406
406
|
) -> ResourceFieldAdded:
|
407
407
|
return await add_field_to_resource(
|
408
408
|
request, kbid, rid, field_id, field_payload, skip_store=x_skip_store
|
@@ -503,8 +503,8 @@ async def reprocess_file_field(
|
|
503
503
|
kbid: str,
|
504
504
|
rid: str,
|
505
505
|
field_id: FieldIdString,
|
506
|
-
x_nucliadb_user: str =
|
507
|
-
x_file_password: Optional[str] =
|
506
|
+
x_nucliadb_user: Annotated[str, X_NUCLIADB_USER] = "",
|
507
|
+
x_file_password: Annotated[Optional[str], X_FILE_PASSWORD] = None,
|
508
508
|
) -> ResourceUpdated:
|
509
509
|
await maybe_back_pressure(request, kbid, resource_uuid=rid)
|
510
510
|
|
@@ -553,6 +553,12 @@ async def reprocess_file_field(
|
|
553
553
|
writer.source = BrokerMessage.MessageSource.WRITER
|
554
554
|
writer.basic.metadata.useful = True
|
555
555
|
writer.basic.metadata.status = Metadata.Status.PENDING
|
556
|
+
writer.field_statuses.append(
|
557
|
+
FieldIDStatus(
|
558
|
+
id=FieldID(field_type=resources_pb2.FieldType.FILE, field=field_id),
|
559
|
+
status=FieldStatus.Status.PENDING,
|
560
|
+
)
|
561
|
+
)
|
556
562
|
await transaction.commit(writer, partition, wait=False)
|
557
563
|
# Send current resource to reprocess.
|
558
564
|
try:
|