arkindex-base-worker 0.5.2a1__py3-none-any.whl → 0.5.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arkindex-base-worker
3
- Version: 0.5.2a1
3
+ Version: 0.5.2b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -23,8 +23,9 @@ Requires-Dist: humanize==4.15.0
23
23
  Requires-Dist: peewee~=3.17
24
24
  Requires-Dist: Pillow==11.3.0
25
25
  Requires-Dist: python-gnupg==0.5.6
26
+ Requires-Dist: python-magic==0.4.27
26
27
  Requires-Dist: shapely==2.0.6
27
- Requires-Dist: teklia-toolbox==0.1.12
28
+ Requires-Dist: teklia-toolbox==0.1.13
28
29
  Requires-Dist: zstandard==0.25.0
29
30
  Provides-Extra: tests
30
31
  Requires-Dist: pytest-mock==3.15.1; extra == "tests"
@@ -1,11 +1,11 @@
1
- arkindex_base_worker-0.5.2a1.dist-info/licenses/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
1
+ arkindex_base_worker-0.5.2b1.dist-info/licenses/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
2
2
  arkindex_worker/__init__.py,sha256=Sdt5KXn8EgURb2MurYVrUWaHbH3iFA1XLRo0Lc5AJ44,250
3
3
  arkindex_worker/cache.py,sha256=XpEXMSnbhYCvrJquwA9XXqZo-ajMLpaCxKG5wH3Gp6Y,10959
4
- arkindex_worker/image.py,sha256=sGE8to5iykXv25bpkftOEWzlh5NzBZSKy4lSRoHYHPU,20929
4
+ arkindex_worker/image.py,sha256=9KeZHWNIDkwNJZR0y-mbyD_pvKfrgdktMB32jZqSMYk,20927
5
5
  arkindex_worker/models.py,sha256=DgKvAB_2e1cPcuUavZkyTkV10jBK8y083oVklB9idSk,10855
6
- arkindex_worker/utils.py,sha256=Eqg5pGAuOmuwMT3EhKTQDMek7wHC1KzZL7XXqYVVfHY,10977
7
- arkindex_worker/worker/__init__.py,sha256=SzD0s1_m6gMV02EUF-NeciqZdVPA4dpXI84tSj-g494,17869
8
- arkindex_worker/worker/base.py,sha256=-R_aLMJHbR6X1uM-U0zExsF_KLy5Wl3WJ_YMGO9We0I,22153
6
+ arkindex_worker/utils.py,sha256=kqOTVLBh-0krD2ukTkroiMZ2820wNYxeR8Cf1AyoqNA,10859
7
+ arkindex_worker/worker/__init__.py,sha256=tM_ynAARmtuJw5YWb_jI0AD5KNXbWN1K-VDiixIp7O4,18009
8
+ arkindex_worker/worker/base.py,sha256=2nQdPGh2qQOUNmvV2Mc1KZeqE8d4Fhy9tCo6Q2nNdNQ,22214
9
9
  arkindex_worker/worker/classification.py,sha256=qvykymkgd4nGywHCxL8obo4egstoGsmWNS4Ztc1qNWQ,11024
10
10
  arkindex_worker/worker/corpus.py,sha256=MeIMod7jkWyX0frtD0a37rhumnMV3p9ZOC1xwAoXrAA,2291
11
11
  arkindex_worker/worker/dataset.py,sha256=tVaPx43vaH-KTtx4w5V06e26ha8XPfiJTRzBXlu928Y,5273
@@ -14,8 +14,8 @@ arkindex_worker/worker/entity.py,sha256=Aj6EOfzHEm7qQV-Egm0YKLZgCrLS_3ggOKTY81M2
14
14
  arkindex_worker/worker/image.py,sha256=L6Ikuf0Z0RxJk7JarY5PggJGrYSHLaPK0vn0dy0CIaQ,623
15
15
  arkindex_worker/worker/metadata.py,sha256=keZdOdUthSH2hAw9iet5pN7rzWihTUYjZHRGTEjaltw,6843
16
16
  arkindex_worker/worker/process.py,sha256=9TEHpMcBax1wc6PrWMMrdXe2uNfqyVj7n_dAYZRBGnY,1854
17
- arkindex_worker/worker/task.py,sha256=nYfMSFm_d-4t8y4PO4HjFBnLsZf7IsDjkS7-A2Pgnac,1525
18
- arkindex_worker/worker/training.py,sha256=tyQOHcwv--_wdYz6CgLEe1YM7kwwwKN30LvGTsnWd78,10923
17
+ arkindex_worker/worker/task.py,sha256=HASQU5LYVtgvCnRCLFC6iH7h7v6q_usZNZ-r_Wkv9A8,3306
18
+ arkindex_worker/worker/training.py,sha256=b1YGeUiOWob_DacS4fphGkErJGsx84YVgr5NnsukoEQ,8420
19
19
  arkindex_worker/worker/transcription.py,sha256=sw718R119tsLNY8inPMVeIilvFJo94fMbMtYgH0zTM8,21250
20
20
  examples/standalone/python/worker.py,sha256=Zr4s4pHvgexEjlkixLFYZp1UuwMLeoTxjyNG5_S2iYE,6672
21
21
  examples/tooled/python/worker.py,sha256=kIYlHLsO5UpwX4XtERRq4tf2qTsvqKK30C-w8t0yyhA,1821
@@ -24,11 +24,11 @@ tests/__init__.py,sha256=DG--S6IpGl399rzSAjDdHL76CkOIeZIjajCcyUSDhOQ,241
24
24
  tests/conftest.py,sha256=Tp7YFK17NATwF2yAcBwi0QFNyKSXtLS0VhZ-zZngsQI,24343
25
25
  tests/test_base_worker.py,sha256=lwS4X3atS2ktEKd1XdogmN3mbzq-tO206-k_0EDITlw,29302
26
26
  tests/test_cache.py,sha256=_wztzh94EwVrb8UvpFqgl2aa2_FLaCcJKaqunCYR5Dw,10435
27
- tests/test_dataset_worker.py,sha256=iDJM2C4PfQNH0r4_QqSWoPt8BcM0geUUdODtWY0Z9PA,22412
27
+ tests/test_dataset_worker.py,sha256=LmL3ERF1__PUPkTLiAFC0IYglZTv5WQYA42Vm-uhe2w,22023
28
28
  tests/test_element.py,sha256=hlj5VSF4plwC7uz9R4LGOOXZJQcHZiYCIDZT5V6EIB8,14334
29
29
  tests/test_image.py,sha256=yAM5mMfpQcIurT1KLHmu0AhSX2Qm3YvCu7afyZ3XUdU,28314
30
30
  tests/test_merge.py,sha256=REpZ13jkq_qm_4L5URQgFy5lxvPZtXxQEiWfYLMdmF0,7956
31
- tests/test_modern_config.py,sha256=Bm-a4LYQXgLZWQX7AmVyfJW0LNoLy1wj2d2GjzDkcBk,2683
31
+ tests/test_modern_config.py,sha256=ZbMHT5b5RG3ZPX4MoqI8zitRg2y5fV1C6ynfyRkq828,4008
32
32
  tests/test_utils.py,sha256=tgzNqyJMpddpeFWEjgsew_yDzmqnCA9HDaA5IpevAcM,5353
33
33
  tests/test_elements_worker/__init__.py,sha256=2t3NciCIOun_N-Wv63FWGsTm5W9N3mbwAWVuFORlMg8,308
34
34
  tests/test_elements_worker/test_classification.py,sha256=nya7veSPR_O9G41Enodp2-o6AifMBcaSTWJP2vXSSJ4,30133
@@ -44,8 +44,8 @@ tests/test_elements_worker/test_entity.py,sha256=SNAZEsVVLnqlliOmjkgv_cZhw0bAuJU
44
44
  tests/test_elements_worker/test_image.py,sha256=BljMNKgec_9a5bzNzFpYZIvSbuvwsWDfdqLHVJaTa7M,2079
45
45
  tests/test_elements_worker/test_metadata.py,sha256=qtTDtlp3VnBkfck7PAguK2dEgTLlr1i1EVnmNTeNf3A,20515
46
46
  tests/test_elements_worker/test_process.py,sha256=y4RoVhPfyHzR795fw7-_FXElBcKo3fy4Ew_HI-kxJic,3088
47
- tests/test_elements_worker/test_task.py,sha256=wTUWqN9UhfKmJn3IcFY75EW4I1ulRhisflmY1kmP47s,5574
48
- tests/test_elements_worker/test_training.py,sha256=qgK7BLucddRzc8ePbQtY75x17QvGDEq5XCwgyyvmAJE,8717
47
+ tests/test_elements_worker/test_task.py,sha256=oHwP1fbJftXFA2U4qA3Gb4vX-iJoV-sBvPHnfBBpRrc,8906
48
+ tests/test_elements_worker/test_training.py,sha256=VY3YKYAm8IijAD6gWY0g06I27gXvMxa7SnAsCWm7G-8,4896
49
49
  tests/test_elements_worker/test_transcription_create.py,sha256=yznO9B_BVsOR0Z_VY5ZL8gJp0ZPCz_4sPUs5dXtixAg,29281
50
50
  tests/test_elements_worker/test_transcription_create_with_elements.py,sha256=tmcyglgssEqMnt1Mdy_u6X1m2wgLWTo_HdWst3GrK2k,33056
51
51
  tests/test_elements_worker/test_transcription_list.py,sha256=ikz7HYPCoQWTdTRCd382SB-y-T2BbigPLlIcx5Eow-I,15324
@@ -55,7 +55,7 @@ worker-demo/tests/conftest.py,sha256=XzNMNeg6pmABUAH8jN6eZTlZSFGLYjS3-DTXjiRN6Yc
55
55
  worker-demo/tests/test_worker.py,sha256=3DLd4NRK4bfyatG5P_PK4k9P9tJHx9XQq5_ryFEEFVg,304
56
56
  worker-demo/worker_demo/__init__.py,sha256=2BPomV8ZMNf3YXJgloatKeHQCE6QOkwmsHGkO6MkQuM,125
57
57
  worker-demo/worker_demo/worker.py,sha256=Rt-DjWa5iBP08k58NDZMfeyPuFbtNcbX6nc5jFX7GNo,440
58
- arkindex_base_worker-0.5.2a1.dist-info/METADATA,sha256=AwYp_xJZzu6zAtvnvZjeK_W29tzqvRuwYnxwMYcKSIc,1849
59
- arkindex_base_worker-0.5.2a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
- arkindex_base_worker-0.5.2a1.dist-info/top_level.txt,sha256=-vNjP2VfROx0j83mdi9aIqRZ88eoJjxeWz-R_gPgyXU,49
61
- arkindex_base_worker-0.5.2a1.dist-info/RECORD,,
58
+ arkindex_base_worker-0.5.2b1.dist-info/METADATA,sha256=3dxm9h6sl-bra7LqCF5_plfIrX9qW8zgzF4cnAiUcoQ,1885
59
+ arkindex_base_worker-0.5.2b1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
60
+ arkindex_base_worker-0.5.2b1.dist-info/top_level.txt,sha256=-vNjP2VfROx0j83mdi9aIqRZ88eoJjxeWz-R_gPgyXU,49
61
+ arkindex_base_worker-0.5.2b1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
arkindex_worker/image.py CHANGED
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  from arkindex_worker.models import Element
39
39
 
40
40
  # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
41
- DOWNLOAD_TIMEOUT = (30, 60)
41
+ REQUEST_TIMEOUT = (30, 60)
42
42
 
43
43
  BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
44
44
 
@@ -346,7 +346,7 @@ def _retried_request(url, *args, method=requests.get, **kwargs):
346
346
  url,
347
347
  *args,
348
348
  headers={"User-Agent": IIIF_USER_AGENT},
349
- timeout=DOWNLOAD_TIMEOUT,
349
+ timeout=REQUEST_TIMEOUT,
350
350
  verify=should_verify_cert(url),
351
351
  **kwargs,
352
352
  )
arkindex_worker/utils.py CHANGED
@@ -163,12 +163,12 @@ def zstd_compress(
163
163
 
164
164
  def create_tar_archive(
165
165
  path: Path, destination: Path | None = None
166
- ) -> tuple[int | None, Path, str]:
166
+ ) -> tuple[int | None, Path]:
167
167
  """Create a tar archive using the content at specified location.
168
168
 
169
169
  :param path: Path to the file to archive
170
170
  :param destination: Optional path for the created TAR archive. A tempfile will be created if this is omitted.
171
- :return: The file descriptor (if one was created) and path to the TAR archive, hash of its content.
171
+ :return: The file descriptor (if one was created) and path to the TAR archive.
172
172
  """
173
173
  # Parse destination and create a tmpfile if none was specified
174
174
  file_d, destination = (
@@ -204,26 +204,26 @@ def create_tar_archive(
204
204
  with file_path.open("rb") as file_data:
205
205
  for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
206
206
  content_hasher.update(chunk)
207
- return file_d, destination, content_hasher.hexdigest()
207
+ return file_d, destination
208
208
 
209
209
 
210
210
  def create_tar_zst_archive(
211
211
  source: Path, destination: Path | None = None
212
- ) -> tuple[int | None, Path, str, str]:
212
+ ) -> tuple[int | None, Path, str]:
213
213
  """Helper to create a TAR+ZST archive from a source folder.
214
214
 
215
215
  :param source: Path to the folder whose content should be archived.
216
216
  :param destination: Path to the created archive, defaults to None. If unspecified, a temporary file will be created.
217
- :return: The file descriptor of the created tempfile (if one was created), path to the archive, its hash and the hash of the tar archive's content.
217
+ :return: The file descriptor of the created tempfile (if one was created), path to the archive and its hash.
218
218
  """
219
219
  # Create tar archive
220
- tar_fd, tar_archive, tar_hash = create_tar_archive(source)
220
+ tar_fd, tar_archive = create_tar_archive(source)
221
221
 
222
222
  zst_fd, zst_archive, zst_hash = zstd_compress(tar_archive, destination)
223
223
 
224
224
  close_delete_file(tar_fd, tar_archive)
225
225
 
226
- return zst_fd, zst_archive, zst_hash, tar_hash
226
+ return zst_fd, zst_archive, zst_hash
227
227
 
228
228
 
229
229
  def create_zip_archive(source: Path, destination: Path | None = None) -> Path:
@@ -424,12 +424,13 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
424
424
  failed = 0
425
425
  for i, dataset_set in enumerate(dataset_sets, start=1):
426
426
  try:
427
- assert dataset_set.dataset.state == DatasetState.Complete.value, (
428
- "When processing a set, its dataset state should be Complete."
429
- )
430
-
431
- logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
432
- self.download_dataset_artifact(dataset_set.dataset)
427
+ if dataset_set.dataset.state == DatasetState.Complete.value:
428
+ logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
429
+ self.download_dataset_artifact(dataset_set.dataset)
430
+ else:
431
+ logger.warning(
432
+ f"The dataset {dataset_set.dataset} has its state set to `{dataset_set.dataset.state}`, its archive will not be downloaded"
433
+ )
433
434
 
434
435
  logger.info(f"Processing {dataset_set} ({i}/{count})")
435
436
  self.process_set(dataset_set)
@@ -444,7 +445,7 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
444
445
 
445
446
  logger.warning(message, exc_info=e if self.args.verbose else None)
446
447
 
447
- # Cleanup the latest downloaded dataset artifact
448
+ # Cleanup the latest downloaded dataset artifact (if needed)
448
449
  self.cleanup_downloaded_artifact()
449
450
 
450
451
  message = f"Ran on {count} {pluralize('set', count)}: {count - failed} completed, {failed} failed"
@@ -15,7 +15,7 @@ import gnupg
15
15
  import yaml
16
16
 
17
17
  from arkindex import options_from_env
18
- from arkindex.exceptions import ClientError, ErrorResponse
18
+ from arkindex.exceptions import ErrorResponse
19
19
  from arkindex_worker import logger
20
20
  from arkindex_worker.cache import (
21
21
  check_version,
@@ -261,6 +261,10 @@ class BaseWorker:
261
261
 
262
262
  logger.info(f"Loaded {worker_run['summary']} from API")
263
263
 
264
+ # The `RetrieveSecret` endpoint is only available in Arkindex EE.
265
+ # In CE, the values of `secret` fields should be used directly without calling `RetrieveSecret`.
266
+ can_retrieve_secret = "RetrieveSecret" in self.api_client.document.links
267
+
264
268
  def _process_config_item(item: dict) -> tuple[str, Any]:
265
269
  if not item["secret"]:
266
270
  return (item["key"], item["value"])
@@ -270,16 +274,12 @@ class BaseWorker:
270
274
  logger.info(f"Optional secret `{item['key']}` is not set")
271
275
  return (item["key"], None)
272
276
 
273
- # Load secret, only available in Arkindex EE
274
- try:
275
- secret = self.load_secret(Path(item["value"]))
276
- except ClientError as e:
277
- logger.error(
278
- f"Failed to retrieve the secret {item['value']}, probably an Arkindex Community Edition: {e}"
279
- )
280
- return (item["key"], None)
277
+ value = item["value"]
278
+ # Load secret when `RetrieveSecret` is available
279
+ if can_retrieve_secret:
280
+ value = self.load_secret(Path(item["value"]))
281
281
 
282
- return (item["key"], secret)
282
+ return (item["key"], value)
283
283
 
284
284
  # Load model version configuration when available
285
285
  # Workers will use model version ID and details to download the model
@@ -4,9 +4,16 @@ BaseWorker methods for tasks.
4
4
 
5
5
  import uuid
6
6
  from collections.abc import Iterator
7
+ from http.client import REQUEST_TIMEOUT
8
+ from pathlib import Path
9
+
10
+ import magic
11
+ import requests
7
12
 
8
13
  from arkindex.compat import DownloadedFile
14
+ from arkindex_worker import logger
9
15
  from arkindex_worker.models import Artifact
16
+ from teklia_toolbox.requests import should_verify_cert
10
17
 
11
18
 
12
19
  class TaskMixin:
@@ -45,3 +52,49 @@ class TaskMixin:
45
52
  return self.api_client.request(
46
53
  "DownloadArtifact", id=task_id, path=artifact.path
47
54
  )
55
+
56
+ def upload_artifact(self, path: Path) -> None:
57
+ """
58
+ Upload a single file as an Artifact of the current task.
59
+
60
+ :param path: Path of the single file to upload as an Artifact.
61
+ """
62
+ assert path and isinstance(path, Path) and path.exists(), (
63
+ "path shouldn't be null, should be a Path and should exist"
64
+ )
65
+
66
+ if self.is_read_only:
67
+ logger.warning("Cannot upload artifact as this worker is in read-only mode")
68
+ return
69
+
70
+ # Get path relative to task's data directory
71
+ relpath = str(path.relative_to(self.work_dir))
72
+
73
+ # Get file size
74
+ size = path.stat().st_size
75
+
76
+ # Detect content type
77
+ try:
78
+ content_type = magic.from_file(path, mime=True)
79
+ except Exception as e:
80
+ logger.warning(f"Failed to get a mime type for {path}: {e}")
81
+ content_type = "application/octet-stream"
82
+
83
+ # Create artifact on API to get an S3 url
84
+ artifact = self.api_client.request(
85
+ "CreateArtifact",
86
+ id=self.task_id,
87
+ body={"path": relpath, "content_type": content_type, "size": size},
88
+ )
89
+
90
+ # Upload the file content to S3
91
+ s3_put_url = artifact["s3_put_url"]
92
+ with path.open("rb") as content:
93
+ resp = requests.put(
94
+ s3_put_url,
95
+ data=content,
96
+ headers={"Content-Type": content_type},
97
+ timeout=REQUEST_TIMEOUT,
98
+ verify=should_verify_cert(s3_put_url),
99
+ )
100
+ resp.raise_for_status()
@@ -3,16 +3,15 @@ BaseWorker methods for training.
3
3
  """
4
4
 
5
5
  import functools
6
+ from collections.abc import Generator
6
7
  from contextlib import contextmanager
7
8
  from pathlib import Path
8
9
  from typing import NewType
9
10
  from uuid import UUID
10
11
 
11
- import requests
12
-
13
- from arkindex.exceptions import ErrorResponse
14
12
  from arkindex_worker import logger
15
13
  from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
14
+ from teklia_toolbox.uploads import MultipartUpload
16
15
 
17
16
  DirPath = NewType("DirPath", Path)
18
17
  """Path to a directory"""
@@ -25,23 +24,21 @@ FileSize = NewType("FileSize", int)
25
24
 
26
25
 
27
26
  @contextmanager
28
- def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
27
+ def create_archive(path: DirPath) -> Generator[tuple[Path, FileSize, Hash]]:
29
28
  """
30
29
  Create a tar archive from the files at the given location then compress it to a zst archive.
31
30
 
32
- Yield its location, its hash, its size and its content's hash.
31
+ Yield its location, its size and its hash.
33
32
 
34
33
  :param path: Create a compressed tar archive from the files
35
- :returns: The location of the created archive, its hash, its size and its content's hash
34
+ :returns: The location of the created archive, its size and its hash
36
35
  """
37
36
  assert path.is_dir(), "create_archive needs a directory"
38
37
 
39
- zst_descriptor, zst_archive, archive_hash, content_hash = create_tar_zst_archive(
40
- path
41
- )
38
+ zst_descriptor, zst_archive, archive_hash = create_tar_zst_archive(path)
42
39
 
43
40
  # Get content hash, archive size and hash
44
- yield zst_archive, content_hash, zst_archive.stat().st_size, archive_hash
41
+ yield zst_archive, zst_archive.stat().st_size, archive_hash
45
42
 
46
43
  # Remove the zst archive
47
44
  close_delete_file(zst_descriptor, zst_archive)
@@ -112,62 +109,48 @@ class TrainingMixin:
112
109
  """
113
110
 
114
111
  configuration = configuration or {}
115
- if not self.model_version:
116
- self.create_model_version(
117
- model_id=model_id,
118
- tag=tag,
119
- description=description,
120
- configuration=configuration,
121
- parent=parent,
122
- )
123
-
124
- elif tag or description or configuration or parent:
125
- assert self.model_version.get("model_id") == model_id, (
126
- "Given `model_id` does not match the current model version"
127
- )
128
- # If any attribute field has been defined, PATCH the current model version
129
- self.update_model_version(
130
- tag=tag,
131
- description=description,
132
- configuration=configuration,
133
- parent=parent,
134
- )
135
112
 
136
113
  # Create the zst archive, get its hash and size
137
- # Validate the model version
138
114
  with create_archive(path=model_path) as (
139
115
  path_to_archive,
140
- hash,
141
116
  size,
142
- archive_hash,
117
+ hash,
143
118
  ):
144
- # Create a new model version with hash and size
145
- self.upload_to_s3(archive_path=path_to_archive)
146
-
147
- current_version_id = self.model_version["id"]
148
- # Mark the model as valid
149
- self.validate_model_version(
150
- size=size,
151
- hash=hash,
152
- archive_hash=archive_hash,
153
- )
154
- if self.model_version["id"] != current_version_id and (
155
- tag or description or configuration or parent
156
- ):
157
- logger.warning(
158
- "Updating the existing available model version with the given attributes."
119
+ # Update an existing model version with hash, size and any other defined attribute
120
+ if self.model_version:
121
+ assert self.model_version.get("model_id") == model_id, (
122
+ "Given `model_id` does not match the current model version"
159
123
  )
160
124
  self.update_model_version(
125
+ size=size,
126
+ archive_hash=hash,
127
+ tag=tag,
128
+ description=description,
129
+ configuration=configuration,
130
+ parent=parent,
131
+ )
132
+
133
+ # Create a new model version with hash and size
134
+ else:
135
+ self.create_model_version(
136
+ model_id=model_id,
137
+ size=size,
138
+ archive_hash=hash,
161
139
  tag=tag,
162
140
  description=description,
163
141
  configuration=configuration,
164
142
  parent=parent,
165
143
  )
166
144
 
145
+ # Upload the archive in multiple parts (supports huge files)
146
+ self.upload_to_s3(path_to_archive)
147
+
167
148
  @skip_if_read_only
168
149
  def create_model_version(
169
150
  self,
170
151
  model_id: str,
152
+ size: FileSize,
153
+ archive_hash: Hash,
171
154
  tag: str | None = None,
172
155
  description: str | None = None,
173
156
  configuration: dict | None = None,
@@ -177,6 +160,8 @@ class TrainingMixin:
177
160
  Create a new version of the specified model with its base attributes.
178
161
  Once successfully created, the model version is accessible via `self.model_version`.
179
162
 
163
+ :param size: Size of uploaded archive
164
+ :param hash: MD5 hash of the uploaded archive
180
165
  :param tag: Tag of the model version
181
166
  :param description: Description of the model version
182
167
  :param configuration: Configuration of the model version
@@ -189,6 +174,8 @@ class TrainingMixin:
189
174
  "CreateModelVersion",
190
175
  id=model_id,
191
176
  body=build_clean_payload(
177
+ size=size,
178
+ archive_hash=archive_hash,
192
179
  tag=tag,
193
180
  description=description,
194
181
  configuration=configuration,
@@ -197,12 +184,14 @@ class TrainingMixin:
197
184
  )
198
185
 
199
186
  logger.info(
200
- f"Model version ({self.model_version['id']}) was successfully created"
187
+ f"Model version ({self.model_version['id']}) was successfully created."
201
188
  )
202
189
 
203
190
  @skip_if_read_only
204
191
  def update_model_version(
205
192
  self,
193
+ size: FileSize,
194
+ archive_hash: Hash,
206
195
  tag: str | None = None,
207
196
  description: str | None = None,
208
197
  configuration: dict | None = None,
@@ -211,6 +200,8 @@ class TrainingMixin:
211
200
  """
212
201
  Update the current model version with the given attributes.
213
202
 
203
+ :param size: Size of uploaded archive
204
+ :param hash: MD5 hash of the uploaded archive
214
205
  :param tag: Tag of the model version
215
206
  :param description: Description of the model version
216
207
  :param configuration: Configuration of the model version
@@ -221,6 +212,8 @@ class TrainingMixin:
221
212
  "UpdateModelVersion",
222
213
  id=self.model_version["id"],
223
214
  body=build_clean_payload(
215
+ size=size,
216
+ archive_hash=archive_hash,
224
217
  tag=tag,
225
218
  description=description,
226
219
  configuration=configuration,
@@ -228,93 +221,34 @@ class TrainingMixin:
228
221
  ),
229
222
  )
230
223
  logger.info(
231
- f"Model version ({self.model_version['id']}) was successfully updated"
224
+ f"Model version ({self.model_version['id']}) was successfully updated."
232
225
  )
233
226
 
234
227
  @skip_if_read_only
235
228
  def upload_to_s3(self, archive_path: Path) -> None:
236
229
  """
237
- Upload the archive of the model's files to an Amazon s3 compatible storage
230
+ Upload the archive of the model's files to an Amazon s3 compatible storage in multiple parts
238
231
  """
239
-
240
232
  assert self.model_version, (
241
233
  "You must create the model version before uploading an archive."
242
234
  )
243
235
  assert self.model_version["state"] != "Available", (
244
- "The model is already marked as available."
236
+ "The model version is already marked as available."
245
237
  )
246
238
 
247
- s3_put_url = self.model_version.get("s3_put_url")
248
- assert s3_put_url, (
249
- "S3 PUT URL is not set, please ensure you have the right to validate a model version."
250
- )
251
-
252
- logger.info("Uploading to s3...")
253
- # Upload the archive on s3
254
- with archive_path.open("rb") as archive:
255
- r = requests.put(
256
- url=s3_put_url,
257
- data=archive,
258
- headers={"Content-Type": "application/zstd"},
259
- )
260
- r.raise_for_status()
261
-
262
- @skip_if_read_only
263
- def validate_model_version(
264
- self,
265
- hash: str,
266
- size: int,
267
- archive_hash: str,
268
- ):
269
- """
270
- Sets the model version as `Available`, once its archive has been uploaded to S3.
271
-
272
- :param hash: MD5 hash of the files contained in the archive
273
- :param size: The size of the uploaded archive
274
- :param archive_hash: MD5 hash of the uploaded archive
275
- """
276
- assert self.model_version, (
277
- "You must create the model version and upload its archive before validating it."
239
+ multipart = MultipartUpload(
240
+ client=self.api_client,
241
+ file_path=archive_path,
242
+ object_type="model_version",
243
+ object_id=str(self.model_version["id"]),
278
244
  )
279
245
  try:
280
- self.model_version = self.api_client.request(
281
- "PartialUpdateModelVersion",
282
- id=self.model_version["id"],
283
- body={
284
- "state": "available",
285
- "size": size,
286
- "hash": hash,
287
- "archive_hash": archive_hash,
288
- },
246
+ multipart.upload()
247
+ multipart.complete()
248
+ except Exception:
249
+ multipart.abort()
250
+ raise
251
+ else:
252
+ logger.info(
253
+ f"Model version ({self.model_version['id']}) archive was successfully uploaded and is now available."
289
254
  )
290
- except ErrorResponse as e:
291
- model_version = e.content
292
- if not model_version or "id" not in model_version:
293
- raise e
294
-
295
- logger.warning(
296
- f"An available model version exists with hash {hash}, using it instead of the pending version."
297
- )
298
- pending_version_id = self.model_version["id"]
299
- logger.warning("Removing the pending model version.")
300
- try:
301
- self.api_client.request("DestroyModelVersion", id=pending_version_id)
302
- except ErrorResponse as e:
303
- msg = getattr(e, "content", str(e))
304
- logger.error(
305
- f"An error occurred removing the pending version {pending_version_id}: {msg}."
306
- )
307
-
308
- logger.info("Retrieving the existing model version.")
309
- existing_version_id = model_version["id"].pop()
310
- try:
311
- self.model_version = self.api_client.request(
312
- "RetrieveModelVersion", id=existing_version_id
313
- )
314
- except ErrorResponse as e:
315
- logger.error(
316
- f"An error occurred retrieving the existing version {existing_version_id}: {e.status_code} - {e.content}."
317
- )
318
- raise
319
-
320
- logger.info(f"Model version {self.model_version['id']} is now available.")
@@ -435,34 +435,6 @@ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
435
435
  ]
436
436
 
437
437
 
438
- def test_run_initial_dataset_state_error(
439
- mocker, responses, caplog, mock_dataset_worker, default_dataset
440
- ):
441
- default_dataset.state = DatasetState.Building.value
442
- mocker.patch(
443
- "arkindex_worker.worker.DatasetWorker.list_sets",
444
- return_value=[Set(name="train", dataset=default_dataset)],
445
- )
446
-
447
- with pytest.raises(SystemExit):
448
- mock_dataset_worker.run()
449
-
450
- assert len(responses.calls) == len(BASE_API_CALLS) * 2
451
- assert [
452
- (call.request.method, call.request.url) for call in responses.calls
453
- ] == BASE_API_CALLS * 2
454
-
455
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
456
- (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
457
- (logging.INFO, "Modern configuration is not available"),
458
- (
459
- logging.WARNING,
460
- "Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
461
- ),
462
- (logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
463
- ]
464
-
465
-
466
438
  def test_run_download_dataset_artifact_api_error(
467
439
  mocker,
468
440
  tmp_path,
@@ -570,16 +542,18 @@ def test_run_no_downloaded_dataset_artifact_error(
570
542
  ]
571
543
 
572
544
 
545
+ @pytest.mark.parametrize("dataset_state", DatasetState)
573
546
  def test_run(
574
547
  mocker,
575
548
  tmp_path,
576
549
  responses,
577
550
  caplog,
551
+ dataset_state,
578
552
  mock_dataset_worker,
579
553
  default_dataset,
580
554
  default_artifact,
581
555
  ):
582
- default_dataset.state = DatasetState.Complete.value
556
+ default_dataset.state = dataset_state.value
583
557
  mocker.patch(
584
558
  "arkindex_worker.worker.DatasetWorker.list_sets",
585
559
  return_value=[Set(name="train", dataset=default_dataset)],
@@ -590,55 +564,68 @@ def test_run(
590
564
  )
591
565
  mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
592
566
 
593
- archive_path = (
594
- FIXTURES_DIR
595
- / "extract_parent_archives"
596
- / "first_parent"
597
- / "arkindex_data.tar.zst"
598
- )
599
- responses.add(
600
- responses.GET,
601
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
602
- status=200,
603
- json=[default_artifact],
604
- )
605
- responses.add(
606
- responses.GET,
607
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
608
- status=200,
609
- body=archive_path.read_bytes(),
610
- content_type="application/zstd",
611
- )
567
+ if dataset_state == DatasetState.Complete:
568
+ archive_path = (
569
+ FIXTURES_DIR
570
+ / "extract_parent_archives"
571
+ / "first_parent"
572
+ / "arkindex_data.tar.zst"
573
+ )
574
+ responses.add(
575
+ responses.GET,
576
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
577
+ status=200,
578
+ json=[default_artifact],
579
+ )
580
+ responses.add(
581
+ responses.GET,
582
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
583
+ status=200,
584
+ body=archive_path.read_bytes(),
585
+ content_type="application/zstd",
586
+ )
612
587
 
613
588
  mock_dataset_worker.run()
614
589
 
615
590
  assert mock_process.call_count == 1
616
591
 
617
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 2
592
+ # We only download the dataset archive when it is Complete
593
+ extra_calls = []
594
+ if dataset_state == DatasetState.Complete:
595
+ extra_calls = [
596
+ (
597
+ "GET",
598
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
599
+ ),
600
+ (
601
+ "GET",
602
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
603
+ ),
604
+ ]
605
+
606
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
618
607
  assert [
619
608
  (call.request.method, call.request.url) for call in responses.calls
620
- ] == BASE_API_CALLS * 2 + [
621
- (
622
- "GET",
623
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
624
- ),
625
- (
626
- "GET",
627
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
628
- ),
629
- ]
609
+ ] == BASE_API_CALLS * 2 + extra_calls
630
610
 
631
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
611
+ logs = [
632
612
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
633
613
  (logging.INFO, "Modern configuration is not available"),
634
614
  (
635
- logging.INFO,
636
- "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
615
+ logging.WARNING,
616
+ f"The dataset Dataset (dataset_id) has its state set to `{dataset_state.value}`, its archive will not be downloaded",
637
617
  ),
638
- (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
639
618
  (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
640
619
  (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
641
620
  ]
621
+ if dataset_state == DatasetState.Complete:
622
+ logs[2] = (
623
+ logging.INFO,
624
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
625
+ )
626
+ logs.insert(3, (logging.INFO, "Downloading artifact for Dataset (dataset_id)"))
627
+
628
+ assert [(level, message) for _, level, message in caplog.record_tuples] == logs
642
629
 
643
630
 
644
631
  def test_run_read_only(
@@ -1,6 +1,9 @@
1
+ import tempfile
1
2
  import uuid
3
+ from pathlib import Path
2
4
 
3
5
  import pytest
6
+ from requests import HTTPError
4
7
 
5
8
  from arkindex.exceptions import ErrorResponse
6
9
  from arkindex_worker.models import Artifact
@@ -196,3 +199,112 @@ def test_download_artifact(
196
199
  ] == BASE_API_CALLS + [
197
200
  ("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.tar.zst"),
198
201
  ]
202
+
203
+
204
+ @pytest.mark.parametrize(
205
+ ("payload", "error"),
206
+ [
207
+ # Path
208
+ (
209
+ {"path": None},
210
+ "path shouldn't be null, should be a Path and should exist",
211
+ ),
212
+ (
213
+ {"path": "not path type"},
214
+ "path shouldn't be null, should be a Path and should exist",
215
+ ),
216
+ (
217
+ {"path": Path("i_do_no_exist.oops")},
218
+ "path shouldn't be null, should be a Path and should exist",
219
+ ),
220
+ ],
221
+ )
222
+ def test_upload_artifact_wrong_param_path(mock_dataset_worker, payload, error):
223
+ with pytest.raises(AssertionError, match=error):
224
+ mock_dataset_worker.upload_artifact(**payload)
225
+
226
+
227
+ @pytest.fixture
228
+ def tmp_file(mock_dataset_worker):
229
+ with tempfile.NamedTemporaryFile(
230
+ mode="w", suffix=".txt", dir=mock_dataset_worker.work_dir
231
+ ) as file:
232
+ file.write("Some content...")
233
+ file.seek(0)
234
+
235
+ yield Path(file.name)
236
+
237
+
238
+ def test_upload_artifact_api_error(responses, mock_dataset_worker, tmp_file):
239
+ responses.add(
240
+ responses.POST,
241
+ "http://testserver/api/v1/task/my_task/artifacts/",
242
+ status=418,
243
+ )
244
+
245
+ with pytest.raises(ErrorResponse):
246
+ mock_dataset_worker.upload_artifact(path=tmp_file)
247
+
248
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
249
+ assert [
250
+ (call.request.method, call.request.url) for call in responses.calls
251
+ ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/task/my_task/artifacts/")]
252
+
253
+
254
+ def test_upload_artifact_s3_upload_error(
255
+ responses,
256
+ mock_dataset_worker,
257
+ tmp_file,
258
+ ):
259
+ responses.add(
260
+ responses.POST,
261
+ "http://testserver/api/v1/task/my_task/artifacts/",
262
+ json={
263
+ "id": "11111111-1111-1111-1111-111111111111",
264
+ "path": tmp_file.name,
265
+ "size": 15,
266
+ "content_type": "text/plain",
267
+ "s3_put_url": "http://example.com/oops.txt",
268
+ },
269
+ )
270
+ responses.add(responses.PUT, "http://example.com/oops.txt", status=500)
271
+
272
+ with pytest.raises(HTTPError):
273
+ mock_dataset_worker.upload_artifact(path=tmp_file)
274
+
275
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
276
+ assert [
277
+ (call.request.method, call.request.url) for call in responses.calls
278
+ ] == BASE_API_CALLS + [
279
+ ("POST", "http://testserver/api/v1/task/my_task/artifacts/"),
280
+ ("PUT", "http://example.com/oops.txt"),
281
+ ]
282
+
283
+
284
+ def test_upload_artifact(
285
+ responses,
286
+ mock_dataset_worker,
287
+ tmp_file,
288
+ ):
289
+ responses.add(
290
+ responses.POST,
291
+ "http://testserver/api/v1/task/my_task/artifacts/",
292
+ json={
293
+ "id": "11111111-1111-1111-1111-111111111111",
294
+ "path": tmp_file.name,
295
+ "size": 15,
296
+ "content_type": "text/plain",
297
+ "s3_put_url": "http://example.com/test.txt",
298
+ },
299
+ )
300
+ responses.add(responses.PUT, "http://example.com/test.txt")
301
+
302
+ mock_dataset_worker.upload_artifact(path=tmp_file)
303
+
304
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
305
+ assert [
306
+ (call.request.method, call.request.url) for call in responses.calls
307
+ ] == BASE_API_CALLS + [
308
+ ("POST", "http://testserver/api/v1/task/my_task/artifacts/"),
309
+ ("PUT", "http://example.com/test.txt"),
310
+ ]
@@ -27,16 +27,16 @@ def default_model_version():
27
27
  return {
28
28
  "id": "model_version_id",
29
29
  "model_id": "model_id",
30
- "state": "created",
31
30
  "parent": "42" * 16,
32
- "tag": "A simple tag",
33
31
  "description": "A description",
32
+ "tag": "A simple tag",
33
+ "state": "created",
34
+ "size": 42,
35
+ "archive_hash": "123456789",
34
36
  "configuration": {"test": "value"},
35
- "s3_url": None,
37
+ "s3_etag": None,
36
38
  "s3_put_url": "http://upload.archive",
37
- "hash": None,
38
- "archive_hash": None,
39
- "size": None,
39
+ "s3_url": None,
40
40
  "created": "2000-01-01T00:00:00Z",
41
41
  }
42
42
 
@@ -46,14 +46,11 @@ def test_create_archive(model_file_dir):
46
46
 
47
47
  with create_archive(path=model_file_dir) as (
48
48
  zst_archive_path,
49
- hash,
50
49
  size,
51
- archive_hash,
50
+ hash,
52
51
  ):
53
52
  assert zst_archive_path.exists(), "The archive was not created"
54
- assert hash == "c5aedde18a768757351068b840c8c8f9", (
55
- "Hash was not properly computed"
56
- )
53
+ assert len(hash) == 32
57
54
  assert 300 < size < 700
58
55
 
59
56
  assert not zst_archive_path.exists(), "Auto removal failed"
@@ -64,37 +61,16 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
64
61
 
65
62
  with create_archive(path=model_file_dir_with_subfolder) as (
66
63
  zst_archive_path,
67
- hash,
68
64
  size,
69
- archive_hash,
65
+ hash,
70
66
  ):
71
67
  assert zst_archive_path.exists(), "The archive was not created"
72
- assert hash == "3e453881404689e6e125144d2db3e605", (
73
- "Hash was not properly computed"
74
- )
68
+ assert len(hash) == 32
75
69
  assert 300 < size < 1500
76
70
 
77
71
  assert not zst_archive_path.exists(), "Auto removal failed"
78
72
 
79
73
 
80
- def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_dir):
81
- s3_endpoint_url = "http://s3.localhost.com"
82
- responses.add_passthru(s3_endpoint_url)
83
- responses.add(responses.PUT, s3_endpoint_url, status=400)
84
-
85
- mock_training_worker.model_version = {
86
- "state": "Created",
87
- "s3_put_url": s3_endpoint_url,
88
- }
89
-
90
- file_path = model_file_dir / "model_file.pth"
91
- with pytest.raises(
92
- Exception,
93
- match="400 Client Error: Bad Request for url: http://s3.localhost.com/",
94
- ):
95
- mock_training_worker.upload_to_s3(file_path)
96
-
97
-
98
74
  @pytest.mark.parametrize(
99
75
  "method",
100
76
  [
@@ -102,7 +78,6 @@ def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_
102
78
  "create_model_version",
103
79
  "update_model_version",
104
80
  "upload_to_s3",
105
- "validate_model_version",
106
81
  ],
107
82
  )
108
83
  def test_training_mixin_read_only(mock_training_worker, method, caplog):
@@ -127,12 +102,16 @@ def test_create_model_version_already_created(mock_training_worker):
127
102
  with pytest.raises(
128
103
  AssertionError, match="A model version has already been created."
129
104
  ):
130
- mock_training_worker.create_model_version(model_id="model_id")
105
+ mock_training_worker.create_model_version(
106
+ model_id="model_id", size=42, archive_hash="123456789"
107
+ )
131
108
 
132
109
 
133
110
  @pytest.mark.parametrize("set_tag", [True, False])
134
111
  def test_create_model_version(mock_training_worker, default_model_version, set_tag):
135
112
  args = {
113
+ "size": 42,
114
+ "archive_hash": "123456789",
136
115
  "parent": "42" * 16,
137
116
  "tag": "A simple tag",
138
117
  "description": "A description",
@@ -154,12 +133,12 @@ def test_create_model_version(mock_training_worker, default_model_version, set_t
154
133
 
155
134
  def test_update_model_version_not_created(mock_training_worker):
156
135
  with pytest.raises(AssertionError, match="No model version has been created yet."):
157
- mock_training_worker.update_model_version()
136
+ mock_training_worker.update_model_version(size=42, archive_hash="123456789")
158
137
 
159
138
 
160
139
  def test_update_model_version(mock_training_worker, default_model_version):
161
140
  mock_training_worker.model_version = default_model_version
162
- args = {"tag": "A new tag"}
141
+ args = {"size": 42, "archive_hash": "123456789", "tag": "A new tag"}
163
142
  new_model_version = {**default_model_version, "tag": "A new tag"}
164
143
  mock_training_worker.api_client.add_response(
165
144
  "UpdateModelVersion",
@@ -169,101 +148,3 @@ def test_update_model_version(mock_training_worker, default_model_version):
169
148
  )
170
149
  mock_training_worker.update_model_version(**args)
171
150
  assert mock_training_worker.model_version == new_model_version
172
-
173
-
174
- def test_validate_model_version_not_created(mock_training_worker):
175
- with pytest.raises(
176
- AssertionError,
177
- match="You must create the model version and upload its archive before validating it.",
178
- ):
179
- mock_training_worker.validate_model_version(hash="a", size=1, archive_hash="b")
180
-
181
-
182
- @pytest.mark.parametrize("deletion_failed", [True, False])
183
- def test_validate_model_version_hash_conflict(
184
- mock_training_worker,
185
- default_model_version,
186
- caplog,
187
- deletion_failed,
188
- ):
189
- mock_training_worker.model_version = {"id": "another_id"}
190
- args = {
191
- "hash": "hash",
192
- "archive_hash": "archive_hash",
193
- "size": 30,
194
- }
195
- mock_training_worker.api_client.add_error_response(
196
- "PartialUpdateModelVersion",
197
- id="another_id",
198
- status_code=409,
199
- body={"state": "available", **args},
200
- content={"id": ["model_version_id"]},
201
- )
202
- if deletion_failed:
203
- mock_training_worker.api_client.add_error_response(
204
- "DestroyModelVersion",
205
- id="another_id",
206
- status_code=403,
207
- content="Not admin",
208
- )
209
- else:
210
- mock_training_worker.api_client.add_response(
211
- "DestroyModelVersion",
212
- id="another_id",
213
- response="No content",
214
- )
215
- mock_training_worker.api_client.add_response(
216
- "RetrieveModelVersion",
217
- id="model_version_id",
218
- response=default_model_version,
219
- )
220
-
221
- mock_training_worker.validate_model_version(**args)
222
- assert mock_training_worker.model_version == default_model_version
223
- error_msg = []
224
- if deletion_failed:
225
- error_msg = [
226
- (
227
- logging.ERROR,
228
- "An error occurred removing the pending version another_id: Not admin.",
229
- )
230
- ]
231
- assert [
232
- (level, message)
233
- for module, level, message in caplog.record_tuples
234
- if module == "arkindex_worker"
235
- ] == [
236
- (
237
- logging.WARNING,
238
- "An available model version exists with hash hash, using it instead of the pending version.",
239
- ),
240
- (logging.WARNING, "Removing the pending model version."),
241
- *error_msg,
242
- (logging.INFO, "Retrieving the existing model version."),
243
- (logging.INFO, "Model version model_version_id is now available."),
244
- ]
245
-
246
-
247
- def test_validate_model_version(mock_training_worker, default_model_version, caplog):
248
- mock_training_worker.model_version = {"id": "model_version_id"}
249
- args = {
250
- "hash": "hash",
251
- "archive_hash": "archive_hash",
252
- "size": 30,
253
- }
254
- mock_training_worker.api_client.add_response(
255
- "PartialUpdateModelVersion",
256
- id="model_version_id",
257
- body={"state": "available", **args},
258
- response=default_model_version,
259
- )
260
-
261
- mock_training_worker.validate_model_version(**args)
262
- assert mock_training_worker.model_version == default_model_version
263
- assert [
264
- (level, message)
265
- for module, level, message in caplog.record_tuples
266
- if module == "arkindex_worker"
267
- ] == [
268
- (logging.INFO, "Model version model_version_id is now available."),
269
- ]
@@ -79,3 +79,42 @@ def test_with_secrets(mock_base_worker_modern_conf, responses):
79
79
  assert mock_base_worker_modern_conf.secrets == {
80
80
  "a_secret": "My super duper secret value"
81
81
  }
82
+
83
+
84
+ def test_with_secrets_ce(mock_base_worker_modern_conf, responses, monkeypatch):
85
+ # Provide the full configuration directly from the worker run
86
+ responses.add(
87
+ responses.GET,
88
+ "http://testserver/api/v1/workers/runs/56785678-5678-5678-5678-567856785678/configuration/",
89
+ status=200,
90
+ json={
91
+ "configuration": [
92
+ {"key": "some_key", "value": "test", "secret": False},
93
+ {
94
+ "key": "a_secret",
95
+ "value": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
96
+ "secret": True,
97
+ },
98
+ ]
99
+ },
100
+ )
101
+
102
+ # Remove the RetrieveSecret endpoint to simulate Arkindex CE
103
+ monkeypatch.delitem(
104
+ mock_base_worker_modern_conf.api_client.document.links, "RetrieveSecret"
105
+ )
106
+
107
+ mock_base_worker_modern_conf.configure()
108
+
109
+ assert mock_base_worker_modern_conf.config == {
110
+ "a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
111
+ "some_key": "test",
112
+ }
113
+ assert (
114
+ mock_base_worker_modern_conf.user_configuration
115
+ == mock_base_worker_modern_conf.config
116
+ )
117
+ assert mock_base_worker_modern_conf.secrets == {
118
+ # The value is used directly instead of treated as a secret name
119
+ "a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
120
+ }