arkindex-base-worker 0.3.7rc4__py3-none-any.whl → 0.5.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
- arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
- arkindex_worker/cache.py +1 -1
- arkindex_worker/image.py +167 -2
- arkindex_worker/models.py +18 -0
- arkindex_worker/utils.py +98 -4
- arkindex_worker/worker/__init__.py +117 -218
- arkindex_worker/worker/base.py +39 -46
- arkindex_worker/worker/classification.py +45 -29
- arkindex_worker/worker/corpus.py +86 -0
- arkindex_worker/worker/dataset.py +89 -26
- arkindex_worker/worker/element.py +352 -91
- arkindex_worker/worker/entity.py +13 -11
- arkindex_worker/worker/image.py +21 -0
- arkindex_worker/worker/metadata.py +26 -16
- arkindex_worker/worker/process.py +92 -0
- arkindex_worker/worker/task.py +5 -4
- arkindex_worker/worker/training.py +25 -10
- arkindex_worker/worker/transcription.py +89 -68
- arkindex_worker/worker/version.py +3 -1
- hooks/pre_gen_project.py +3 -0
- tests/__init__.py +8 -0
- tests/conftest.py +47 -58
- tests/test_base_worker.py +212 -12
- tests/test_dataset_worker.py +294 -437
- tests/test_elements_worker/{test_classifications.py → test_classification.py} +313 -200
- tests/test_elements_worker/test_cli.py +3 -11
- tests/test_elements_worker/test_corpus.py +168 -0
- tests/test_elements_worker/test_dataset.py +106 -157
- tests/test_elements_worker/test_element.py +427 -0
- tests/test_elements_worker/test_element_create_multiple.py +715 -0
- tests/test_elements_worker/test_element_create_single.py +528 -0
- tests/test_elements_worker/test_element_list_children.py +969 -0
- tests/test_elements_worker/test_element_list_parents.py +530 -0
- tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
- tests/test_elements_worker/test_entity_list_and_check.py +160 -0
- tests/test_elements_worker/test_image.py +66 -0
- tests/test_elements_worker/test_metadata.py +252 -161
- tests/test_elements_worker/test_process.py +89 -0
- tests/test_elements_worker/test_task.py +8 -18
- tests/test_elements_worker/test_training.py +17 -8
- tests/test_elements_worker/test_transcription_create.py +873 -0
- tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
- tests/test_elements_worker/test_transcription_list.py +450 -0
- tests/test_elements_worker/test_version.py +60 -0
- tests/test_elements_worker/test_worker.py +578 -293
- tests/test_image.py +542 -209
- tests/test_merge.py +1 -2
- tests/test_utils.py +89 -4
- worker-demo/tests/__init__.py +0 -0
- worker-demo/tests/conftest.py +32 -0
- worker-demo/tests/test_worker.py +12 -0
- worker-demo/worker_demo/__init__.py +6 -0
- worker-demo/worker_demo/worker.py +19 -0
- arkindex_base_worker-0.3.7rc4.dist-info/RECORD +0 -41
- tests/test_elements_worker/test_elements.py +0 -2713
- tests/test_elements_worker/test_transcriptions.py +0 -2119
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
|
@@ -2,14 +2,18 @@
|
|
|
2
2
|
ElementsWorker methods for classifications and ML classes.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from uuid import UUID
|
|
6
|
-
|
|
7
|
-
from apistar.exceptions import ErrorResponse
|
|
8
5
|
from peewee import IntegrityError
|
|
9
6
|
|
|
7
|
+
from arkindex.exceptions import ErrorResponse
|
|
10
8
|
from arkindex_worker import logger
|
|
11
9
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
12
10
|
from arkindex_worker.models import Element
|
|
11
|
+
from arkindex_worker.utils import (
|
|
12
|
+
DEFAULT_BATCH_SIZE,
|
|
13
|
+
batch_publication,
|
|
14
|
+
make_batches,
|
|
15
|
+
pluralize,
|
|
16
|
+
)
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
class ClassificationMixin:
|
|
@@ -23,7 +27,7 @@ class ClassificationMixin:
|
|
|
23
27
|
)
|
|
24
28
|
self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
|
|
25
29
|
logger.info(
|
|
26
|
-
f
|
|
30
|
+
f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
|
|
27
31
|
)
|
|
28
32
|
|
|
29
33
|
def get_ml_class_id(self, ml_class: str) -> str:
|
|
@@ -41,7 +45,7 @@ class ClassificationMixin:
|
|
|
41
45
|
if ml_class_id is None:
|
|
42
46
|
logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
|
|
43
47
|
try:
|
|
44
|
-
response = self.request(
|
|
48
|
+
response = self.api_client.request(
|
|
45
49
|
"CreateMLClass", id=self.corpus_id, body={"name": ml_class}
|
|
46
50
|
)
|
|
47
51
|
ml_class_id = self.classes[ml_class] = response["id"]
|
|
@@ -121,7 +125,7 @@ class ClassificationMixin:
|
|
|
121
125
|
)
|
|
122
126
|
return
|
|
123
127
|
try:
|
|
124
|
-
created = self.request(
|
|
128
|
+
created = self.api_client.request(
|
|
125
129
|
"CreateClassification",
|
|
126
130
|
body={
|
|
127
131
|
"element": str(element.id),
|
|
@@ -169,19 +173,27 @@ class ClassificationMixin:
|
|
|
169
173
|
|
|
170
174
|
return created
|
|
171
175
|
|
|
176
|
+
@batch_publication
|
|
172
177
|
def create_classifications(
|
|
173
178
|
self,
|
|
174
179
|
element: Element | CachedElement,
|
|
175
180
|
classifications: list[dict[str, str | float | bool]],
|
|
181
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
176
182
|
) -> list[dict[str, str | float | bool]]:
|
|
177
183
|
"""
|
|
178
184
|
Create multiple classifications at once on the given element through the API.
|
|
179
185
|
|
|
180
186
|
:param element: The element to create classifications on.
|
|
181
|
-
:param classifications:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
187
|
+
:param classifications: A list of dicts representing a classification each, with the following keys:
|
|
188
|
+
|
|
189
|
+
ml_class (str)
|
|
190
|
+
Required. Name of the MLClass to use.
|
|
191
|
+
confidence (float)
|
|
192
|
+
Required. Confidence score for the classification. Must be between 0 and 1.
|
|
193
|
+
high_confidence (bool)
|
|
194
|
+
Optional. Whether or not the classification is of high confidence.
|
|
195
|
+
|
|
196
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
185
197
|
|
|
186
198
|
:returns: List of created classifications, as returned in the ``classifications`` field by
|
|
187
199
|
the ``CreateClassifications`` API endpoint.
|
|
@@ -194,18 +206,10 @@ class ClassificationMixin:
|
|
|
194
206
|
), "classifications shouldn't be null and should be of type list"
|
|
195
207
|
|
|
196
208
|
for index, classification in enumerate(classifications):
|
|
197
|
-
|
|
209
|
+
ml_class = classification.get("ml_class")
|
|
198
210
|
assert (
|
|
199
|
-
|
|
200
|
-
), f"Classification at index {index} in classifications:
|
|
201
|
-
|
|
202
|
-
# Make sure it's a valid UUID
|
|
203
|
-
try:
|
|
204
|
-
UUID(ml_class_id)
|
|
205
|
-
except ValueError as e:
|
|
206
|
-
raise ValueError(
|
|
207
|
-
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
|
|
208
|
-
) from e
|
|
211
|
+
ml_class and isinstance(ml_class, str)
|
|
212
|
+
), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
|
|
209
213
|
|
|
210
214
|
confidence = classification.get("confidence")
|
|
211
215
|
assert (
|
|
@@ -226,14 +230,26 @@ class ClassificationMixin:
|
|
|
226
230
|
)
|
|
227
231
|
return
|
|
228
232
|
|
|
229
|
-
created_cls =
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
"
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
233
|
+
created_cls = [
|
|
234
|
+
created_cl
|
|
235
|
+
for batch in make_batches(classifications, "classification", batch_size)
|
|
236
|
+
for created_cl in self.api_client.request(
|
|
237
|
+
"CreateClassifications",
|
|
238
|
+
body={
|
|
239
|
+
"parent": str(element.id),
|
|
240
|
+
"worker_run_id": self.worker_run_id,
|
|
241
|
+
"classifications": [
|
|
242
|
+
{
|
|
243
|
+
**classification,
|
|
244
|
+
"ml_class": self.get_ml_class_id(
|
|
245
|
+
classification["ml_class"]
|
|
246
|
+
),
|
|
247
|
+
}
|
|
248
|
+
for classification in batch
|
|
249
|
+
],
|
|
250
|
+
},
|
|
251
|
+
)["classifications"]
|
|
252
|
+
]
|
|
237
253
|
|
|
238
254
|
for created_cl in created_cls:
|
|
239
255
|
created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BaseWorker methods for corpora.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from operator import itemgetter
|
|
7
|
+
from tempfile import _TemporaryFileWrapper
|
|
8
|
+
from uuid import UUID
|
|
9
|
+
|
|
10
|
+
from arkindex_worker import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CorpusExportState(Enum):
|
|
14
|
+
"""
|
|
15
|
+
State of a corpus export.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
Created = "created"
|
|
19
|
+
"""
|
|
20
|
+
The corpus export is created, awaiting its processing.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
Running = "running"
|
|
24
|
+
"""
|
|
25
|
+
The corpus export is being built.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
Failed = "failed"
|
|
29
|
+
"""
|
|
30
|
+
The corpus export failed.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
Done = "done"
|
|
34
|
+
"""
|
|
35
|
+
The corpus export ended in success.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CorpusMixin:
|
|
40
|
+
def download_export(self, export_id: str) -> _TemporaryFileWrapper:
|
|
41
|
+
"""
|
|
42
|
+
Download an export.
|
|
43
|
+
|
|
44
|
+
:param export_id: UUID of the export to download
|
|
45
|
+
:returns: The downloaded export stored in a temporary file.
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
UUID(export_id)
|
|
49
|
+
except ValueError as e:
|
|
50
|
+
raise ValueError("export_id is not a valid uuid.") from e
|
|
51
|
+
|
|
52
|
+
logger.info(f"Downloading export ({export_id})...")
|
|
53
|
+
export: _TemporaryFileWrapper = self.api_client.request(
|
|
54
|
+
"DownloadExport", id=export_id
|
|
55
|
+
)
|
|
56
|
+
logger.info(f"Downloaded export ({export_id}) @ `{export.name}`")
|
|
57
|
+
return export
|
|
58
|
+
|
|
59
|
+
def download_latest_export(self) -> _TemporaryFileWrapper:
|
|
60
|
+
"""
|
|
61
|
+
Download the latest export in `done` state of the current corpus.
|
|
62
|
+
|
|
63
|
+
:returns: The downloaded export stored in a temporary file.
|
|
64
|
+
"""
|
|
65
|
+
# List all exports on the corpus
|
|
66
|
+
exports = self.api_client.paginate("ListExports", id=self.corpus_id)
|
|
67
|
+
|
|
68
|
+
# Find the latest that is in "done" state
|
|
69
|
+
exports: list[dict] = sorted(
|
|
70
|
+
list(
|
|
71
|
+
filter(
|
|
72
|
+
lambda export: export["state"] == CorpusExportState.Done.value,
|
|
73
|
+
exports,
|
|
74
|
+
)
|
|
75
|
+
),
|
|
76
|
+
key=itemgetter("updated"),
|
|
77
|
+
reverse=True,
|
|
78
|
+
)
|
|
79
|
+
assert (
|
|
80
|
+
len(exports) > 0
|
|
81
|
+
), f'No available exports found for the corpus ({self.corpus_id}) with state "{CorpusExportState.Done.value.capitalize()}".'
|
|
82
|
+
|
|
83
|
+
# Download latest export
|
|
84
|
+
export_id: str = exports[0]["id"]
|
|
85
|
+
|
|
86
|
+
return self.download_export(export_id)
|
|
@@ -2,12 +2,14 @@
|
|
|
2
2
|
BaseWorker methods for datasets.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import uuid
|
|
6
|
+
from argparse import ArgumentTypeError
|
|
5
7
|
from collections.abc import Iterator
|
|
6
8
|
from enum import Enum
|
|
7
9
|
|
|
8
10
|
from arkindex_worker import logger
|
|
9
11
|
from arkindex_worker.cache import unsupported_cache
|
|
10
|
-
from arkindex_worker.models import Dataset, Element
|
|
12
|
+
from arkindex_worker.models import Dataset, Element, Set
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class DatasetState(Enum):
|
|
@@ -36,49 +38,110 @@ class DatasetState(Enum):
|
|
|
36
38
|
"""
|
|
37
39
|
|
|
38
40
|
|
|
41
|
+
class MissingDatasetArchive(Exception):
|
|
42
|
+
"""
|
|
43
|
+
Exception raised when the compressed archive associated to
|
|
44
|
+
a dataset isn't found in its task artifacts.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
|
|
49
|
+
"""The `--set` argument should have the following format:
|
|
50
|
+
<dataset_id>:<set_name>
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
value (str): Provided argument.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ArgumentTypeError: When the value is invalid.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
tuple[uuid.UUID, str]: The ID of the dataset parsed as UUID and the name of the set.
|
|
60
|
+
"""
|
|
61
|
+
values = value.split(":")
|
|
62
|
+
if len(values) != 2:
|
|
63
|
+
raise ArgumentTypeError(
|
|
64
|
+
f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
dataset_id, set_name = values
|
|
68
|
+
try:
|
|
69
|
+
dataset_id = uuid.UUID(dataset_id)
|
|
70
|
+
return (dataset_id, set_name)
|
|
71
|
+
except (TypeError, ValueError) as e:
|
|
72
|
+
raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
|
|
73
|
+
|
|
74
|
+
|
|
39
75
|
class DatasetMixin:
|
|
40
|
-
def
|
|
76
|
+
def add_arguments(self) -> None:
|
|
77
|
+
"""Define specific ``argparse`` arguments for the worker using this mixin"""
|
|
78
|
+
self.parser.add_argument(
|
|
79
|
+
"--set",
|
|
80
|
+
type=check_dataset_set,
|
|
81
|
+
nargs="+",
|
|
82
|
+
help="""
|
|
83
|
+
One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
|
|
84
|
+
(e.g.: "12341234-1234-1234-1234-123412341234:train")
|
|
85
|
+
""",
|
|
86
|
+
default=[],
|
|
87
|
+
)
|
|
88
|
+
super().add_arguments()
|
|
89
|
+
|
|
90
|
+
def list_process_sets(self) -> Iterator[Set]:
|
|
41
91
|
"""
|
|
42
|
-
List
|
|
92
|
+
List dataset sets associated to the worker's process. This helper is not available in developer mode.
|
|
43
93
|
|
|
44
|
-
:returns: An iterator of ``
|
|
94
|
+
:returns: An iterator of ``Set`` objects built from the ``ListProcessSets`` API endpoint.
|
|
45
95
|
"""
|
|
46
96
|
assert not self.is_read_only, "This helper is not available in read-only mode."
|
|
47
97
|
|
|
48
98
|
results = self.api_client.paginate(
|
|
49
|
-
"
|
|
99
|
+
"ListProcessSets", id=self.process_information["id"]
|
|
50
100
|
)
|
|
51
101
|
|
|
52
102
|
return map(
|
|
53
|
-
lambda result:
|
|
103
|
+
lambda result: Set(
|
|
104
|
+
name=result["set_name"], dataset=Dataset(**result["dataset"])
|
|
105
|
+
),
|
|
54
106
|
results,
|
|
55
107
|
)
|
|
56
108
|
|
|
57
|
-
def
|
|
109
|
+
def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
|
|
58
110
|
"""
|
|
59
|
-
List elements in a dataset.
|
|
111
|
+
List elements in a dataset set.
|
|
60
112
|
|
|
61
|
-
:param
|
|
62
|
-
:returns: An iterator of
|
|
113
|
+
:param dataset_set: Set to find elements in.
|
|
114
|
+
:returns: An iterator of Element built from the ``ListDatasetElements`` API endpoint.
|
|
63
115
|
"""
|
|
64
|
-
assert
|
|
65
|
-
|
|
66
|
-
), "
|
|
116
|
+
assert dataset_set and isinstance(
|
|
117
|
+
dataset_set, Set
|
|
118
|
+
), "dataset_set shouldn't be null and should be a Set"
|
|
119
|
+
|
|
120
|
+
results = self.api_client.paginate(
|
|
121
|
+
"ListDatasetElements", id=dataset_set.dataset.id, set=dataset_set.name
|
|
122
|
+
)
|
|
67
123
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
124
|
+
return map(lambda result: Element(**result["element"]), results)
|
|
125
|
+
|
|
126
|
+
def list_sets(self) -> Iterator[Set]:
|
|
127
|
+
"""
|
|
128
|
+
List the sets to be processed, either from the CLI arguments or using the
|
|
129
|
+
[list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
|
|
130
|
+
|
|
131
|
+
:returns: An iterator of ``Set`` objects.
|
|
132
|
+
"""
|
|
133
|
+
if not self.is_read_only:
|
|
134
|
+
yield from self.list_process_sets()
|
|
135
|
+
|
|
136
|
+
datasets: dict[uuid.UUID, Dataset] = {}
|
|
137
|
+
for dataset_id, set_name in self.args.set:
|
|
138
|
+
# Retrieving dataset information if not already cached
|
|
139
|
+
if dataset_id not in datasets:
|
|
140
|
+
datasets[dataset_id] = Dataset(
|
|
141
|
+
**self.api_client.request("RetrieveDataset", id=dataset_id)
|
|
76
142
|
)
|
|
77
|
-
)
|
|
78
143
|
|
|
79
|
-
|
|
80
|
-
lambda result: (result["set"], Element(**result["element"])), results
|
|
81
|
-
)
|
|
144
|
+
yield Set(name=set_name, dataset=datasets[dataset_id])
|
|
82
145
|
|
|
83
146
|
@unsupported_cache
|
|
84
147
|
def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
|
|
@@ -100,7 +163,7 @@ class DatasetMixin:
|
|
|
100
163
|
logger.warning("Cannot update dataset as this worker is in read-only mode")
|
|
101
164
|
return
|
|
102
165
|
|
|
103
|
-
updated_dataset = self.request(
|
|
166
|
+
updated_dataset = self.api_client.request(
|
|
104
167
|
"PartialUpdateDataset",
|
|
105
168
|
id=dataset.id,
|
|
106
169
|
body={"state": state.value},
|