arkindex-base-worker 0.3.7rc4__py3-none-any.whl → 0.5.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
  2. arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +1 -1
  6. arkindex_worker/image.py +167 -2
  7. arkindex_worker/models.py +18 -0
  8. arkindex_worker/utils.py +98 -4
  9. arkindex_worker/worker/__init__.py +117 -218
  10. arkindex_worker/worker/base.py +39 -46
  11. arkindex_worker/worker/classification.py +45 -29
  12. arkindex_worker/worker/corpus.py +86 -0
  13. arkindex_worker/worker/dataset.py +89 -26
  14. arkindex_worker/worker/element.py +352 -91
  15. arkindex_worker/worker/entity.py +13 -11
  16. arkindex_worker/worker/image.py +21 -0
  17. arkindex_worker/worker/metadata.py +26 -16
  18. arkindex_worker/worker/process.py +92 -0
  19. arkindex_worker/worker/task.py +5 -4
  20. arkindex_worker/worker/training.py +25 -10
  21. arkindex_worker/worker/transcription.py +89 -68
  22. arkindex_worker/worker/version.py +3 -1
  23. hooks/pre_gen_project.py +3 -0
  24. tests/__init__.py +8 -0
  25. tests/conftest.py +47 -58
  26. tests/test_base_worker.py +212 -12
  27. tests/test_dataset_worker.py +294 -437
  28. tests/test_elements_worker/{test_classifications.py → test_classification.py} +313 -200
  29. tests/test_elements_worker/test_cli.py +3 -11
  30. tests/test_elements_worker/test_corpus.py +168 -0
  31. tests/test_elements_worker/test_dataset.py +106 -157
  32. tests/test_elements_worker/test_element.py +427 -0
  33. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  34. tests/test_elements_worker/test_element_create_single.py +528 -0
  35. tests/test_elements_worker/test_element_list_children.py +969 -0
  36. tests/test_elements_worker/test_element_list_parents.py +530 -0
  37. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  38. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  39. tests/test_elements_worker/test_image.py +66 -0
  40. tests/test_elements_worker/test_metadata.py +252 -161
  41. tests/test_elements_worker/test_process.py +89 -0
  42. tests/test_elements_worker/test_task.py +8 -18
  43. tests/test_elements_worker/test_training.py +17 -8
  44. tests/test_elements_worker/test_transcription_create.py +873 -0
  45. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  46. tests/test_elements_worker/test_transcription_list.py +450 -0
  47. tests/test_elements_worker/test_version.py +60 -0
  48. tests/test_elements_worker/test_worker.py +578 -293
  49. tests/test_image.py +542 -209
  50. tests/test_merge.py +1 -2
  51. tests/test_utils.py +89 -4
  52. worker-demo/tests/__init__.py +0 -0
  53. worker-demo/tests/conftest.py +32 -0
  54. worker-demo/tests/test_worker.py +12 -0
  55. worker-demo/worker_demo/__init__.py +6 -0
  56. worker-demo/worker_demo/worker.py +19 -0
  57. arkindex_base_worker-0.3.7rc4.dist-info/RECORD +0 -41
  58. tests/test_elements_worker/test_elements.py +0 -2713
  59. tests/test_elements_worker/test_transcriptions.py +0 -2119
  60. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
@@ -2,14 +2,18 @@
2
2
  ElementsWorker methods for classifications and ML classes.
3
3
  """
4
4
 
5
- from uuid import UUID
6
-
7
- from apistar.exceptions import ErrorResponse
8
5
  from peewee import IntegrityError
9
6
 
7
+ from arkindex.exceptions import ErrorResponse
10
8
  from arkindex_worker import logger
11
9
  from arkindex_worker.cache import CachedClassification, CachedElement
12
10
  from arkindex_worker.models import Element
11
+ from arkindex_worker.utils import (
12
+ DEFAULT_BATCH_SIZE,
13
+ batch_publication,
14
+ make_batches,
15
+ pluralize,
16
+ )
13
17
 
14
18
 
15
19
  class ClassificationMixin:
@@ -23,7 +27,7 @@ class ClassificationMixin:
23
27
  )
24
28
  self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
25
29
  logger.info(
26
- f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
30
+ f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
27
31
  )
28
32
 
29
33
  def get_ml_class_id(self, ml_class: str) -> str:
@@ -41,7 +45,7 @@ class ClassificationMixin:
41
45
  if ml_class_id is None:
42
46
  logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
43
47
  try:
44
- response = self.request(
48
+ response = self.api_client.request(
45
49
  "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
46
50
  )
47
51
  ml_class_id = self.classes[ml_class] = response["id"]
@@ -121,7 +125,7 @@ class ClassificationMixin:
121
125
  )
122
126
  return
123
127
  try:
124
- created = self.request(
128
+ created = self.api_client.request(
125
129
  "CreateClassification",
126
130
  body={
127
131
  "element": str(element.id),
@@ -169,19 +173,27 @@ class ClassificationMixin:
169
173
 
170
174
  return created
171
175
 
176
+ @batch_publication
172
177
  def create_classifications(
173
178
  self,
174
179
  element: Element | CachedElement,
175
180
  classifications: list[dict[str, str | float | bool]],
181
+ batch_size: int = DEFAULT_BATCH_SIZE,
176
182
  ) -> list[dict[str, str | float | bool]]:
177
183
  """
178
184
  Create multiple classifications at once on the given element through the API.
179
185
 
180
186
  :param element: The element to create classifications on.
181
- :param classifications: The classifications to create, a list of dicts. Each of them contains
182
- a **ml_class_id** (str), the ID of the MLClass for this classification;
183
- a **confidence** (float), the confidence score, between 0 and 1;
184
- a **high_confidence** (bool), the high confidence state of the classification.
187
+ :param classifications: A list of dicts representing a classification each, with the following keys:
188
+
189
+ ml_class (str)
190
+ Required. Name of the MLClass to use.
191
+ confidence (float)
192
+ Required. Confidence score for the classification. Must be between 0 and 1.
193
+ high_confidence (bool)
194
+ Optional. Whether or not the classification is of high confidence.
195
+
196
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
185
197
 
186
198
  :returns: List of created classifications, as returned in the ``classifications`` field by
187
199
  the ``CreateClassifications`` API endpoint.
@@ -194,18 +206,10 @@ class ClassificationMixin:
194
206
  ), "classifications shouldn't be null and should be of type list"
195
207
 
196
208
  for index, classification in enumerate(classifications):
197
- ml_class_id = classification.get("ml_class_id")
209
+ ml_class = classification.get("ml_class")
198
210
  assert (
199
- ml_class_id and isinstance(ml_class_id, str)
200
- ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
201
-
202
- # Make sure it's a valid UUID
203
- try:
204
- UUID(ml_class_id)
205
- except ValueError as e:
206
- raise ValueError(
207
- f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
208
- ) from e
211
+ ml_class and isinstance(ml_class, str)
212
+ ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
209
213
 
210
214
  confidence = classification.get("confidence")
211
215
  assert (
@@ -226,14 +230,26 @@ class ClassificationMixin:
226
230
  )
227
231
  return
228
232
 
229
- created_cls = self.request(
230
- "CreateClassifications",
231
- body={
232
- "parent": str(element.id),
233
- "worker_run_id": self.worker_run_id,
234
- "classifications": classifications,
235
- },
236
- )["classifications"]
233
+ created_cls = [
234
+ created_cl
235
+ for batch in make_batches(classifications, "classification", batch_size)
236
+ for created_cl in self.api_client.request(
237
+ "CreateClassifications",
238
+ body={
239
+ "parent": str(element.id),
240
+ "worker_run_id": self.worker_run_id,
241
+ "classifications": [
242
+ {
243
+ **classification,
244
+ "ml_class": self.get_ml_class_id(
245
+ classification["ml_class"]
246
+ ),
247
+ }
248
+ for classification in batch
249
+ ],
250
+ },
251
+ )["classifications"]
252
+ ]
237
253
 
238
254
  for created_cl in created_cls:
239
255
  created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
@@ -0,0 +1,86 @@
1
+ """
2
+ BaseWorker methods for corpora.
3
+ """
4
+
5
+ from enum import Enum
6
+ from operator import itemgetter
7
+ from tempfile import _TemporaryFileWrapper
8
+ from uuid import UUID
9
+
10
+ from arkindex_worker import logger
11
+
12
+
13
+ class CorpusExportState(Enum):
14
+ """
15
+ State of a corpus export.
16
+ """
17
+
18
+ Created = "created"
19
+ """
20
+ The corpus export is created, awaiting its processing.
21
+ """
22
+
23
+ Running = "running"
24
+ """
25
+ The corpus export is being built.
26
+ """
27
+
28
+ Failed = "failed"
29
+ """
30
+ The corpus export failed.
31
+ """
32
+
33
+ Done = "done"
34
+ """
35
+ The corpus export ended in success.
36
+ """
37
+
38
+
39
+ class CorpusMixin:
40
+ def download_export(self, export_id: str) -> _TemporaryFileWrapper:
41
+ """
42
+ Download an export.
43
+
44
+ :param export_id: UUID of the export to download
45
+ :returns: The downloaded export stored in a temporary file.
46
+ """
47
+ try:
48
+ UUID(export_id)
49
+ except ValueError as e:
50
+ raise ValueError("export_id is not a valid uuid.") from e
51
+
52
+ logger.info(f"Downloading export ({export_id})...")
53
+ export: _TemporaryFileWrapper = self.api_client.request(
54
+ "DownloadExport", id=export_id
55
+ )
56
+ logger.info(f"Downloaded export ({export_id}) @ `{export.name}`")
57
+ return export
58
+
59
+ def download_latest_export(self) -> _TemporaryFileWrapper:
60
+ """
61
+ Download the latest export in `done` state of the current corpus.
62
+
63
+ :returns: The downloaded export stored in a temporary file.
64
+ """
65
+ # List all exports on the corpus
66
+ exports = self.api_client.paginate("ListExports", id=self.corpus_id)
67
+
68
+ # Find the latest that is in "done" state
69
+ exports: list[dict] = sorted(
70
+ list(
71
+ filter(
72
+ lambda export: export["state"] == CorpusExportState.Done.value,
73
+ exports,
74
+ )
75
+ ),
76
+ key=itemgetter("updated"),
77
+ reverse=True,
78
+ )
79
+ assert (
80
+ len(exports) > 0
81
+ ), f'No available exports found for the corpus ({self.corpus_id}) with state "{CorpusExportState.Done.value.capitalize()}".'
82
+
83
+ # Download latest export
84
+ export_id: str = exports[0]["id"]
85
+
86
+ return self.download_export(export_id)
@@ -2,12 +2,14 @@
2
2
  BaseWorker methods for datasets.
3
3
  """
4
4
 
5
+ import uuid
6
+ from argparse import ArgumentTypeError
5
7
  from collections.abc import Iterator
6
8
  from enum import Enum
7
9
 
8
10
  from arkindex_worker import logger
9
11
  from arkindex_worker.cache import unsupported_cache
10
- from arkindex_worker.models import Dataset, Element
12
+ from arkindex_worker.models import Dataset, Element, Set
11
13
 
12
14
 
13
15
  class DatasetState(Enum):
@@ -36,49 +38,110 @@ class DatasetState(Enum):
36
38
  """
37
39
 
38
40
 
41
+ class MissingDatasetArchive(Exception):
42
+ """
43
+ Exception raised when the compressed archive associated to
44
+ a dataset isn't found in its task artifacts.
45
+ """
46
+
47
+
48
+ def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
49
+ """The `--set` argument should have the following format:
50
+ <dataset_id>:<set_name>
51
+
52
+ Args:
53
+ value (str): Provided argument.
54
+
55
+ Raises:
56
+ ArgumentTypeError: When the value is invalid.
57
+
58
+ Returns:
59
+ tuple[uuid.UUID, str]: The ID of the dataset parsed as UUID and the name of the set.
60
+ """
61
+ values = value.split(":")
62
+ if len(values) != 2:
63
+ raise ArgumentTypeError(
64
+ f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
65
+ )
66
+
67
+ dataset_id, set_name = values
68
+ try:
69
+ dataset_id = uuid.UUID(dataset_id)
70
+ return (dataset_id, set_name)
71
+ except (TypeError, ValueError) as e:
72
+ raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
73
+
74
+
39
75
  class DatasetMixin:
40
- def list_process_datasets(self) -> Iterator[Dataset]:
76
+ def add_arguments(self) -> None:
77
+ """Define specific ``argparse`` arguments for the worker using this mixin"""
78
+ self.parser.add_argument(
79
+ "--set",
80
+ type=check_dataset_set,
81
+ nargs="+",
82
+ help="""
83
+ One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
84
+ (e.g.: "12341234-1234-1234-1234-123412341234:train")
85
+ """,
86
+ default=[],
87
+ )
88
+ super().add_arguments()
89
+
90
+ def list_process_sets(self) -> Iterator[Set]:
41
91
  """
42
- List datasets associated to the worker's process. This helper is not available in developer mode.
92
+ List dataset sets associated to the worker's process. This helper is not available in developer mode.
43
93
 
44
- :returns: An iterator of ``Dataset`` objects built from the ``ListProcessDatasets`` API endpoint.
94
+ :returns: An iterator of ``Set`` objects built from the ``ListProcessSets`` API endpoint.
45
95
  """
46
96
  assert not self.is_read_only, "This helper is not available in read-only mode."
47
97
 
48
98
  results = self.api_client.paginate(
49
- "ListProcessDatasets", id=self.process_information["id"]
99
+ "ListProcessSets", id=self.process_information["id"]
50
100
  )
51
101
 
52
102
  return map(
53
- lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
103
+ lambda result: Set(
104
+ name=result["set_name"], dataset=Dataset(**result["dataset"])
105
+ ),
54
106
  results,
55
107
  )
56
108
 
57
- def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
109
+ def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
58
110
  """
59
- List elements in a dataset.
111
+ List elements in a dataset set.
60
112
 
61
- :param dataset: Dataset to find elements in.
62
- :returns: An iterator of tuples built from the ``ListDatasetElements`` API endpoint.
113
+ :param dataset_set: Set to find elements in.
114
+ :returns: An iterator of Element built from the ``ListDatasetElements`` API endpoint.
63
115
  """
64
- assert dataset and isinstance(
65
- dataset, Dataset
66
- ), "dataset shouldn't be null and should be a Dataset"
116
+ assert dataset_set and isinstance(
117
+ dataset_set, Set
118
+ ), "dataset_set shouldn't be null and should be a Set"
119
+
120
+ results = self.api_client.paginate(
121
+ "ListDatasetElements", id=dataset_set.dataset.id, set=dataset_set.name
122
+ )
67
123
 
68
- if dataset.sets == dataset.selected_sets:
69
- results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
70
- else:
71
- results = iter(
72
- element
73
- for selected_set in dataset.selected_sets
74
- for element in self.api_client.paginate(
75
- "ListDatasetElements", id=dataset.id, set=selected_set
124
+ return map(lambda result: Element(**result["element"]), results)
125
+
126
+ def list_sets(self) -> Iterator[Set]:
127
+ """
128
+ List the sets to be processed, either from the CLI arguments or using the
129
+ [list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
130
+
131
+ :returns: An iterator of ``Set`` objects.
132
+ """
133
+ if not self.is_read_only:
134
+ yield from self.list_process_sets()
135
+
136
+ datasets: dict[uuid.UUID, Dataset] = {}
137
+ for dataset_id, set_name in self.args.set:
138
+ # Retrieving dataset information if not already cached
139
+ if dataset_id not in datasets:
140
+ datasets[dataset_id] = Dataset(
141
+ **self.api_client.request("RetrieveDataset", id=dataset_id)
76
142
  )
77
- )
78
143
 
79
- return map(
80
- lambda result: (result["set"], Element(**result["element"])), results
81
- )
144
+ yield Set(name=set_name, dataset=datasets[dataset_id])
82
145
 
83
146
  @unsupported_cache
84
147
  def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
@@ -100,7 +163,7 @@ class DatasetMixin:
100
163
  logger.warning("Cannot update dataset as this worker is in read-only mode")
101
164
  return
102
165
 
103
- updated_dataset = self.request(
166
+ updated_dataset = self.api_client.request(
104
167
  "PartialUpdateDataset",
105
168
  id=dataset.id,
106
169
  body={"state": state.value},