arkindex-base-worker 0.3.7rc9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/METADATA +16 -20
  2. arkindex_base_worker-0.4.0.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/WHEEL +1 -1
  4. arkindex_worker/cache.py +1 -1
  5. arkindex_worker/image.py +120 -1
  6. arkindex_worker/models.py +6 -0
  7. arkindex_worker/utils.py +85 -4
  8. arkindex_worker/worker/__init__.py +68 -162
  9. arkindex_worker/worker/base.py +39 -34
  10. arkindex_worker/worker/classification.py +34 -18
  11. arkindex_worker/worker/corpus.py +86 -0
  12. arkindex_worker/worker/dataset.py +71 -1
  13. arkindex_worker/worker/element.py +352 -91
  14. arkindex_worker/worker/entity.py +11 -11
  15. arkindex_worker/worker/image.py +21 -0
  16. arkindex_worker/worker/metadata.py +19 -9
  17. arkindex_worker/worker/process.py +92 -0
  18. arkindex_worker/worker/task.py +5 -4
  19. arkindex_worker/worker/training.py +25 -10
  20. arkindex_worker/worker/transcription.py +89 -68
  21. arkindex_worker/worker/version.py +3 -1
  22. tests/__init__.py +8 -0
  23. tests/conftest.py +36 -52
  24. tests/test_base_worker.py +212 -12
  25. tests/test_dataset_worker.py +21 -45
  26. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  27. tests/test_elements_worker/test_cli.py +3 -11
  28. tests/test_elements_worker/test_corpus.py +168 -0
  29. tests/test_elements_worker/test_dataset.py +7 -12
  30. tests/test_elements_worker/test_element.py +427 -0
  31. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  32. tests/test_elements_worker/test_element_create_single.py +528 -0
  33. tests/test_elements_worker/test_element_list_children.py +969 -0
  34. tests/test_elements_worker/test_element_list_parents.py +530 -0
  35. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  36. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  37. tests/test_elements_worker/test_image.py +66 -0
  38. tests/test_elements_worker/test_metadata.py +230 -139
  39. tests/test_elements_worker/test_process.py +89 -0
  40. tests/test_elements_worker/test_task.py +8 -18
  41. tests/test_elements_worker/test_training.py +17 -8
  42. tests/test_elements_worker/test_transcription_create.py +873 -0
  43. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  44. tests/test_elements_worker/test_transcription_list.py +450 -0
  45. tests/test_elements_worker/test_version.py +60 -0
  46. tests/test_elements_worker/test_worker.py +563 -279
  47. tests/test_image.py +432 -209
  48. tests/test_merge.py +1 -2
  49. tests/test_utils.py +66 -3
  50. arkindex_base_worker-0.3.7rc9.dist-info/RECORD +0 -47
  51. tests/test_elements_worker/test_elements.py +0 -2713
  52. tests/test_elements_worker/test_transcriptions.py +0 -2119
  53. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/LICENSE +0 -0
  54. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/top_level.txt +0 -0
@@ -4,64 +4,47 @@ Base classes to implement Arkindex workers.
4
4
 
5
5
  import contextlib
6
6
  import json
7
- import os
8
7
  import sys
9
8
  import uuid
10
- from argparse import ArgumentTypeError
11
- from collections.abc import Iterable, Iterator
12
- from enum import Enum
9
+ from collections.abc import Iterable
10
+ from itertools import chain
13
11
  from pathlib import Path
14
12
 
15
- from apistar.exceptions import ErrorResponse
16
-
13
+ from arkindex.exceptions import ErrorResponse
17
14
  from arkindex_worker import logger
18
15
  from arkindex_worker.cache import CachedElement
19
16
  from arkindex_worker.models import Dataset, Element, Set
17
+ from arkindex_worker.utils import pluralize
20
18
  from arkindex_worker.worker.base import BaseWorker
21
19
  from arkindex_worker.worker.classification import ClassificationMixin
22
- from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
20
+ from arkindex_worker.worker.corpus import CorpusMixin
21
+ from arkindex_worker.worker.dataset import (
22
+ DatasetMixin,
23
+ DatasetState,
24
+ MissingDatasetArchive,
25
+ )
23
26
  from arkindex_worker.worker.element import ElementMixin
24
27
  from arkindex_worker.worker.entity import EntityMixin
28
+ from arkindex_worker.worker.image import ImageMixin
25
29
  from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
30
+ from arkindex_worker.worker.process import ActivityState, ProcessMixin, ProcessMode
26
31
  from arkindex_worker.worker.task import TaskMixin
27
32
  from arkindex_worker.worker.transcription import TranscriptionMixin
28
33
  from arkindex_worker.worker.version import WorkerVersionMixin
29
34
 
30
35
 
31
- class ActivityState(Enum):
32
- """
33
- Processing state of an element.
34
- """
35
-
36
- Queued = "queued"
37
- """
38
- The element has not yet been processed by a worker.
39
- """
40
-
41
- Started = "started"
42
- """
43
- The element is being processed by a worker.
44
- """
45
-
46
- Processed = "processed"
47
- """
48
- The element has been successfully processed by a worker.
49
- """
50
-
51
- Error = "error"
52
- """
53
- An error occurred while processing this element.
54
- """
55
-
56
-
57
36
  class ElementsWorker(
37
+ ElementMixin,
38
+ DatasetMixin,
58
39
  BaseWorker,
59
40
  ClassificationMixin,
60
- ElementMixin,
41
+ CorpusMixin,
61
42
  TranscriptionMixin,
62
43
  WorkerVersionMixin,
63
44
  EntityMixin,
64
45
  MetaDataMixin,
46
+ ImageMixin,
47
+ ProcessMixin,
65
48
  ):
66
49
  """
67
50
  Base class for ML workers that operate on Arkindex elements.
@@ -79,39 +62,41 @@ class ElementsWorker(
79
62
  """
80
63
  super().__init__(description, support_cache)
81
64
 
82
- # Add mandatory argument to process elements
83
- self.parser.add_argument(
84
- "--elements-list",
85
- help="JSON elements list to use",
86
- type=open,
87
- default=os.environ.get("TASK_ELEMENTS"),
88
- )
89
- self.parser.add_argument(
90
- "--element",
91
- type=uuid.UUID,
92
- nargs="+",
93
- help="One or more Arkindex element ID",
94
- )
95
-
96
65
  self.classes = {}
97
66
 
98
67
  self.entity_types = {}
99
68
  """Known and available entity types in processed corpus
100
69
  """
101
70
 
71
+ self.corpus_types = {}
72
+ """Known and available element types in processed corpus
73
+ """
74
+
102
75
  self._worker_version_cache = {}
103
76
 
104
- def list_elements(self) -> Iterable[CachedElement] | list[str]:
77
+ def get_elements(self) -> Iterable[CachedElement] | list[str] | list[Element]:
105
78
  """
106
79
  List the elements to be processed, either from the CLI arguments or
107
80
  the cache database when enabled.
108
81
 
109
82
  :return: An iterable of [CachedElement][arkindex_worker.cache.CachedElement] when cache support is enabled,
110
- and a list of strings representing element IDs otherwise.
83
+ or a list of strings representing element IDs otherwise.
111
84
  """
112
85
  assert not (
113
86
  self.args.elements_list and self.args.element
114
87
  ), "elements-list and element CLI args shouldn't be both set"
88
+
89
+ def invalid_element_id(value: str) -> bool:
90
+ """
91
+ Return whether the ID of an element is a valid UUID or not
92
+ """
93
+ try:
94
+ uuid.UUID(value)
95
+ except Exception:
96
+ return True
97
+
98
+ return False
99
+
115
100
  out = []
116
101
 
117
102
  # Load from the cache when available
@@ -121,15 +106,28 @@ class ElementsWorker(
121
106
  )
122
107
  if self.use_cache and cache_query.exists():
123
108
  return cache_query
124
- # Process elements from JSON file
125
109
  elif self.args.elements_list:
110
+ # Process elements from JSON file
126
111
  data = json.load(self.args.elements_list)
127
112
  assert isinstance(data, list), "Elements list must be a list"
128
113
  assert len(data), "No elements in elements list"
129
114
  out += list(filter(None, [element.get("id") for element in data]))
130
- # Add any extra element from CLI
131
115
  elif self.args.element:
116
+ # Add any extra element from CLI
132
117
  out += self.args.element
118
+ elif self.process_mode == ProcessMode.Dataset or self.args.set:
119
+ # Elements from datasets
120
+ return list(
121
+ chain.from_iterable(map(self.list_set_elements, self.list_sets()))
122
+ )
123
+ elif self.process_mode == ProcessMode.Export:
124
+ # For export mode processes, use list_process_elements and return element IDs
125
+ return {item["id"] for item in self.list_process_elements()}
126
+
127
+ invalid_element_ids = list(filter(invalid_element_id, out))
128
+ assert (
129
+ not invalid_element_ids
130
+ ), f"These element IDs are invalid: {', '.join(invalid_element_ids)}"
133
131
 
134
132
  return out
135
133
 
@@ -139,40 +137,22 @@ class ElementsWorker(
139
137
  Whether or not WorkerActivity support has been enabled on the DataImport
140
138
  used to run this worker.
141
139
  """
142
- if self.is_read_only:
140
+ if self.is_read_only or self.process_mode in [
141
+ ProcessMode.Dataset,
142
+ ProcessMode.Export,
143
+ ]:
144
+ # Worker activities are also disabled when running an ElementsWorker in a Dataset process
145
+ # and when running export processes.
143
146
  return False
144
147
  assert (
145
148
  self.process_information
146
149
  ), "Worker must be configured to access its process activity state"
147
150
  return self.process_information.get("activity_state") == "ready"
148
151
 
149
- def configure(self):
150
- """
151
- Setup the worker using CLI arguments and environment variables.
152
- """
153
- # CLI args are stored on the instance so that implementations can access them
154
- self.args = self.parser.parse_args()
155
-
156
- if self.is_read_only:
157
- super().configure_for_developers()
158
- else:
159
- super().configure()
160
- super().configure_cache()
161
-
162
- # Retrieve the model configuration
163
- if self.model_configuration:
164
- self.config.update(self.model_configuration)
165
- logger.info("Model version configuration retrieved")
166
-
167
- # Retrieve the user configuration
168
- if self.user_configuration:
169
- self.config.update(self.user_configuration)
170
- logger.info("User configuration retrieved")
171
-
172
152
  def run(self):
173
153
  """
174
154
  Implements an Arkindex worker that goes through each element returned by
175
- [list_elements][arkindex_worker.worker.ElementsWorker.list_elements].
155
+ [get_elements][arkindex_worker.worker.ElementsWorker.get_elements].
176
156
  It calls [process_element][arkindex_worker.worker.ElementsWorker.process_element],
177
157
  catching exceptions, and handles saving WorkerActivity updates when enabled.
178
158
  """
@@ -180,7 +160,7 @@ class ElementsWorker(
180
160
 
181
161
  # List all elements either from JSON file
182
162
  # or direct list of elements on CLI
183
- elements = self.list_elements()
163
+ elements = self.get_elements()
184
164
  if not elements:
185
165
  logger.warning("No elements to process, stopping.")
186
166
  sys.exit(1)
@@ -196,12 +176,14 @@ class ElementsWorker(
196
176
  for i, item in enumerate(elements, start=1):
197
177
  element = None
198
178
  try:
199
- if self.use_cache:
200
- # Just use the result of list_elements as the element
179
+ if isinstance(item, CachedElement | Element):
180
+ # Just use the result of get_elements as the element
201
181
  element = item
202
182
  else:
203
183
  # Load element using the Arkindex API
204
- element = Element(**self.request("RetrieveElement", id=item))
184
+ element = Element(
185
+ **self.api_client.request("RetrieveElement", id=item)
186
+ )
205
187
 
206
188
  logger.info(f"Processing {element} ({i}/{count})")
207
189
 
@@ -239,7 +221,7 @@ class ElementsWorker(
239
221
  with contextlib.suppress(Exception):
240
222
  self.update_activity(element.id, ActivityState.Error)
241
223
 
242
- message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
224
+ message = f'Ran on {count} {pluralize("element", count)}: {count - failed} completed, {failed} failed'
243
225
  if failed:
244
226
  logger.error(message)
245
227
  if failed >= count: # Everything failed!
@@ -280,7 +262,7 @@ class ElementsWorker(
280
262
  assert isinstance(state, ActivityState), "state should be an ActivityState"
281
263
 
282
264
  try:
283
- self.request(
265
+ self.api_client.request(
284
266
  "UpdateWorkerActivity",
285
267
  id=self.worker_run_id,
286
268
  body={
@@ -310,29 +292,7 @@ class ElementsWorker(
310
292
  return True
311
293
 
312
294
 
313
- def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
314
- values = value.split(":")
315
- if len(values) != 2:
316
- raise ArgumentTypeError(
317
- f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
318
- )
319
-
320
- dataset_id, set_name = values
321
- try:
322
- dataset_id = uuid.UUID(dataset_id)
323
- return (dataset_id, set_name)
324
- except (TypeError, ValueError) as e:
325
- raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
326
-
327
-
328
- class MissingDatasetArchive(Exception):
329
- """
330
- Exception raised when the compressed archive associated to
331
- a dataset isn't found in its task artifacts.
332
- """
333
-
334
-
335
- class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
295
+ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
336
296
  """
337
297
  Base class for ML workers that operate on Arkindex dataset sets.
338
298
 
@@ -355,40 +315,6 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
355
315
  # Set as an instance variable as dataset workers might use it to easily extract its content
356
316
  self.downloaded_dataset_artifact: Path | None = None
357
317
 
358
- self.parser.add_argument(
359
- "--set",
360
- type=check_dataset_set,
361
- nargs="+",
362
- help="""
363
- One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
364
- (e.g.: "12341234-1234-1234-1234-123412341234:train")
365
- """,
366
- default=[],
367
- )
368
-
369
- def configure(self):
370
- """
371
- Setup the worker using CLI arguments and environment variables.
372
- """
373
- # CLI args are stored on the instance so that implementations can access them
374
- self.args = self.parser.parse_args()
375
-
376
- if self.is_read_only:
377
- super().configure_for_developers()
378
- else:
379
- super().configure()
380
- super().configure_cache()
381
-
382
- # Retrieve the model configuration
383
- if self.model_configuration:
384
- self.config.update(self.model_configuration)
385
- logger.info("Model version configuration retrieved")
386
-
387
- # Retrieve the user configuration
388
- if self.user_configuration:
389
- self.config.update(self.user_configuration)
390
- logger.info("User configuration retrieved")
391
-
392
318
  def cleanup_downloaded_artifact(self) -> None:
393
319
  """
394
320
  Cleanup the downloaded dataset artifact if any
@@ -436,30 +362,10 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
436
362
  :param set: The set to process.
437
363
  """
438
364
 
439
- def list_sets(self) -> Iterator[Set]:
440
- """
441
- List the sets to be processed, either from the CLI arguments or using the
442
- [list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
443
-
444
- :returns: An iterator of ``Set`` objects.
445
- """
446
- if not self.is_read_only:
447
- yield from self.list_process_sets()
448
-
449
- datasets: dict[uuid.UUID, Dataset] = {}
450
- for dataset_id, set_name in self.args.set:
451
- # Retrieving dataset information is not already cached
452
- if dataset_id not in datasets:
453
- datasets[dataset_id] = Dataset(
454
- **self.request("RetrieveDataset", id=dataset_id)
455
- )
456
-
457
- yield Set(name=set_name, dataset=datasets[dataset_id])
458
-
459
365
  def run(self):
460
366
  """
461
367
  Implements an Arkindex worker that goes through each dataset set returned by
462
- [list_sets][arkindex_worker.worker.DatasetWorker.list_sets].
368
+ [list_sets][arkindex_worker.worker.dataset.DatasetMixin.list_sets].
463
369
 
464
370
  It calls [process_set][arkindex_worker.worker.DatasetWorker.process_set],
465
371
  catching exceptions.
@@ -499,7 +405,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
499
405
  # Cleanup the latest downloaded dataset artifact
500
406
  self.cleanup_downloaded_artifact()
501
407
 
502
- message = f'Ran on {count} set{"s"[:count>1]}: {count - failed} completed, {failed} failed'
408
+ message = f'Ran on {count} {pluralize("set", count)}: {count - failed} completed, {failed} failed'
503
409
  if failed:
504
410
  logger.error(message)
505
411
  if failed >= count: # Everything failed!
@@ -12,15 +12,9 @@ from tempfile import mkdtemp
12
12
 
13
13
  import gnupg
14
14
  import yaml
15
- from apistar.exceptions import ErrorResponse
16
- from tenacity import (
17
- before_sleep_log,
18
- retry,
19
- retry_if_exception,
20
- stop_after_attempt,
21
- wait_exponential,
22
- )
23
15
 
16
+ from arkindex import options_from_env
17
+ from arkindex.exceptions import ErrorResponse
24
18
  from arkindex_worker import logger
25
19
  from arkindex_worker.cache import (
26
20
  check_version,
@@ -30,7 +24,8 @@ from arkindex_worker.cache import (
30
24
  merge_parents_cache,
31
25
  )
32
26
  from arkindex_worker.utils import close_delete_file, extract_tar_zst_archive
33
- from teklia_toolbox.requests import _get_arkindex_client, _is_500_error
27
+ from arkindex_worker.worker.process import ProcessMode
28
+ from teklia_toolbox.requests import get_arkindex_client
34
29
 
35
30
 
36
31
  class ExtrasDirNotFoundError(Exception):
@@ -162,6 +157,13 @@ class BaseWorker:
162
157
  raise Exception("Missing ARKINDEX_CORPUS_ID environment variable")
163
158
  return self._corpus_id
164
159
 
160
+ @property
161
+ def process_mode(self) -> ProcessMode | None:
162
+ """Mode of the process being run. Returns None when read-only."""
163
+ if self.is_read_only:
164
+ return
165
+ return ProcessMode(self.process_information["mode"])
166
+
165
167
  @property
166
168
  def is_read_only(self) -> bool:
167
169
  """
@@ -185,7 +187,7 @@ class BaseWorker:
185
187
  Create an ArkindexClient to make API requests towards Arkindex instances.
186
188
  """
187
189
  # Build Arkindex API client from environment variables
188
- self.api_client = _get_arkindex_client()
190
+ self.api_client = get_arkindex_client(**options_from_env())
189
191
  logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
190
192
 
191
193
  def configure_for_developers(self):
@@ -225,7 +227,7 @@ class BaseWorker:
225
227
  # Load all required secrets
226
228
  self.secrets = {name: self.load_secret(Path(name)) for name in required_secrets}
227
229
 
228
- def configure(self):
230
+ def configure_worker_run(self):
229
231
  """
230
232
  Setup the necessary configuration needed using CLI args and environment variables.
231
233
  This is the method called when running a worker on Arkindex.
@@ -237,7 +239,7 @@ class BaseWorker:
237
239
  logger.debug("Debug output enabled")
238
240
 
239
241
  # Load worker run information
240
- worker_run = self.request("RetrieveWorkerRun", id=self.worker_run_id)
242
+ worker_run = self.api_client.request("RetrieveWorkerRun", id=self.worker_run_id)
241
243
 
242
244
  # Load process information
243
245
  self.process_information = worker_run["process"]
@@ -296,7 +298,7 @@ class BaseWorker:
296
298
  if self.support_cache and self.args.database is not None:
297
299
  self.use_cache = True
298
300
  elif self.support_cache and self.task_id:
299
- task = self.request("RetrieveTaskFromAgent", id=self.task_id)
301
+ task = self.api_client.request("RetrieveTask", id=self.task_id)
300
302
  self.task_parents = task["parents"]
301
303
  paths = self.find_parents_file_paths(Path("db.sqlite"))
302
304
  self.use_cache = len(paths) > 0
@@ -326,6 +328,29 @@ class BaseWorker:
326
328
  else:
327
329
  logger.debug("Cache is disabled")
328
330
 
331
+ def configure(self):
332
+ """
333
+ Setup the worker using CLI arguments and environment variables.
334
+ """
335
+ # CLI args are stored on the instance so that implementations can access them
336
+ self.args = self.parser.parse_args()
337
+
338
+ if self.is_read_only:
339
+ self.configure_for_developers()
340
+ else:
341
+ self.configure_worker_run()
342
+ self.configure_cache()
343
+
344
+ # Retrieve the model configuration
345
+ if self.model_configuration:
346
+ self.config.update(self.model_configuration)
347
+ logger.info("Model version configuration retrieved")
348
+
349
+ # Retrieve the user configuration
350
+ if self.user_configuration:
351
+ self.config.update(self.user_configuration)
352
+ logger.info("User configuration retrieved")
353
+
329
354
  def load_secret(self, name: Path):
330
355
  """
331
356
  Load a Ponos secret by name.
@@ -337,7 +362,7 @@ class BaseWorker:
337
362
 
338
363
  # Load from the backend
339
364
  try:
340
- resp = self.request("RetrieveSecret", name=str(name))
365
+ resp = self.api_client.request("RetrieveSecret", name=str(name))
341
366
  secret = resp["content"]
342
367
  logging.info(f"Loaded API secret {name}")
343
368
  except ErrorResponse as e:
@@ -477,26 +502,6 @@ class BaseWorker:
477
502
  # Clean up
478
503
  shutil.rmtree(base_extracted_path)
479
504
 
480
- @retry(
481
- retry=retry_if_exception(_is_500_error),
482
- wait=wait_exponential(multiplier=2, min=3),
483
- reraise=True,
484
- stop=stop_after_attempt(5),
485
- before_sleep=before_sleep_log(logger, logging.INFO),
486
- )
487
- def request(self, *args, **kwargs):
488
- """
489
- Wrapper around the ``ArkindexClient.request`` method.
490
-
491
- The API call will be retried up to 5 times in case of HTTP 5xx errors,
492
- with an exponential sleep time of 3, 4, 8 and 16 seconds between calls.
493
- If the 5th call still causes an HTTP 5xx error, the exception is re-raised
494
- and the caller should catch it.
495
-
496
- Log messages are displayed when an HTTP 5xx error occurs, before waiting for the next call.
497
- """
498
- return self.api_client.request(*args, **kwargs)
499
-
500
505
  def add_arguments(self):
501
506
  """Override this method to add ``argparse`` arguments to this worker"""
502
507
 
@@ -2,12 +2,18 @@
2
2
  ElementsWorker methods for classifications and ML classes.
3
3
  """
4
4
 
5
- from apistar.exceptions import ErrorResponse
6
5
  from peewee import IntegrityError
7
6
 
7
+ from arkindex.exceptions import ErrorResponse
8
8
  from arkindex_worker import logger
9
9
  from arkindex_worker.cache import CachedClassification, CachedElement
10
10
  from arkindex_worker.models import Element
11
+ from arkindex_worker.utils import (
12
+ DEFAULT_BATCH_SIZE,
13
+ batch_publication,
14
+ make_batches,
15
+ pluralize,
16
+ )
11
17
 
12
18
 
13
19
  class ClassificationMixin:
@@ -21,7 +27,7 @@ class ClassificationMixin:
21
27
  )
22
28
  self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
23
29
  logger.info(
24
- f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
30
+ f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
25
31
  )
26
32
 
27
33
  def get_ml_class_id(self, ml_class: str) -> str:
@@ -39,7 +45,7 @@ class ClassificationMixin:
39
45
  if ml_class_id is None:
40
46
  logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
41
47
  try:
42
- response = self.request(
48
+ response = self.api_client.request(
43
49
  "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
44
50
  )
45
51
  ml_class_id = self.classes[ml_class] = response["id"]
@@ -119,7 +125,7 @@ class ClassificationMixin:
119
125
  )
120
126
  return
121
127
  try:
122
- created = self.request(
128
+ created = self.api_client.request(
123
129
  "CreateClassification",
124
130
  body={
125
131
  "element": str(element.id),
@@ -167,10 +173,12 @@ class ClassificationMixin:
167
173
 
168
174
  return created
169
175
 
176
+ @batch_publication
170
177
  def create_classifications(
171
178
  self,
172
179
  element: Element | CachedElement,
173
180
  classifications: list[dict[str, str | float | bool]],
181
+ batch_size: int = DEFAULT_BATCH_SIZE,
174
182
  ) -> list[dict[str, str | float | bool]]:
175
183
  """
176
184
  Create multiple classifications at once on the given element through the API.
@@ -185,6 +193,8 @@ class ClassificationMixin:
185
193
  high_confidence (bool)
186
194
  Optional. Whether or not the classification is of high confidence.
187
195
 
196
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
197
+
188
198
  :returns: List of created classifications, as returned in the ``classifications`` field by
189
199
  the ``CreateClassifications`` API endpoint.
190
200
  """
@@ -220,20 +230,26 @@ class ClassificationMixin:
220
230
  )
221
231
  return
222
232
 
223
- created_cls = self.request(
224
- "CreateClassifications",
225
- body={
226
- "parent": str(element.id),
227
- "worker_run_id": self.worker_run_id,
228
- "classifications": [
229
- {
230
- **classification,
231
- "ml_class": self.get_ml_class_id(classification["ml_class"]),
232
- }
233
- for classification in classifications
234
- ],
235
- },
236
- )["classifications"]
233
+ created_cls = [
234
+ created_cl
235
+ for batch in make_batches(classifications, "classification", batch_size)
236
+ for created_cl in self.api_client.request(
237
+ "CreateClassifications",
238
+ body={
239
+ "parent": str(element.id),
240
+ "worker_run_id": self.worker_run_id,
241
+ "classifications": [
242
+ {
243
+ **classification,
244
+ "ml_class": self.get_ml_class_id(
245
+ classification["ml_class"]
246
+ ),
247
+ }
248
+ for classification in batch
249
+ ],
250
+ },
251
+ )["classifications"]
252
+ ]
237
253
 
238
254
  for created_cl in created_cls:
239
255
  created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])