arkindex-base-worker 0.3.7rc5__py3-none-any.whl → 0.5.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
  2. arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +1 -1
  6. arkindex_worker/image.py +167 -2
  7. arkindex_worker/models.py +18 -0
  8. arkindex_worker/utils.py +98 -4
  9. arkindex_worker/worker/__init__.py +117 -218
  10. arkindex_worker/worker/base.py +39 -46
  11. arkindex_worker/worker/classification.py +34 -18
  12. arkindex_worker/worker/corpus.py +86 -0
  13. arkindex_worker/worker/dataset.py +89 -26
  14. arkindex_worker/worker/element.py +352 -91
  15. arkindex_worker/worker/entity.py +13 -11
  16. arkindex_worker/worker/image.py +21 -0
  17. arkindex_worker/worker/metadata.py +26 -16
  18. arkindex_worker/worker/process.py +92 -0
  19. arkindex_worker/worker/task.py +5 -4
  20. arkindex_worker/worker/training.py +25 -10
  21. arkindex_worker/worker/transcription.py +89 -68
  22. arkindex_worker/worker/version.py +3 -1
  23. hooks/pre_gen_project.py +3 -0
  24. tests/__init__.py +8 -0
  25. tests/conftest.py +47 -58
  26. tests/test_base_worker.py +212 -12
  27. tests/test_dataset_worker.py +294 -437
  28. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  29. tests/test_elements_worker/test_cli.py +3 -11
  30. tests/test_elements_worker/test_corpus.py +168 -0
  31. tests/test_elements_worker/test_dataset.py +106 -157
  32. tests/test_elements_worker/test_element.py +427 -0
  33. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  34. tests/test_elements_worker/test_element_create_single.py +528 -0
  35. tests/test_elements_worker/test_element_list_children.py +969 -0
  36. tests/test_elements_worker/test_element_list_parents.py +530 -0
  37. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  38. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  39. tests/test_elements_worker/test_image.py +66 -0
  40. tests/test_elements_worker/test_metadata.py +252 -161
  41. tests/test_elements_worker/test_process.py +89 -0
  42. tests/test_elements_worker/test_task.py +8 -18
  43. tests/test_elements_worker/test_training.py +17 -8
  44. tests/test_elements_worker/test_transcription_create.py +873 -0
  45. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  46. tests/test_elements_worker/test_transcription_list.py +450 -0
  47. tests/test_elements_worker/test_version.py +60 -0
  48. tests/test_elements_worker/test_worker.py +578 -293
  49. tests/test_image.py +542 -209
  50. tests/test_merge.py +1 -2
  51. tests/test_utils.py +89 -4
  52. worker-demo/tests/__init__.py +0 -0
  53. worker-demo/tests/conftest.py +32 -0
  54. worker-demo/tests/test_worker.py +12 -0
  55. worker-demo/worker_demo/__init__.py +6 -0
  56. worker-demo/worker_demo/worker.py +19 -0
  57. arkindex_base_worker-0.3.7rc5.dist-info/RECORD +0 -41
  58. tests/test_elements_worker/test_elements.py +0 -2713
  59. tests/test_elements_worker/test_transcriptions.py +0 -2119
  60. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
@@ -2,12 +2,18 @@
2
2
  ElementsWorker methods for classifications and ML classes.
3
3
  """
4
4
 
5
- from apistar.exceptions import ErrorResponse
6
5
  from peewee import IntegrityError
7
6
 
7
+ from arkindex.exceptions import ErrorResponse
8
8
  from arkindex_worker import logger
9
9
  from arkindex_worker.cache import CachedClassification, CachedElement
10
10
  from arkindex_worker.models import Element
11
+ from arkindex_worker.utils import (
12
+ DEFAULT_BATCH_SIZE,
13
+ batch_publication,
14
+ make_batches,
15
+ pluralize,
16
+ )
11
17
 
12
18
 
13
19
  class ClassificationMixin:
@@ -21,7 +27,7 @@ class ClassificationMixin:
21
27
  )
22
28
  self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
23
29
  logger.info(
24
- f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
30
+ f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
25
31
  )
26
32
 
27
33
  def get_ml_class_id(self, ml_class: str) -> str:
@@ -39,7 +45,7 @@ class ClassificationMixin:
39
45
  if ml_class_id is None:
40
46
  logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
41
47
  try:
42
- response = self.request(
48
+ response = self.api_client.request(
43
49
  "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
44
50
  )
45
51
  ml_class_id = self.classes[ml_class] = response["id"]
@@ -119,7 +125,7 @@ class ClassificationMixin:
119
125
  )
120
126
  return
121
127
  try:
122
- created = self.request(
128
+ created = self.api_client.request(
123
129
  "CreateClassification",
124
130
  body={
125
131
  "element": str(element.id),
@@ -167,10 +173,12 @@ class ClassificationMixin:
167
173
 
168
174
  return created
169
175
 
176
+ @batch_publication
170
177
  def create_classifications(
171
178
  self,
172
179
  element: Element | CachedElement,
173
180
  classifications: list[dict[str, str | float | bool]],
181
+ batch_size: int = DEFAULT_BATCH_SIZE,
174
182
  ) -> list[dict[str, str | float | bool]]:
175
183
  """
176
184
  Create multiple classifications at once on the given element through the API.
@@ -185,6 +193,8 @@ class ClassificationMixin:
185
193
  high_confidence (bool)
186
194
  Optional. Whether or not the classification is of high confidence.
187
195
 
196
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
197
+
188
198
  :returns: List of created classifications, as returned in the ``classifications`` field by
189
199
  the ``CreateClassifications`` API endpoint.
190
200
  """
@@ -220,20 +230,26 @@ class ClassificationMixin:
220
230
  )
221
231
  return
222
232
 
223
- created_cls = self.request(
224
- "CreateClassifications",
225
- body={
226
- "parent": str(element.id),
227
- "worker_run_id": self.worker_run_id,
228
- "classifications": [
229
- {
230
- **classification,
231
- "ml_class": self.get_ml_class_id(classification["ml_class"]),
232
- }
233
- for classification in classifications
234
- ],
235
- },
236
- )["classifications"]
233
+ created_cls = [
234
+ created_cl
235
+ for batch in make_batches(classifications, "classification", batch_size)
236
+ for created_cl in self.api_client.request(
237
+ "CreateClassifications",
238
+ body={
239
+ "parent": str(element.id),
240
+ "worker_run_id": self.worker_run_id,
241
+ "classifications": [
242
+ {
243
+ **classification,
244
+ "ml_class": self.get_ml_class_id(
245
+ classification["ml_class"]
246
+ ),
247
+ }
248
+ for classification in batch
249
+ ],
250
+ },
251
+ )["classifications"]
252
+ ]
237
253
 
238
254
  for created_cl in created_cls:
239
255
  created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
@@ -0,0 +1,86 @@
1
+ """
2
+ BaseWorker methods for corpora.
3
+ """
4
+
5
+ from enum import Enum
6
+ from operator import itemgetter
7
+ from tempfile import _TemporaryFileWrapper
8
+ from uuid import UUID
9
+
10
+ from arkindex_worker import logger
11
+
12
+
13
+ class CorpusExportState(Enum):
14
+ """
15
+ State of a corpus export.
16
+ """
17
+
18
+ Created = "created"
19
+ """
20
+ The corpus export is created, awaiting its processing.
21
+ """
22
+
23
+ Running = "running"
24
+ """
25
+ The corpus export is being built.
26
+ """
27
+
28
+ Failed = "failed"
29
+ """
30
+ The corpus export failed.
31
+ """
32
+
33
+ Done = "done"
34
+ """
35
+ The corpus export ended in success.
36
+ """
37
+
38
+
39
+ class CorpusMixin:
40
+ def download_export(self, export_id: str) -> _TemporaryFileWrapper:
41
+ """
42
+ Download an export.
43
+
44
+ :param export_id: UUID of the export to download
45
+ :returns: The downloaded export stored in a temporary file.
46
+ """
47
+ try:
48
+ UUID(export_id)
49
+ except ValueError as e:
50
+ raise ValueError("export_id is not a valid uuid.") from e
51
+
52
+ logger.info(f"Downloading export ({export_id})...")
53
+ export: _TemporaryFileWrapper = self.api_client.request(
54
+ "DownloadExport", id=export_id
55
+ )
56
+ logger.info(f"Downloaded export ({export_id}) @ `{export.name}`")
57
+ return export
58
+
59
+ def download_latest_export(self) -> _TemporaryFileWrapper:
60
+ """
61
+ Download the latest export in `done` state of the current corpus.
62
+
63
+ :returns: The downloaded export stored in a temporary file.
64
+ """
65
+ # List all exports on the corpus
66
+ exports = self.api_client.paginate("ListExports", id=self.corpus_id)
67
+
68
+ # Find the latest that is in "done" state
69
+ exports: list[dict] = sorted(
70
+ list(
71
+ filter(
72
+ lambda export: export["state"] == CorpusExportState.Done.value,
73
+ exports,
74
+ )
75
+ ),
76
+ key=itemgetter("updated"),
77
+ reverse=True,
78
+ )
79
+ assert (
80
+ len(exports) > 0
81
+ ), f'No available exports found for the corpus ({self.corpus_id}) with state "{CorpusExportState.Done.value.capitalize()}".'
82
+
83
+ # Download latest export
84
+ export_id: str = exports[0]["id"]
85
+
86
+ return self.download_export(export_id)
@@ -2,12 +2,14 @@
2
2
  BaseWorker methods for datasets.
3
3
  """
4
4
 
5
+ import uuid
6
+ from argparse import ArgumentTypeError
5
7
  from collections.abc import Iterator
6
8
  from enum import Enum
7
9
 
8
10
  from arkindex_worker import logger
9
11
  from arkindex_worker.cache import unsupported_cache
10
- from arkindex_worker.models import Dataset, Element
12
+ from arkindex_worker.models import Dataset, Element, Set
11
13
 
12
14
 
13
15
  class DatasetState(Enum):
@@ -36,49 +38,110 @@ class DatasetState(Enum):
36
38
  """
37
39
 
38
40
 
41
+ class MissingDatasetArchive(Exception):
42
+ """
43
+ Exception raised when the compressed archive associated to
44
+ a dataset isn't found in its task artifacts.
45
+ """
46
+
47
+
48
+ def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
49
+ """The `--set` argument should have the following format:
50
+ <dataset_id>:<set_name>
51
+
52
+ Args:
53
+ value (str): Provided argument.
54
+
55
+ Raises:
56
+ ArgumentTypeError: When the value is invalid.
57
+
58
+ Returns:
59
+ tuple[uuid.UUID, str]: The ID of the dataset parsed as UUID and the name of the set.
60
+ """
61
+ values = value.split(":")
62
+ if len(values) != 2:
63
+ raise ArgumentTypeError(
64
+ f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
65
+ )
66
+
67
+ dataset_id, set_name = values
68
+ try:
69
+ dataset_id = uuid.UUID(dataset_id)
70
+ return (dataset_id, set_name)
71
+ except (TypeError, ValueError) as e:
72
+ raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
73
+
74
+
39
75
  class DatasetMixin:
40
- def list_process_datasets(self) -> Iterator[Dataset]:
76
+ def add_arguments(self) -> None:
77
+ """Define specific ``argparse`` arguments for the worker using this mixin"""
78
+ self.parser.add_argument(
79
+ "--set",
80
+ type=check_dataset_set,
81
+ nargs="+",
82
+ help="""
83
+ One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
84
+ (e.g.: "12341234-1234-1234-1234-123412341234:train")
85
+ """,
86
+ default=[],
87
+ )
88
+ super().add_arguments()
89
+
90
+ def list_process_sets(self) -> Iterator[Set]:
41
91
  """
42
- List datasets associated to the worker's process. This helper is not available in developer mode.
92
+ List dataset sets associated to the worker's process. This helper is not available in developer mode.
43
93
 
44
- :returns: An iterator of ``Dataset`` objects built from the ``ListProcessDatasets`` API endpoint.
94
+ :returns: An iterator of ``Set`` objects built from the ``ListProcessSets`` API endpoint.
45
95
  """
46
96
  assert not self.is_read_only, "This helper is not available in read-only mode."
47
97
 
48
98
  results = self.api_client.paginate(
49
- "ListProcessDatasets", id=self.process_information["id"]
99
+ "ListProcessSets", id=self.process_information["id"]
50
100
  )
51
101
 
52
102
  return map(
53
- lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
103
+ lambda result: Set(
104
+ name=result["set_name"], dataset=Dataset(**result["dataset"])
105
+ ),
54
106
  results,
55
107
  )
56
108
 
57
- def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
109
+ def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
58
110
  """
59
- List elements in a dataset.
111
+ List elements in a dataset set.
60
112
 
61
- :param dataset: Dataset to find elements in.
62
- :returns: An iterator of tuples built from the ``ListDatasetElements`` API endpoint.
113
+ :param dataset_set: Set to find elements in.
114
+ :returns: An iterator of Element built from the ``ListDatasetElements`` API endpoint.
63
115
  """
64
- assert dataset and isinstance(
65
- dataset, Dataset
66
- ), "dataset shouldn't be null and should be a Dataset"
116
+ assert dataset_set and isinstance(
117
+ dataset_set, Set
118
+ ), "dataset_set shouldn't be null and should be a Set"
119
+
120
+ results = self.api_client.paginate(
121
+ "ListDatasetElements", id=dataset_set.dataset.id, set=dataset_set.name
122
+ )
67
123
 
68
- if dataset.sets == dataset.selected_sets:
69
- results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
70
- else:
71
- results = iter(
72
- element
73
- for selected_set in dataset.selected_sets
74
- for element in self.api_client.paginate(
75
- "ListDatasetElements", id=dataset.id, set=selected_set
124
+ return map(lambda result: Element(**result["element"]), results)
125
+
126
+ def list_sets(self) -> Iterator[Set]:
127
+ """
128
+ List the sets to be processed, either from the CLI arguments or using the
129
+ [list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
130
+
131
+ :returns: An iterator of ``Set`` objects.
132
+ """
133
+ if not self.is_read_only:
134
+ yield from self.list_process_sets()
135
+
136
+ datasets: dict[uuid.UUID, Dataset] = {}
137
+ for dataset_id, set_name in self.args.set:
138
+ # Retrieving dataset information if not already cached
139
+ if dataset_id not in datasets:
140
+ datasets[dataset_id] = Dataset(
141
+ **self.api_client.request("RetrieveDataset", id=dataset_id)
76
142
  )
77
- )
78
143
 
79
- return map(
80
- lambda result: (result["set"], Element(**result["element"])), results
81
- )
144
+ yield Set(name=set_name, dataset=datasets[dataset_id])
82
145
 
83
146
  @unsupported_cache
84
147
  def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
@@ -100,7 +163,7 @@ class DatasetMixin:
100
163
  logger.warning("Cannot update dataset as this worker is in read-only mode")
101
164
  return
102
165
 
103
- updated_dataset = self.request(
166
+ updated_dataset = self.api_client.request(
104
167
  "PartialUpdateDataset",
105
168
  id=dataset.id,
106
169
  body={"state": state.value},