arkindex-base-worker 0.3.6rc5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.dist-info}/METADATA +14 -13
  2. arkindex_base_worker-0.3.7.dist-info/RECORD +47 -0
  3. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +14 -0
  6. arkindex_worker/image.py +29 -19
  7. arkindex_worker/models.py +14 -2
  8. arkindex_worker/utils.py +17 -3
  9. arkindex_worker/worker/__init__.py +122 -125
  10. arkindex_worker/worker/base.py +24 -24
  11. arkindex_worker/worker/classification.py +18 -25
  12. arkindex_worker/worker/dataset.py +24 -18
  13. arkindex_worker/worker/element.py +45 -6
  14. arkindex_worker/worker/entity.py +35 -4
  15. arkindex_worker/worker/metadata.py +21 -11
  16. arkindex_worker/worker/training.py +13 -0
  17. arkindex_worker/worker/transcription.py +45 -5
  18. arkindex_worker/worker/version.py +22 -0
  19. hooks/pre_gen_project.py +3 -0
  20. tests/conftest.py +14 -6
  21. tests/test_base_worker.py +0 -6
  22. tests/test_dataset_worker.py +291 -409
  23. tests/test_elements_worker/test_classifications.py +365 -539
  24. tests/test_elements_worker/test_cli.py +1 -1
  25. tests/test_elements_worker/test_dataset.py +97 -116
  26. tests/test_elements_worker/test_elements.py +227 -61
  27. tests/test_elements_worker/test_entities.py +22 -2
  28. tests/test_elements_worker/test_metadata.py +53 -27
  29. tests/test_elements_worker/test_training.py +35 -0
  30. tests/test_elements_worker/test_transcriptions.py +149 -16
  31. tests/test_elements_worker/test_worker.py +19 -6
  32. tests/test_image.py +37 -0
  33. tests/test_utils.py +23 -1
  34. worker-demo/tests/__init__.py +0 -0
  35. worker-demo/tests/conftest.py +32 -0
  36. worker-demo/tests/test_worker.py +12 -0
  37. worker-demo/worker_demo/__init__.py +6 -0
  38. worker-demo/worker_demo/worker.py +19 -0
  39. arkindex_base_worker-0.3.6rc5.dist-info/RECORD +0 -41
  40. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.dist-info}/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arkindex-base-worker
3
- Version: 0.3.6rc5
3
+ Version: 0.3.7
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -41,22 +41,23 @@ Classifier: Topic :: Text Processing :: Linguistic
41
41
  Requires-Python: >=3.10
42
42
  Description-Content-Type: text/markdown
43
43
  License-File: LICENSE
44
- Requires-Dist: arkindex-client ==1.0.14
45
- Requires-Dist: peewee ==3.17.0
46
- Requires-Dist: Pillow ==10.1.0
47
- Requires-Dist: pymdown-extensions ==10.3.1
48
- Requires-Dist: python-gnupg ==0.5.1
49
- Requires-Dist: shapely ==2.0.2
50
- Requires-Dist: tenacity ==8.2.3
44
+ Requires-Dist: peewee ==3.17.1
45
+ Requires-Dist: Pillow ==10.3.0
46
+ Requires-Dist: pymdown-extensions ==10.7.1
47
+ Requires-Dist: python-gnupg ==0.5.2
48
+ Requires-Dist: shapely ==2.0.3
49
+ Requires-Dist: teklia-toolbox ==0.1.4
51
50
  Requires-Dist: zstandard ==0.22.0
52
51
  Provides-Extra: docs
53
- Requires-Dist: black ==23.11.0 ; extra == 'docs'
52
+ Requires-Dist: black ==24.4.0 ; extra == 'docs'
54
53
  Requires-Dist: doc8 ==1.1.1 ; extra == 'docs'
55
- Requires-Dist: mkdocs ==1.5.3 ; extra == 'docs'
56
- Requires-Dist: mkdocs-material ==9.4.8 ; extra == 'docs'
57
- Requires-Dist: mkdocstrings ==0.23.0 ; extra == 'docs'
58
- Requires-Dist: mkdocstrings-python ==1.7.3 ; extra == 'docs'
54
+ Requires-Dist: mkdocs-material ==9.5.17 ; extra == 'docs'
55
+ Requires-Dist: mkdocstrings-python ==1.9.2 ; extra == 'docs'
59
56
  Requires-Dist: recommonmark ==0.7.1 ; extra == 'docs'
57
+ Provides-Extra: tests
58
+ Requires-Dist: pytest ==8.1.1 ; extra == 'tests'
59
+ Requires-Dist: pytest-mock ==3.14.0 ; extra == 'tests'
60
+ Requires-Dist: pytest-responses ==0.5.1 ; extra == 'tests'
60
61
 
61
62
  # Arkindex base Worker
62
63
 
@@ -0,0 +1,47 @@
1
+ arkindex_worker/__init__.py,sha256=OlgCtTC9MaWeejviY0a3iQpALcRQGMVArFVVYwTF6I8,162
2
+ arkindex_worker/cache.py,sha256=FTlB0coXofn5zTNRTcVIvh709mcw4a1bPGqkwWjKs3w,11248
3
+ arkindex_worker/image.py,sha256=5ymIGaTm2D7Sp2YYQkbuheuGnx5VJo0_AzYAEIvNGhs,14267
4
+ arkindex_worker/models.py,sha256=xSvOadkNg3rgccic1xLgonzP28ugzmcGw0IUqXn51Cc,9844
5
+ arkindex_worker/utils.py,sha256=0Mu7Fa8DVcHn19pg-FIXqMDpfgzQkb7QR9IAlAi-x_k,7243
6
+ arkindex_worker/worker/__init__.py,sha256=U-_zOrQ09xmpBF9SmrTVj_UwnsCjFueV5G2hJAFEwv0,18806
7
+ arkindex_worker/worker/base.py,sha256=qtkCGfpGn7SWsQZRJ5cpW0gQ4tV_cyR_AHbuHZr53z4,19585
8
+ arkindex_worker/worker/classification.py,sha256=JVz-6YEeuavOy7zGfQi4nE_wpj9hwMUZDXTem-hXQY8,10328
9
+ arkindex_worker/worker/dataset.py,sha256=roX2IMMNA-icteTtRADiFSZiZSRPClqS62ZPJm9s2JI,2923
10
+ arkindex_worker/worker/element.py,sha256=AWK3YJSHWy3j4ajntJloi_2X4zxsgXZ6c6dzphgq3OI,33848
11
+ arkindex_worker/worker/entity.py,sha256=suhycfikC9oTPEWmX48_cnvFEw-Wu5zBA8n_00K4KUk,14714
12
+ arkindex_worker/worker/metadata.py,sha256=Bouuc_JaXogKykVXOTKDVP3tX--OUQeHoazxIGrGrJI,6702
13
+ arkindex_worker/worker/task.py,sha256=cz3wJNPgogZv1lm_3lm7WScitQtYQtL6H6I7Xokq208,1475
14
+ arkindex_worker/worker/training.py,sha256=YYnLNi4lsB0fEDj8Xh73z2Amt1LIfPdpuGzagOEtgDE,10648
15
+ arkindex_worker/worker/transcription.py,sha256=6R7ofcGnNqX4rjT0kRKIE-G9FHq2TJ1tfztNM5sTqYE,20464
16
+ arkindex_worker/worker/version.py,sha256=cs2pdlDxpKRO2Oldvcu54w-D_DQhf1cdeEt4tKX_QYs,1927
17
+ hooks/pre_gen_project.py,sha256=xQJERv3vv9VzIqcBHI281eeWLWREXUF4mMw7PvJHHXM,269
18
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ tests/conftest.py,sha256=Oi5SJic4TNwDj8Pm0WHgg657yB7_JKxbLC0HYPI3RUc,22134
20
+ tests/test_base_worker.py,sha256=Uq6_MpLW23gmKFXkU-SyDUaA_4dlViLBGG4e3gpBBz0,24512
21
+ tests/test_cache.py,sha256=ii0gyr0DrG7ChEs7pmT8hMdSguAOAcCze4bRMiFQxuk,10640
22
+ tests/test_dataset_worker.py,sha256=1joFRFmkL6XfPL9y1NYB_5QO-5FF56rwigAHrqtJMMA,23848
23
+ tests/test_element.py,sha256=2G9M15TLxQRmvrWM9Kw2ucnElh4kSv_oF_5FYwwAxTY,13181
24
+ tests/test_image.py,sha256=FZv8njLxh45sVgmY71UFHt0lv1cHr0cK4rrtPhQleX8,16262
25
+ tests/test_merge.py,sha256=Q4zCbtZbe0wBfqE56gvAD06c6pDuhqnjKaioFqIgAQw,8331
26
+ tests/test_utils.py,sha256=vpeHMeL7bJQonv5ZEbJmlJikqVKn5VWlVEbvmYFzDYA,1650
27
+ tests/test_elements_worker/__init__.py,sha256=Fh4nkbbyJSMv_VtjQxnWrOqTnxXaaWI8S9WU0VrzCHs,179
28
+ tests/test_elements_worker/test_classifications.py,sha256=vU6al1THtDSmERyVscMXaqiRPwTllcpRUHyeyBQ8M9U,26417
29
+ tests/test_elements_worker/test_cli.py,sha256=BsFTswLti63WAZ2pf6ipiZKWJJyCQuSfuKnSlESuK8g,2878
30
+ tests/test_elements_worker/test_dataset.py,sha256=hityecntzrldkuBHBWApYDkXSzSySdG3AZXJlM_sCOM,11777
31
+ tests/test_elements_worker/test_elements.py,sha256=6XKtgXSVQJnTSgTHWwEVsAtIwLBapjYjUYPUdjxcHsY,84971
32
+ tests/test_elements_worker/test_entities.py,sha256=yi1mXzvKvNwUNMzo0UZ56YOIJstYHcLyeepPJ8f10MQ,34557
33
+ tests/test_elements_worker/test_metadata.py,sha256=YMYmkUSEp4WKNBm3QLcrg4yn6qVTWQ_aZzSu9Xygr80,18756
34
+ tests/test_elements_worker/test_task.py,sha256=FCpxE9UpouKXgjGvWgNHEai_Hiy2d1YmqRG-_v2s27s,6312
35
+ tests/test_elements_worker/test_training.py,sha256=3PGH6dAc2eSBD7w6ivrt1yAh6sCoici4nuIS9zdw6S8,9476
36
+ tests/test_elements_worker/test_transcriptions.py,sha256=WVJG26sZyY66fu-Eka9A1_WWIeNI2scogjypzURnp8A,73468
37
+ tests/test_elements_worker/test_worker.py,sha256=7-jGJVT3yMGpIyN96Uafz5eIUrO4ieNLgw0k1D8BhGc,17163
38
+ worker-demo/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
+ worker-demo/tests/conftest.py,sha256=XzNMNeg6pmABUAH8jN6eZTlZSFGLYjS3-DTXjiRN6Yc,1002
40
+ worker-demo/tests/test_worker.py,sha256=3DLd4NRK4bfyatG5P_PK4k9P9tJHx9XQq5_ryFEEFVg,304
41
+ worker-demo/worker_demo/__init__.py,sha256=2BPomV8ZMNf3YXJgloatKeHQCE6QOkwmsHGkO6MkQuM,125
42
+ worker-demo/worker_demo/worker.py,sha256=Rt-DjWa5iBP08k58NDZMfeyPuFbtNcbX6nc5jFX7GNo,440
43
+ arkindex_base_worker-0.3.7.dist-info/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
44
+ arkindex_base_worker-0.3.7.dist-info/METADATA,sha256=AH2_i5Ne_vAPAYdQhlFhJQogSzDuLFtxueFsDMpkbMw,3458
45
+ arkindex_base_worker-0.3.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
46
+ arkindex_base_worker-0.3.7.dist-info/top_level.txt,sha256=58NuslgxQC2vT4DiqZEgO4JqJRrYa2yeNI9QvkbfGQU,40
47
+ arkindex_base_worker-0.3.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: bdist_wheel (0.43.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,2 +1,4 @@
1
1
  arkindex_worker
2
+ hooks
2
3
  tests
4
+ worker-demo
arkindex_worker/cache.py CHANGED
@@ -374,3 +374,17 @@ def merge_parents_cache(paths: list, current_database: Path):
374
374
  for statement in statements:
375
375
  cursor.execute(statement)
376
376
  connection.commit()
377
+
378
+
379
+ def unsupported_cache(func):
380
+ def wrapper(self, *args, **kwargs):
381
+ results = func(self, *args, **kwargs)
382
+
383
+ if not (self.is_read_only or self.use_cache):
384
+ logger.warning(
385
+ f"This API helper `{func.__name__}` did not update the cache database"
386
+ )
387
+
388
+ return results
389
+
390
+ return wrapper
arkindex_worker/image.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Helper methods to download and open IIIF images, and manage polygons.
3
3
  """
4
+
4
5
  import re
5
6
  from collections import namedtuple
6
7
  from io import BytesIO
@@ -20,6 +21,7 @@ from tenacity import (
20
21
  )
21
22
 
22
23
  from arkindex_worker import logger
24
+ from teklia_toolbox.requests import should_verify_cert
23
25
 
24
26
  # Avoid circular imports error when type checking
25
27
  if TYPE_CHECKING:
@@ -114,32 +116,38 @@ def download_image(url: str) -> Image:
114
116
  )
115
117
  else:
116
118
  raise e
117
- except requests.exceptions.SSLError:
118
- logger.warning(
119
- "An SSLError occurred during image download, retrying with a weaker and unsafe SSL configuration"
120
- )
121
-
122
- # Saving current ciphers
123
- previous_ciphers = requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS
124
-
125
- # Downgrading ciphers to download the image
126
- requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL:@SECLEVEL=1"
127
- resp = _retried_request(url)
128
-
129
- # Restoring previous ciphers
130
- requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = previous_ciphers
131
119
 
132
120
  # Preprocess the image and prepare it for classification
133
121
  image = Image.open(BytesIO(resp.content))
134
122
  logger.info(
135
- "Downloaded image {} - size={}x{} in {}".format(
136
- url, image.size[0], image.size[1], resp.elapsed
137
- )
123
+ f"Downloaded image {url} - size={image.size[0]}x{image.size[1]} in {resp.elapsed}"
138
124
  )
139
125
 
140
126
  return image
141
127
 
142
128
 
129
+ def upload_image(image: Image, url: str) -> requests.Response:
130
+ """
131
+ Upload a Pillow image to a URL.
132
+
133
+ :param image: Pillow image to upload.
134
+ :param url: Destination URL.
135
+ :returns: The upload response.
136
+ """
137
+ assert url.startswith("http"), "Destination URL for the image must be HTTP(S)"
138
+
139
+ # Retrieve a binarized version of the image
140
+ image_bytes = BytesIO()
141
+ image.save(image_bytes, format="jpeg")
142
+ image_bytes.seek(0)
143
+
144
+ # Upload the image
145
+ resp = _retried_request(url, method=requests.put, data=image_bytes)
146
+ logger.info(f"Uploaded image to {url} in {resp.elapsed}")
147
+
148
+ return resp
149
+
150
+
143
151
  def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox:
144
152
  """
145
153
  Compute the rectangle bounding box of a polygon.
@@ -167,8 +175,10 @@ def _retry_log(retry_state, *args, **kwargs):
167
175
  before_sleep=_retry_log,
168
176
  reraise=True,
169
177
  )
170
- def _retried_request(url):
171
- resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
178
+ def _retried_request(url, *args, method=requests.get, **kwargs):
179
+ resp = method(
180
+ url, *args, timeout=DOWNLOAD_TIMEOUT, verify=should_verify_cert(url), **kwargs
181
+ )
172
182
  resp.raise_for_status()
173
183
  return resp
174
184
 
arkindex_worker/models.py CHANGED
@@ -20,6 +20,8 @@ class MagicDict(dict):
20
20
  Automagically convert lists and dicts to MagicDicts and lists of MagicDicts
21
21
  Allows for nested access: foo.bar.baz
22
22
  """
23
+ if isinstance(item, Dataset):
24
+ return item
23
25
  if isinstance(item, list):
24
26
  return list(map(self._magify, item))
25
27
  if isinstance(item, dict):
@@ -75,10 +77,10 @@ class Element(MagicDict):
75
77
 
76
78
  def image_url(self, size: str = "full") -> str | None:
77
79
  """
78
- Build an URL to access the image.
80
+ Build a URL to access the image.
79
81
  When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
80
82
  :param size: Subresolution of the image, following the syntax of the IIIF resize parameter.
81
- :returns: An URL to the image, or None if the element does not have an image.
83
+ :returns: A URL to the image, or None if the element does not have an image.
82
84
  """
83
85
  if not self.get("zone"):
84
86
  return
@@ -272,6 +274,16 @@ class Dataset(ArkindexModel):
272
274
  return f"{self.id}.tar.zst"
273
275
 
274
276
 
277
+ class Set(MagicDict):
278
+ """
279
+ Describes an Arkindex dataset set.
280
+ """
281
+
282
+ def __str__(self):
283
+ # Not using ArkindexModel.__str__ as we do not retrieve the Set ID
284
+ return f"{self.__class__.__name__} ({self.name}) from {self.dataset}"
285
+
286
+
275
287
  class Artifact(ArkindexModel):
276
288
  """
277
289
  Describes an Arkindex artifact.
arkindex_worker/utils.py CHANGED
@@ -10,6 +10,19 @@ import zstandard as zstd
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
+ MANUAL_SOURCE = "manual"
14
+
15
+
16
+ def parse_source_id(value: str) -> bool | str | None:
17
+ """
18
+ Parse a UUID argument (Worker Version, Worker Run, ...) to use it directly in the API.
19
+ Arkindex API filters generally expect `False` to filter manual sources.
20
+ """
21
+ if value == MANUAL_SOURCE:
22
+ return False
23
+ return value or None
24
+
25
+
13
26
  CHUNK_SIZE = 1024
14
27
  """Chunk Size used for ZSTD compression"""
15
28
 
@@ -31,9 +44,10 @@ def decompress_zst_archive(compressed_archive: Path) -> tuple[int, Path]:
31
44
 
32
45
  logger.debug(f"Uncompressing file to {archive_path}")
33
46
  try:
34
- with compressed_archive.open("rb") as compressed, archive_path.open(
35
- "wb"
36
- ) as decompressed:
47
+ with (
48
+ compressed_archive.open("rb") as compressed,
49
+ archive_path.open("wb") as decompressed,
50
+ ):
37
51
  dctx.copy_stream(compressed, decompressed)
38
52
  logger.debug(f"Successfully uncompressed archive {compressed_archive}")
39
53
  except zstandard.ZstdError as e:
@@ -1,31 +1,31 @@
1
1
  """
2
2
  Base classes to implement Arkindex workers.
3
3
  """
4
+
4
5
  import contextlib
5
6
  import json
6
7
  import os
7
8
  import sys
8
9
  import uuid
10
+ from argparse import ArgumentTypeError
9
11
  from collections.abc import Iterable, Iterator
10
12
  from enum import Enum
11
- from itertools import groupby
12
- from operator import itemgetter
13
13
  from pathlib import Path
14
14
 
15
15
  from apistar.exceptions import ErrorResponse
16
16
 
17
17
  from arkindex_worker import logger
18
18
  from arkindex_worker.cache import CachedElement
19
- from arkindex_worker.models import Dataset, Element
19
+ from arkindex_worker.models import Dataset, Element, Set
20
20
  from arkindex_worker.worker.base import BaseWorker
21
21
  from arkindex_worker.worker.classification import ClassificationMixin
22
22
  from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
23
23
  from arkindex_worker.worker.element import ElementMixin
24
- from arkindex_worker.worker.entity import EntityMixin # noqa: F401
24
+ from arkindex_worker.worker.entity import EntityMixin
25
25
  from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
26
26
  from arkindex_worker.worker.task import TaskMixin
27
27
  from arkindex_worker.worker.transcription import TranscriptionMixin
28
- from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
28
+ from arkindex_worker.worker.version import WorkerVersionMixin
29
29
 
30
30
 
31
31
  class ActivityState(Enum):
@@ -159,6 +159,16 @@ class ElementsWorker(
159
159
  super().configure()
160
160
  super().configure_cache()
161
161
 
162
+ # Retrieve the model configuration
163
+ if self.model_configuration:
164
+ self.config.update(self.model_configuration)
165
+ logger.info("Model version configuration retrieved")
166
+
167
+ # Retrieve the user configuration
168
+ if self.user_configuration:
169
+ self.config.update(self.user_configuration)
170
+ logger.info("User configuration retrieved")
171
+
162
172
  def run(self):
163
173
  """
164
174
  Implements an Arkindex worker that goes through each element returned by
@@ -229,12 +239,13 @@ class ElementsWorker(
229
239
  with contextlib.suppress(Exception):
230
240
  self.update_activity(element.id, ActivityState.Error)
231
241
 
242
+ message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
232
243
  if failed:
233
- logger.error(
234
- f"Ran on {count} elements: {count - failed} completed, {failed} failed"
235
- )
244
+ logger.error(message)
236
245
  if failed >= count: # Everything failed!
237
246
  sys.exit(1)
247
+ else:
248
+ logger.info(message)
238
249
 
239
250
  def process_element(self, element: Element | CachedElement):
240
251
  """
@@ -299,6 +310,21 @@ class ElementsWorker(
299
310
  return True
300
311
 
301
312
 
313
+ def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
314
+ values = value.split(":")
315
+ if len(values) != 2:
316
+ raise ArgumentTypeError(
317
+ f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
318
+ )
319
+
320
+ dataset_id, set_name = values
321
+ try:
322
+ dataset_id = uuid.UUID(dataset_id)
323
+ return (dataset_id, set_name)
324
+ except (TypeError, ValueError) as e:
325
+ raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
326
+
327
+
302
328
  class MissingDatasetArchive(Exception):
303
329
  """
304
330
  Exception raised when the compressed archive associated to
@@ -308,7 +334,7 @@ class MissingDatasetArchive(Exception):
308
334
 
309
335
  class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
310
336
  """
311
- Base class for ML workers that operate on Arkindex datasets.
337
+ Base class for ML workers that operate on Arkindex dataset sets.
312
338
 
313
339
  This class inherits from numerous mixin classes found in other modules of
314
340
  ``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
@@ -318,24 +344,28 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
318
344
  self,
319
345
  description: str = "Arkindex Dataset Worker",
320
346
  support_cache: bool = False,
321
- generator: bool = False,
322
347
  ):
323
348
  """
324
349
  :param description: The worker's description.
325
350
  :param support_cache: Whether the worker supports cache.
326
- :param generator: Whether the worker generates the dataset archive artifact.
327
351
  """
328
352
  super().__init__(description, support_cache)
329
353
 
354
+ # Path to the dataset compressed archive (containing images and a SQLite database)
355
+ # Set as an instance variable as dataset workers might use it to easily extract its content
356
+ self.downloaded_dataset_artifact: Path | None = None
357
+
330
358
  self.parser.add_argument(
331
- "--dataset",
332
- type=uuid.UUID,
359
+ "--set",
360
+ type=check_dataset_set,
333
361
  nargs="+",
334
- help="One or more Arkindex dataset ID",
362
+ help="""
363
+ One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
364
+ (e.g.: "12341234-1234-1234-1234-123412341234:train")
365
+ """,
366
+ default=[],
335
367
  )
336
368
 
337
- self.generator = generator
338
-
339
369
  def configure(self):
340
370
  """
341
371
  Setup the worker using CLI arguments and environment variables.
@@ -349,163 +379,130 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
349
379
  super().configure()
350
380
  super().configure_cache()
351
381
 
352
- def download_dataset_artifact(self, dataset: Dataset) -> Path:
382
+ # Retrieve the model configuration
383
+ if self.model_configuration:
384
+ self.config.update(self.model_configuration)
385
+ logger.info("Model version configuration retrieved")
386
+
387
+ # Retrieve the user configuration
388
+ if self.user_configuration:
389
+ self.config.update(self.user_configuration)
390
+ logger.info("User configuration retrieved")
391
+
392
+ def cleanup_downloaded_artifact(self) -> None:
393
+ """
394
+ Cleanup the downloaded dataset artifact if any
395
+ """
396
+ if not self.downloaded_dataset_artifact:
397
+ return
398
+
399
+ self.downloaded_dataset_artifact.unlink(missing_ok=True)
400
+
401
+ def download_dataset_artifact(self, dataset: Dataset) -> None:
353
402
  """
354
403
  Find and download the compressed archive artifact describing a dataset using
355
404
  the [list_artifacts][arkindex_worker.worker.task.TaskMixin.list_artifacts] and
356
405
  [download_artifact][arkindex_worker.worker.task.TaskMixin.download_artifact] methods.
357
406
 
358
407
  :param dataset: The dataset to retrieve the compressed archive artifact for.
359
- :returns: A path to the downloaded artifact.
360
408
  :raises MissingDatasetArchive: When the dataset artifact is not found.
361
409
  """
410
+ extra_dir = self.find_extras_directory()
411
+ archive = extra_dir / dataset.filepath
412
+ if archive.exists():
413
+ return
362
414
 
363
- task_id = uuid.UUID(dataset.task_id)
415
+ # Cleanup the dataset artifact that was downloaded previously
416
+ self.cleanup_downloaded_artifact()
364
417
 
418
+ logger.info(f"Downloading artifact for {dataset}")
419
+ task_id = uuid.UUID(dataset.task_id)
365
420
  for artifact in self.list_artifacts(task_id):
366
421
  if artifact.path != dataset.filepath:
367
422
  continue
368
423
 
369
- extra_dir = self.find_extras_directory()
370
- archive = extra_dir / dataset.filepath
371
424
  archive.write_bytes(self.download_artifact(task_id, artifact).read())
372
- return archive
425
+ self.downloaded_dataset_artifact = archive
426
+ return
373
427
 
374
428
  raise MissingDatasetArchive(
375
429
  "The dataset compressed archive artifact was not found."
376
430
  )
377
431
 
378
- def list_dataset_elements_per_split(
379
- self, dataset: Dataset
380
- ) -> Iterator[tuple[str, list[Element]]]:
381
- """
382
- List the elements in the dataset, grouped by split, using the
383
- [list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
384
-
385
- :param dataset: The dataset to retrieve elements from.
386
- :returns: An iterator of tuples containing the split name and the list of its elements.
387
- """
388
-
389
- def format_split(
390
- split: tuple[str, Iterator[tuple[str, Element]]],
391
- ) -> tuple[str, list[Element]]:
392
- return (split[0], list(map(itemgetter(1), list(split[1]))))
393
-
394
- return map(
395
- format_split,
396
- groupby(
397
- sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
398
- key=itemgetter(0),
399
- ),
400
- )
401
-
402
- def process_dataset(self, dataset: Dataset):
432
+ def process_set(self, set: Set):
403
433
  """
404
- Override this method to implement your worker and process a single Arkindex dataset at once.
434
+ Override this method to implement your worker and process a single Arkindex dataset set at once.
405
435
 
406
- :param dataset: The dataset to process.
436
+ :param set: The set to process.
407
437
  """
408
438
 
409
- def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
439
+ def list_sets(self) -> Iterator[Set]:
410
440
  """
411
- List the datasets to be processed, either from the CLI arguments or using the
412
- [list_process_datasets][arkindex_worker.worker.dataset.DatasetMixin.list_process_datasets] method.
441
+ List the sets to be processed, either from the CLI arguments or using the
442
+ [list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
413
443
 
414
- :returns: An iterator of strings if the worker is in read-only mode,
415
- else an iterator of ``Dataset`` objects.
444
+ :returns: An iterator of ``Set`` objects.
416
445
  """
417
- if self.is_read_only:
418
- return map(str, self.args.dataset)
446
+ if not self.is_read_only:
447
+ yield from self.list_process_sets()
448
+
449
+ datasets: dict[uuid.UUID, Dataset] = {}
450
+ for dataset_id, set_name in self.args.set:
451
+ # Retrieving dataset information is not already cached
452
+ if dataset_id not in datasets:
453
+ datasets[dataset_id] = Dataset(
454
+ **self.request("RetrieveDataset", id=dataset_id)
455
+ )
419
456
 
420
- return self.list_process_datasets()
457
+ yield Set(name=set_name, dataset=datasets[dataset_id])
421
458
 
422
459
  def run(self):
423
460
  """
424
- Implements an Arkindex worker that goes through each dataset returned by
425
- [list_datasets][arkindex_worker.worker.DatasetWorker.list_datasets].
461
+ Implements an Arkindex worker that goes through each dataset set returned by
462
+ [list_sets][arkindex_worker.worker.DatasetWorker.list_sets].
426
463
 
427
- It calls [process_dataset][arkindex_worker.worker.DatasetWorker.process_dataset],
428
- catching exceptions, and handles updating the [DatasetState][arkindex_worker.worker.dataset.DatasetState]
429
- when the worker is a generator.
464
+ It calls [process_set][arkindex_worker.worker.DatasetWorker.process_set],
465
+ catching exceptions.
430
466
  """
431
467
  self.configure()
432
468
 
433
- datasets: list[Dataset] | list[str] = list(self.list_datasets())
434
- if not datasets:
435
- logger.warning("No datasets to process, stopping.")
469
+ dataset_sets: list[Set] = list(self.list_sets())
470
+ if not dataset_sets:
471
+ logger.warning("No sets to process, stopping.")
436
472
  sys.exit(1)
437
473
 
438
- # Process every dataset
439
- count = len(datasets)
474
+ # Process every set
475
+ count = len(dataset_sets)
440
476
  failed = 0
441
- for i, item in enumerate(datasets, start=1):
442
- dataset = None
443
- dataset_artifact = None
444
-
477
+ for i, dataset_set in enumerate(dataset_sets, start=1):
445
478
  try:
446
- if not self.is_read_only:
447
- # Just use the result of list_datasets as the dataset
448
- dataset = item
449
- else:
450
- # Load dataset using the Arkindex API
451
- dataset = Dataset(**self.request("RetrieveDataset", id=item))
452
-
453
- if self.generator:
454
- assert (
455
- dataset.state == DatasetState.Open.value
456
- ), "When generating a new dataset, its state should be Open."
457
- else:
458
- assert (
459
- dataset.state == DatasetState.Complete.value
460
- ), "When processing an existing dataset, its state should be Complete."
461
-
462
- logger.info(f"Processing {dataset} ({i}/{count})")
463
-
464
- if self.generator:
465
- # Update the dataset state to Building
466
- logger.info(f"Building {dataset} ({i}/{count})")
467
- self.update_dataset_state(dataset, DatasetState.Building)
468
- else:
469
- logger.info(f"Downloading data for {dataset} ({i}/{count})")
470
- dataset_artifact = self.download_dataset_artifact(dataset)
479
+ assert (
480
+ dataset_set.dataset.state == DatasetState.Complete.value
481
+ ), "When processing a set, its dataset state should be Complete."
471
482
 
472
- # Process the dataset
473
- self.process_dataset(dataset)
483
+ logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
484
+ self.download_dataset_artifact(dataset_set.dataset)
474
485
 
475
- if self.generator:
476
- # Update the dataset state to Complete
477
- logger.info(f"Completed {dataset} ({i}/{count})")
478
- self.update_dataset_state(dataset, DatasetState.Complete)
486
+ logger.info(f"Processing {dataset_set} ({i}/{count})")
487
+ self.process_set(dataset_set)
479
488
  except Exception as e:
480
- # Handle errors occurring while retrieving, processing or patching the state for this dataset.
489
+ # Handle errors occurring while retrieving or processing this dataset set
481
490
  failed += 1
482
491
 
483
- # Handle the case where we failed retrieving the dataset
484
- dataset_id = dataset.id if dataset else item
485
-
486
492
  if isinstance(e, ErrorResponse):
487
- message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
493
+ message = f"An API error occurred while processing {dataset_set}: {e.title} - {e.content}"
488
494
  else:
489
- message = (
490
- f"Failed running worker on dataset {dataset_id}: {repr(e)}"
491
- )
495
+ message = f"Failed running worker on {dataset_set}: {repr(e)}"
492
496
 
493
- logger.warning(
494
- message,
495
- exc_info=e if self.args.verbose else None,
496
- )
497
- if dataset and self.generator:
498
- # Try to update the state to Error regardless of the response
499
- with contextlib.suppress(Exception):
500
- self.update_dataset_state(dataset, DatasetState.Error)
501
- finally:
502
- # Cleanup the dataset artifact if it was downloaded, no matter what
503
- if dataset_artifact:
504
- dataset_artifact.unlink(missing_ok=True)
497
+ logger.warning(message, exc_info=e if self.args.verbose else None)
498
+
499
+ # Cleanup the latest downloaded dataset artifact
500
+ self.cleanup_downloaded_artifact()
505
501
 
502
+ message = f'Ran on {count} set{"s"[:count>1]}: {count - failed} completed, {failed} failed'
506
503
  if failed:
507
- logger.error(
508
- f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
509
- )
504
+ logger.error(message)
510
505
  if failed >= count: # Everything failed!
511
506
  sys.exit(1)
507
+ else:
508
+ logger.info(message)