arkindex-base-worker 0.4.0a2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arkindex-base-worker
3
- Version: 0.4.0a2
3
+ Version: 0.4.0b2
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -41,17 +41,17 @@ Requires-Python: >=3.10
41
41
  Description-Content-Type: text/markdown
42
42
  License-File: LICENSE
43
43
  Requires-Dist: peewee ~=3.17
44
- Requires-Dist: Pillow ==10.3.0
44
+ Requires-Dist: Pillow ==10.4.0
45
45
  Requires-Dist: python-gnupg ==0.5.2
46
- Requires-Dist: shapely ==2.0.3
46
+ Requires-Dist: shapely ==2.0.5
47
47
  Requires-Dist: teklia-toolbox ==0.1.5
48
48
  Requires-Dist: zstandard ==0.22.0
49
49
  Provides-Extra: docs
50
- Requires-Dist: black ==24.4.0 ; extra == 'docs'
51
- Requires-Dist: mkdocs-material ==9.5.17 ; extra == 'docs'
52
- Requires-Dist: mkdocstrings-python ==1.9.2 ; extra == 'docs'
50
+ Requires-Dist: black ==24.4.2 ; extra == 'docs'
51
+ Requires-Dist: mkdocs-material ==9.5.31 ; extra == 'docs'
52
+ Requires-Dist: mkdocstrings-python ==1.10.7 ; extra == 'docs'
53
53
  Provides-Extra: tests
54
- Requires-Dist: pytest ==8.1.1 ; extra == 'tests'
54
+ Requires-Dist: pytest ==8.3.2 ; extra == 'tests'
55
55
  Requires-Dist: pytest-mock ==3.14.0 ; extra == 'tests'
56
56
  Requires-Dist: pytest-responses ==0.5.1 ; extra == 'tests'
57
57
 
@@ -0,0 +1,51 @@
1
+ arkindex_worker/__init__.py,sha256=OlgCtTC9MaWeejviY0a3iQpALcRQGMVArFVVYwTF6I8,162
2
+ arkindex_worker/cache.py,sha256=FTlB0coXofn5zTNRTcVIvh709mcw4a1bPGqkwWjKs3w,11248
3
+ arkindex_worker/image.py,sha256=8Y0PYMbTEsFUv8lCNLBu7UaDy6um5YfHCefyXL2jpnE,14347
4
+ arkindex_worker/models.py,sha256=bPQzGZNs5a6z6DEcygsa8T33VOqPlMUbwKzHqlKzwbw,9923
5
+ arkindex_worker/utils.py,sha256=zrtMChXx_HGu4UkqXZBKPg3ys0UBFmQaizoX1riM3D4,9824
6
+ arkindex_worker/worker/__init__.py,sha256=w1VlDzERabXIp625kkHnojyu5ctCM11WLw4ARh1ja3k,19818
7
+ arkindex_worker/worker/base.py,sha256=JStHpwSP3bis9LLvV2C2n6GTWtLUVIDA9JPgPJEt17o,18717
8
+ arkindex_worker/worker/classification.py,sha256=ECm1cnQPOj_9m-CoO0e182ElSySAUOoyddHrORbShhc,10951
9
+ arkindex_worker/worker/corpus.py,sha256=s9bCxOszJMwRq1WWAmKjWq888mjDfbaJ18Wo7h-rNOw,1827
10
+ arkindex_worker/worker/dataset.py,sha256=UXElhhARca9m7Himp-yxD5dAqWbdxDKWOUJUGgeCZXI,2934
11
+ arkindex_worker/worker/element.py,sha256=5knEFHc0LDRRHI8IbJbiiQsOAoW7qYPf9lcVXsFlUEQ,34798
12
+ arkindex_worker/worker/entity.py,sha256=qGjQvOVXfP84rER0Dkui6q-rb9nTWerHVG0Z5voB8pU,15229
13
+ arkindex_worker/worker/image.py,sha256=t_Az6IGnj0EZyvcA4XxfPikOUjn_pztgsyxTkFZhaXU,621
14
+ arkindex_worker/worker/metadata.py,sha256=VRajtd2kaBvar9GercX4knvR6l1WFYjoCdJWU9ccKgk,7291
15
+ arkindex_worker/worker/task.py,sha256=1O9zrWXxe3na3TOcoHX5Pxn1875v7EU08BSsCPnb62g,1519
16
+ arkindex_worker/worker/training.py,sha256=qnBFEk11JOWWPLTbjF-lZ9iFBdTPpQzZAzQ9a03J1j4,10874
17
+ arkindex_worker/worker/transcription.py,sha256=8ho-8zmF9LgP86oS59ZZLv5I7tfnZ1yNO2A3pY_9GQ8,21353
18
+ arkindex_worker/worker/version.py,sha256=JIT7OI3Mo7RPkNrjOB9hfqrsG-FYygz_zi4l8PbkuAo,1960
19
+ hooks/pre_gen_project.py,sha256=xQJERv3vv9VzIqcBHI281eeWLWREXUF4mMw7PvJHHXM,269
20
+ tests/__init__.py,sha256=6aeTMHf4q_dKY4jIZWg1KT70VKaLvVlzCxh-Uu_cWiQ,241
21
+ tests/conftest.py,sha256=-ZQTV4rg7TgW84-5Ioqndqv8byNILfDOpyUt8wecEiI,21967
22
+ tests/test_base_worker.py,sha256=LdFV0LFdNU2IOyEKlX59MB1kuyxHCuhy4Tm7eE_iPiU,24281
23
+ tests/test_cache.py,sha256=ii0gyr0DrG7ChEs7pmT8hMdSguAOAcCze4bRMiFQxuk,10640
24
+ tests/test_dataset_worker.py,sha256=d9HG36qnO5HXu9vQ0UTBvdTSRR21FVq1FNoXM-vZbPk,22105
25
+ tests/test_element.py,sha256=2G9M15TLxQRmvrWM9Kw2ucnElh4kSv_oF_5FYwwAxTY,13181
26
+ tests/test_image.py,sha256=Fs9vKYgQ7mEFylbzI4YIO_JyOLeAcs-WxUXpzewxCd8,16188
27
+ tests/test_merge.py,sha256=FMdpsm_ncHNmIvOrJ1vcwlyn8o9-SPcpFTcbAsXwK-w,8320
28
+ tests/test_utils.py,sha256=zbJC24NyTc3slz3Ed3gJDswjRChjkR5oHEgDoQMOBiE,2588
29
+ tests/test_elements_worker/__init__.py,sha256=Fh4nkbbyJSMv_VtjQxnWrOqTnxXaaWI8S9WU0VrzCHs,179
30
+ tests/test_elements_worker/test_classifications.py,sha256=fXZ8cSzIWwZ6LHsY7tKsy9-Pp9fKyKUStIXS4ViBcek,27779
31
+ tests/test_elements_worker/test_cli.py,sha256=a23i1pUDbXi23MUtbWwGEcLLrmc_YlrbDgOG3h66wLM,2620
32
+ tests/test_elements_worker/test_corpus.py,sha256=c_LUHvkJIYgk_wXF06VQPNOoWfiZ06XpjOXrJ7MRiBc,4479
33
+ tests/test_elements_worker/test_dataset.py,sha256=lSXqubhg1EEq2Y2goE8Y2RYaqIpM9Iejq6fGNW2BczU,11411
34
+ tests/test_elements_worker/test_elements.py,sha256=dBhjQ8XZNIE7bjx5AaGaclPLZr1Ur_-tQ-ebS3S_Zn0,89142
35
+ tests/test_elements_worker/test_entities.py,sha256=oav2dtvWWavQe1l3Drbxw1Ta2ocUJEVxJfDQ_r6-rYQ,36181
36
+ tests/test_elements_worker/test_image.py,sha256=_E3UGdDOwTo1MW5KMS81PrdeSPBPWinWYoQPNy2F9Ro,2077
37
+ tests/test_elements_worker/test_metadata.py,sha256=cm2NNaXxBYmYMkPexSPVTAqb2skDTB4mliwQCLz8Y98,22293
38
+ tests/test_elements_worker/test_task.py,sha256=7Sr3fbjdgWUXJUhJEiC9CwnbhQIQX3rCInmHMIrmA38,5573
39
+ tests/test_elements_worker/test_training.py,sha256=Qxi9EzGr_uKcn2Fh5aE6jNrq1K8QKLiOiSew4upASPs,8721
40
+ tests/test_elements_worker/test_transcriptions.py,sha256=FNY6E26iTKqe7LP9LO72By4oV4g9hBIZYTU9BAc_w7I,77060
41
+ tests/test_elements_worker/test_worker.py,sha256=AwdP8uSXNQ_SJavXxJV2s3_J3OiCafShVjMV1dgt4xo,17162
42
+ worker-demo/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ worker-demo/tests/conftest.py,sha256=XzNMNeg6pmABUAH8jN6eZTlZSFGLYjS3-DTXjiRN6Yc,1002
44
+ worker-demo/tests/test_worker.py,sha256=3DLd4NRK4bfyatG5P_PK4k9P9tJHx9XQq5_ryFEEFVg,304
45
+ worker-demo/worker_demo/__init__.py,sha256=2BPomV8ZMNf3YXJgloatKeHQCE6QOkwmsHGkO6MkQuM,125
46
+ worker-demo/worker_demo/worker.py,sha256=Rt-DjWa5iBP08k58NDZMfeyPuFbtNcbX6nc5jFX7GNo,440
47
+ arkindex_base_worker-0.4.0b2.dist-info/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
48
+ arkindex_base_worker-0.4.0b2.dist-info/METADATA,sha256=wvefQTllKMq-jkbjsG1TMyuvF06h-XjBLgc79_j8MTU,3270
49
+ arkindex_base_worker-0.4.0b2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
50
+ arkindex_base_worker-0.4.0b2.dist-info/top_level.txt,sha256=58NuslgxQC2vT4DiqZEgO4JqJRrYa2yeNI9QvkbfGQU,40
51
+ arkindex_base_worker-0.4.0b2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
arkindex_worker/image.py CHANGED
@@ -21,6 +21,7 @@ from tenacity import (
21
21
  )
22
22
 
23
23
  from arkindex_worker import logger
24
+ from arkindex_worker.utils import pluralize
24
25
  from teklia_toolbox.requests import should_verify_cert
25
26
 
26
27
  # Avoid circular imports error when type checking
@@ -164,7 +165,7 @@ def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox:
164
165
  def _retry_log(retry_state, *args, **kwargs):
165
166
  logger.warning(
166
167
  f"Request to {retry_state.args[0]} failed ({repr(retry_state.outcome.exception())}), "
167
- f"retrying in {retry_state.idle_for} seconds"
168
+ f'retrying in {retry_state.idle_for} {pluralize("second", retry_state.idle_for)}'
168
169
  )
169
170
 
170
171
 
arkindex_worker/utils.py CHANGED
@@ -1,14 +1,36 @@
1
1
  import hashlib
2
+ import inspect
2
3
  import logging
3
4
  import os
4
5
  import tarfile
5
6
  import tempfile
7
+ from collections.abc import Callable, Generator
8
+ from itertools import islice
6
9
  from pathlib import Path
10
+ from typing import Any
7
11
 
8
12
  import zstandard as zstd
9
13
 
10
14
  logger = logging.getLogger(__name__)
11
15
 
16
+
17
+ def pluralize(singular: str, count: int) -> str:
18
+ """Pluralize a noun, if necessary, using simplified rules of English pluralization and a list of exceptions.
19
+
20
+ :param str singular: A singular noun describing an object
21
+ :param int count: The object count, to determine whether to pluralize or not
22
+ :return str: The noun in its singular or plural form
23
+ """
24
+ if count == 1:
25
+ return singular
26
+
27
+ some_exceptions = {"entity": "entities", "metadata": "metadata", "class": "classes"}
28
+ if singular in some_exceptions:
29
+ return some_exceptions[singular]
30
+
31
+ return singular + "s"
32
+
33
+
12
34
  MANUAL_SOURCE = "manual"
13
35
 
14
36
 
@@ -196,3 +218,57 @@ def create_tar_zst_archive(
196
218
  close_delete_file(tar_fd, tar_archive)
197
219
 
198
220
  return zst_fd, zst_archive, zst_hash, tar_hash
221
+
222
+
223
+ DEFAULT_BATCH_SIZE = 50
224
+ """Batch size used for bulk publication to Arkindex"""
225
+
226
+
227
+ def batch_publication(func: Callable) -> Callable:
228
+ """
229
+ Decorator for functions that should raise an error when the value passed through the ``batch_size`` parameter is **not** a strictly positive integer.
230
+
231
+ :param func: The function to wrap with the ``batch_size`` check
232
+ :return: The function passing the ``batch_size`` check
233
+ """
234
+ signature = inspect.signature(func)
235
+
236
+ def wrapper(self, *args, **kwargs):
237
+ bound_func = signature.bind(self, *args, **kwargs)
238
+ bound_func.apply_defaults()
239
+ batch_size = bound_func.arguments.get("batch_size")
240
+ assert (
241
+ batch_size and isinstance(batch_size, int) and batch_size > 0
242
+ ), "batch_size shouldn't be null and should be a strictly positive integer"
243
+
244
+ return func(self, *args, **kwargs)
245
+
246
+ return wrapper
247
+
248
+
249
+ def make_batches(
250
+ objects: list, singular_name: str, batch_size: int
251
+ ) -> Generator[list[Any]]:
252
+ """Split an object list in successive batches of maximum size ``batch_size``.
253
+
254
+ :param objects: The object list to divide in batches of ``batch_size`` size
255
+ :param singular_name: The singular form of the noun associated with the object list
256
+ :param batch_size: The maximum size of each batch to split the object list
257
+ :return: A generator of successive batches containing ``batch_size`` items from ``objects``
258
+ """
259
+ count = len(objects)
260
+ logger.info(
261
+ f"Creating batches of size {batch_size} to process {count} {pluralize(singular_name, count)}"
262
+ )
263
+
264
+ index = 1
265
+ iterator = iter(objects)
266
+ while batch := list(islice(iterator, batch_size)):
267
+ count = len(batch)
268
+ logger.info(
269
+ f"Processing batch {index} containing {count} {pluralize(singular_name, count)}..."
270
+ )
271
+
272
+ yield batch
273
+
274
+ index += 1
@@ -17,6 +17,7 @@ from apistar.exceptions import ErrorResponse
17
17
  from arkindex_worker import logger
18
18
  from arkindex_worker.cache import CachedElement
19
19
  from arkindex_worker.models import Dataset, Element, Set
20
+ from arkindex_worker.utils import pluralize
20
21
  from arkindex_worker.worker.base import BaseWorker
21
22
  from arkindex_worker.worker.classification import ClassificationMixin
22
23
  from arkindex_worker.worker.corpus import CorpusMixin
@@ -83,7 +84,20 @@ class ElementsWorker(
83
84
  """
84
85
  super().__init__(description, support_cache)
85
86
 
86
- # Add mandatory argument to process elements
87
+ self.classes = {}
88
+
89
+ self.entity_types = {}
90
+ """Known and available entity types in processed corpus
91
+ """
92
+
93
+ self.corpus_types = {}
94
+ """Known and available element types in processed corpus
95
+ """
96
+
97
+ self._worker_version_cache = {}
98
+
99
+ def add_arguments(self):
100
+ """Define specific ``argparse`` arguments for this worker"""
87
101
  self.parser.add_argument(
88
102
  "--elements-list",
89
103
  help="JSON elements list to use",
@@ -97,14 +111,6 @@ class ElementsWorker(
97
111
  help="One or more Arkindex element ID",
98
112
  )
99
113
 
100
- self.classes = {}
101
-
102
- self.entity_types = {}
103
- """Known and available entity types in processed corpus
104
- """
105
-
106
- self._worker_version_cache = {}
107
-
108
114
  def list_elements(self) -> Iterable[CachedElement] | list[str]:
109
115
  """
110
116
  List the elements to be processed, either from the CLI arguments or
@@ -222,7 +228,9 @@ class ElementsWorker(
222
228
  element = item
223
229
  else:
224
230
  # Load element using the Arkindex API
225
- element = Element(**self.request("RetrieveElement", id=item))
231
+ element = Element(
232
+ **self.api_client.request("RetrieveElement", id=item)
233
+ )
226
234
 
227
235
  logger.info(f"Processing {element} ({i}/{count})")
228
236
 
@@ -260,7 +268,7 @@ class ElementsWorker(
260
268
  with contextlib.suppress(Exception):
261
269
  self.update_activity(element.id, ActivityState.Error)
262
270
 
263
- message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
271
+ message = f'Ran on {count} {pluralize("element", count)}: {count - failed} completed, {failed} failed'
264
272
  if failed:
265
273
  logger.error(message)
266
274
  if failed >= count: # Everything failed!
@@ -301,7 +309,7 @@ class ElementsWorker(
301
309
  assert isinstance(state, ActivityState), "state should be an ActivityState"
302
310
 
303
311
  try:
304
- self.request(
312
+ self.api_client.request(
305
313
  "UpdateWorkerActivity",
306
314
  id=self.worker_run_id,
307
315
  body={
@@ -376,6 +384,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
376
384
  # Set as an instance variable as dataset workers might use it to easily extract its content
377
385
  self.downloaded_dataset_artifact: Path | None = None
378
386
 
387
+ def add_arguments(self):
388
+ """Define specific ``argparse`` arguments for this worker"""
379
389
  self.parser.add_argument(
380
390
  "--set",
381
391
  type=check_dataset_set,
@@ -472,7 +482,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
472
482
  # Retrieving dataset information is not already cached
473
483
  if dataset_id not in datasets:
474
484
  datasets[dataset_id] = Dataset(
475
- **self.request("RetrieveDataset", id=dataset_id)
485
+ **self.api_client.request("RetrieveDataset", id=dataset_id)
476
486
  )
477
487
 
478
488
  yield Set(name=set_name, dataset=datasets[dataset_id])
@@ -520,7 +530,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
520
530
  # Cleanup the latest downloaded dataset artifact
521
531
  self.cleanup_downloaded_artifact()
522
532
 
523
- message = f'Ran on {count} set{"s"[:count>1]}: {count - failed} completed, {failed} failed'
533
+ message = f'Ran on {count} {pluralize("set", count)}: {count - failed} completed, {failed} failed'
524
534
  if failed:
525
535
  logger.error(message)
526
536
  if failed >= count: # Everything failed!
@@ -231,7 +231,7 @@ class BaseWorker:
231
231
  logger.debug("Debug output enabled")
232
232
 
233
233
  # Load worker run information
234
- worker_run = self.request("RetrieveWorkerRun", id=self.worker_run_id)
234
+ worker_run = self.api_client.request("RetrieveWorkerRun", id=self.worker_run_id)
235
235
 
236
236
  # Load process information
237
237
  self.process_information = worker_run["process"]
@@ -290,7 +290,7 @@ class BaseWorker:
290
290
  if self.support_cache and self.args.database is not None:
291
291
  self.use_cache = True
292
292
  elif self.support_cache and self.task_id:
293
- task = self.request("RetrieveTaskFromAgent", id=self.task_id)
293
+ task = self.api_client.request("RetrieveTask", id=self.task_id)
294
294
  self.task_parents = task["parents"]
295
295
  paths = self.find_parents_file_paths(Path("db.sqlite"))
296
296
  self.use_cache = len(paths) > 0
@@ -331,7 +331,7 @@ class BaseWorker:
331
331
 
332
332
  # Load from the backend
333
333
  try:
334
- resp = self.request("RetrieveSecret", name=str(name))
334
+ resp = self.api_client.request("RetrieveSecret", name=str(name))
335
335
  secret = resp["content"]
336
336
  logging.info(f"Loaded API secret {name}")
337
337
  except ErrorResponse as e:
@@ -471,12 +471,6 @@ class BaseWorker:
471
471
  # Clean up
472
472
  shutil.rmtree(base_extracted_path)
473
473
 
474
- def request(self, *args, **kwargs):
475
- """
476
- Wrapper around the ``ArkindexClient.request`` method.
477
- """
478
- return self.api_client.request(*args, **kwargs)
479
-
480
474
  def add_arguments(self):
481
475
  """Override this method to add ``argparse`` arguments to this worker"""
482
476
 
@@ -8,6 +8,12 @@ from peewee import IntegrityError
8
8
  from arkindex_worker import logger
9
9
  from arkindex_worker.cache import CachedClassification, CachedElement
10
10
  from arkindex_worker.models import Element
11
+ from arkindex_worker.utils import (
12
+ DEFAULT_BATCH_SIZE,
13
+ batch_publication,
14
+ make_batches,
15
+ pluralize,
16
+ )
11
17
 
12
18
 
13
19
  class ClassificationMixin:
@@ -21,7 +27,7 @@ class ClassificationMixin:
21
27
  )
22
28
  self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
23
29
  logger.info(
24
- f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
30
+ f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
25
31
  )
26
32
 
27
33
  def get_ml_class_id(self, ml_class: str) -> str:
@@ -39,7 +45,7 @@ class ClassificationMixin:
39
45
  if ml_class_id is None:
40
46
  logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
41
47
  try:
42
- response = self.request(
48
+ response = self.api_client.request(
43
49
  "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
44
50
  )
45
51
  ml_class_id = self.classes[ml_class] = response["id"]
@@ -119,7 +125,7 @@ class ClassificationMixin:
119
125
  )
120
126
  return
121
127
  try:
122
- created = self.request(
128
+ created = self.api_client.request(
123
129
  "CreateClassification",
124
130
  body={
125
131
  "element": str(element.id),
@@ -167,10 +173,12 @@ class ClassificationMixin:
167
173
 
168
174
  return created
169
175
 
176
+ @batch_publication
170
177
  def create_classifications(
171
178
  self,
172
179
  element: Element | CachedElement,
173
180
  classifications: list[dict[str, str | float | bool]],
181
+ batch_size: int = DEFAULT_BATCH_SIZE,
174
182
  ) -> list[dict[str, str | float | bool]]:
175
183
  """
176
184
  Create multiple classifications at once on the given element through the API.
@@ -185,6 +193,8 @@ class ClassificationMixin:
185
193
  high_confidence (bool)
186
194
  Optional. Whether or not the classification is of high confidence.
187
195
 
196
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
197
+
188
198
  :returns: List of created classifications, as returned in the ``classifications`` field by
189
199
  the ``CreateClassifications`` API endpoint.
190
200
  """
@@ -220,20 +230,26 @@ class ClassificationMixin:
220
230
  )
221
231
  return
222
232
 
223
- created_cls = self.request(
224
- "CreateClassifications",
225
- body={
226
- "parent": str(element.id),
227
- "worker_run_id": self.worker_run_id,
228
- "classifications": [
229
- {
230
- **classification,
231
- "ml_class": self.get_ml_class_id(classification["ml_class"]),
232
- }
233
- for classification in classifications
234
- ],
235
- },
236
- )["classifications"]
233
+ created_cls = [
234
+ created_cl
235
+ for batch in make_batches(classifications, "classification", batch_size)
236
+ for created_cl in self.api_client.request(
237
+ "CreateClassifications",
238
+ body={
239
+ "parent": str(element.id),
240
+ "worker_run_id": self.worker_run_id,
241
+ "classifications": [
242
+ {
243
+ **classification,
244
+ "ml_class": self.get_ml_class_id(
245
+ classification["ml_class"]
246
+ ),
247
+ }
248
+ for classification in batch
249
+ ],
250
+ },
251
+ )["classifications"]
252
+ ]
237
253
 
238
254
  for created_cl in created_cls:
239
255
  created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
@@ -63,7 +63,9 @@ class CorpusMixin:
63
63
  # Download latest export
64
64
  export_id: str = exports[0]["id"]
65
65
  logger.info(f"Downloading export ({export_id})...")
66
- export: _TemporaryFileWrapper = self.request("DownloadExport", id=export_id)
66
+ export: _TemporaryFileWrapper = self.api_client.request(
67
+ "DownloadExport", id=export_id
68
+ )
67
69
  logger.info(f"Downloaded export ({export_id}) @ `{export.name}`")
68
70
 
69
71
  return export
@@ -93,7 +93,7 @@ class DatasetMixin:
93
93
  logger.warning("Cannot update dataset as this worker is in read-only mode")
94
94
  return
95
95
 
96
- updated_dataset = self.request(
96
+ updated_dataset = self.api_client.request(
97
97
  "PartialUpdateDataset",
98
98
  id=dataset.id,
99
99
  body={"state": state.value},
@@ -12,6 +12,12 @@ from peewee import IntegrityError
12
12
  from arkindex_worker import logger
13
13
  from arkindex_worker.cache import CachedElement, CachedImage, unsupported_cache
14
14
  from arkindex_worker.models import Element
15
+ from arkindex_worker.utils import (
16
+ DEFAULT_BATCH_SIZE,
17
+ batch_publication,
18
+ make_batches,
19
+ pluralize,
20
+ )
15
21
 
16
22
 
17
23
  class ElementType(NamedTuple):
@@ -31,6 +37,21 @@ class MissingTypeError(Exception):
31
37
 
32
38
 
33
39
  class ElementMixin:
40
+ def list_corpus_types(self):
41
+ """
42
+ Loads available element types in corpus.
43
+ """
44
+ self.corpus_types = {
45
+ element_type["slug"]: element_type
46
+ for element_type in self.api_client.request(
47
+ "RetrieveCorpus", id=self.corpus_id
48
+ )["types"]
49
+ }
50
+ count = len(self.corpus_types)
51
+ logger.info(
52
+ f'Loaded {count} element {pluralize("type", count)} in corpus ({self.corpus_id}).'
53
+ )
54
+
34
55
  @unsupported_cache
35
56
  def create_required_types(self, element_types: list[ElementType]):
36
57
  """Creates given element types in the corpus.
@@ -38,7 +59,7 @@ class ElementMixin:
38
59
  :param element_types: The missing element types to create.
39
60
  """
40
61
  for element_type in element_types:
41
- self.request(
62
+ self.api_client.request(
42
63
  "CreateElementType",
43
64
  body={
44
65
  "slug": element_type.slug,
@@ -66,10 +87,10 @@ class ElementMixin:
66
87
  isinstance(slug, str) for slug in type_slugs
67
88
  ), "Element type slugs must be strings."
68
89
 
69
- corpus = self.request("RetrieveCorpus", id=self.corpus_id)
70
- available_slugs = {element_type["slug"] for element_type in corpus["types"]}
71
- missing_slugs = set(type_slugs) - available_slugs
90
+ if not self.corpus_types:
91
+ self.list_corpus_types()
72
92
 
93
+ missing_slugs = set(type_slugs) - set(self.corpus_types)
73
94
  if missing_slugs:
74
95
  if create_missing:
75
96
  self.create_required_types(
@@ -79,7 +100,7 @@ class ElementMixin:
79
100
  )
80
101
  else:
81
102
  raise MissingTypeError(
82
- f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
103
+ f'Element {pluralize("type", len(missing_slugs))} {", ".join(sorted(missing_slugs))} were not found in corpus ({self.corpus_id}).'
83
104
  )
84
105
 
85
106
  return True
@@ -145,7 +166,7 @@ class ElementMixin:
145
166
  logger.warning("Cannot create element as this worker is in read-only mode")
146
167
  return
147
168
 
148
- sub_element = self.request(
169
+ sub_element = self.api_client.request(
149
170
  "CreateElement",
150
171
  body={
151
172
  "type": type,
@@ -161,10 +182,12 @@ class ElementMixin:
161
182
 
162
183
  return sub_element["id"] if slim_output else sub_element
163
184
 
185
+ @batch_publication
164
186
  def create_elements(
165
187
  self,
166
188
  parent: Element | CachedElement,
167
189
  elements: list[dict[str, str | list[list[int | float]] | float | None]],
190
+ batch_size: int = DEFAULT_BATCH_SIZE,
168
191
  ) -> list[dict[str, str]]:
169
192
  """
170
193
  Create child elements on the given element in a single API request.
@@ -185,6 +208,8 @@ class ElementMixin:
185
208
  confidence (float or None)
186
209
  Optional confidence score, between 0.0 and 1.0.
187
210
 
211
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
212
+
188
213
  :return: List of dicts, with each dict having a single key, ``id``, holding the UUID of each created element.
189
214
  """
190
215
  if isinstance(parent, Element):
@@ -243,14 +268,18 @@ class ElementMixin:
243
268
  logger.warning("Cannot create elements as this worker is in read-only mode")
244
269
  return
245
270
 
246
- created_ids = self.request(
247
- "CreateElements",
248
- id=parent.id,
249
- body={
250
- "worker_run_id": self.worker_run_id,
251
- "elements": elements,
252
- },
253
- )
271
+ created_ids = [
272
+ created_id
273
+ for batch in make_batches(elements, "element", batch_size)
274
+ for created_id in self.api_client.request(
275
+ "CreateElements",
276
+ id=parent.id,
277
+ body={
278
+ "worker_run_id": self.worker_run_id,
279
+ "elements": batch,
280
+ },
281
+ )
282
+ ]
254
283
 
255
284
  if self.use_cache:
256
285
  # Create the image as needed and handle both an Element and a CachedElement
@@ -311,7 +340,7 @@ class ElementMixin:
311
340
  logger.warning("Cannot link elements as this worker is in read-only mode")
312
341
  return
313
342
 
314
- return self.request(
343
+ return self.api_client.request(
315
344
  "CreateElementParent",
316
345
  parent=parent.id,
317
346
  child=child.id,
@@ -383,7 +412,7 @@ class ElementMixin:
383
412
  logger.warning("Cannot update element as this worker is in read-only mode")
384
413
  return
385
414
 
386
- updated_element = self.request(
415
+ updated_element = self.api_client.request(
387
416
  "PartialUpdateElement",
388
417
  id=element.id,
389
418
  body=kwargs,