arkindex-base-worker 0.5.2a2__tar.gz → 0.5.2b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/PKG-INFO +2 -2
  2. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/PKG-INFO +2 -2
  3. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/requires.txt +1 -1
  4. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/utils.py +7 -7
  5. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/base.py +10 -10
  6. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/training.py +58 -124
  7. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/pyproject.toml +2 -2
  8. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_training.py +17 -136
  9. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_modern_config.py +39 -0
  10. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/LICENSE +0 -0
  11. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/README.md +0 -0
  12. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/SOURCES.txt +0 -0
  13. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/dependency_links.txt +0 -0
  14. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/top_level.txt +0 -0
  15. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/__init__.py +0 -0
  16. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/cache.py +0 -0
  17. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/image.py +0 -0
  18. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/models.py +0 -0
  19. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/__init__.py +0 -0
  20. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/classification.py +0 -0
  21. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/corpus.py +0 -0
  22. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/dataset.py +0 -0
  23. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/element.py +0 -0
  24. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/entity.py +0 -0
  25. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/image.py +0 -0
  26. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/metadata.py +0 -0
  27. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/process.py +0 -0
  28. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/task.py +0 -0
  29. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/transcription.py +0 -0
  30. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/examples/standalone/python/worker.py +0 -0
  31. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/examples/tooled/python/worker.py +0 -0
  32. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/hooks/pre_gen_project.py +0 -0
  33. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/setup.cfg +0 -0
  34. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/__init__.py +0 -0
  35. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/conftest.py +0 -0
  36. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_base_worker.py +0 -0
  37. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_cache.py +0 -0
  38. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_dataset_worker.py +0 -0
  39. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_element.py +0 -0
  40. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/__init__.py +0 -0
  41. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_classification.py +0 -0
  42. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_cli.py +0 -0
  43. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_corpus.py +0 -0
  44. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_dataset.py +0 -0
  45. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element.py +0 -0
  46. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_create_multiple.py +0 -0
  47. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_create_single.py +0 -0
  48. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_list_children.py +0 -0
  49. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_list_parents.py +0 -0
  50. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_entity.py +0 -0
  51. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_image.py +0 -0
  52. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_metadata.py +0 -0
  53. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_process.py +0 -0
  54. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_task.py +0 -0
  55. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_create.py +0 -0
  56. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_create_with_elements.py +0 -0
  57. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_list.py +0 -0
  58. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_worker.py +0 -0
  59. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_image.py +0 -0
  60. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_merge.py +0 -0
  61. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/tests/test_utils.py +0 -0
  62. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/__init__.py +0 -0
  63. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/conftest.py +0 -0
  64. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/test_worker.py +0 -0
  65. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/worker-demo/worker_demo/__init__.py +0 -0
  66. {arkindex_base_worker-0.5.2a2 → arkindex_base_worker-0.5.2b1}/worker-demo/worker_demo/worker.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arkindex-base-worker
3
- Version: 0.5.2a2
3
+ Version: 0.5.2b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -25,7 +25,7 @@ Requires-Dist: Pillow==11.3.0
25
25
  Requires-Dist: python-gnupg==0.5.6
26
26
  Requires-Dist: python-magic==0.4.27
27
27
  Requires-Dist: shapely==2.0.6
28
- Requires-Dist: teklia-toolbox==0.1.12
28
+ Requires-Dist: teklia-toolbox==0.1.13
29
29
  Requires-Dist: zstandard==0.25.0
30
30
  Provides-Extra: tests
31
31
  Requires-Dist: pytest-mock==3.15.1; extra == "tests"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arkindex-base-worker
3
- Version: 0.5.2a2
3
+ Version: 0.5.2b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -25,7 +25,7 @@ Requires-Dist: Pillow==11.3.0
25
25
  Requires-Dist: python-gnupg==0.5.6
26
26
  Requires-Dist: python-magic==0.4.27
27
27
  Requires-Dist: shapely==2.0.6
28
- Requires-Dist: teklia-toolbox==0.1.12
28
+ Requires-Dist: teklia-toolbox==0.1.13
29
29
  Requires-Dist: zstandard==0.25.0
30
30
  Provides-Extra: tests
31
31
  Requires-Dist: pytest-mock==3.15.1; extra == "tests"
@@ -4,7 +4,7 @@ Pillow==11.3.0
4
4
  python-gnupg==0.5.6
5
5
  python-magic==0.4.27
6
6
  shapely==2.0.6
7
- teklia-toolbox==0.1.12
7
+ teklia-toolbox==0.1.13
8
8
  zstandard==0.25.0
9
9
 
10
10
  [tests]
@@ -163,12 +163,12 @@ def zstd_compress(
163
163
 
164
164
  def create_tar_archive(
165
165
  path: Path, destination: Path | None = None
166
- ) -> tuple[int | None, Path, str]:
166
+ ) -> tuple[int | None, Path]:
167
167
  """Create a tar archive using the content at specified location.
168
168
 
169
169
  :param path: Path to the file to archive
170
170
  :param destination: Optional path for the created TAR archive. A tempfile will be created if this is omitted.
171
- :return: The file descriptor (if one was created) and path to the TAR archive, hash of its content.
171
+ :return: The file descriptor (if one was created) and path to the TAR archive.
172
172
  """
173
173
  # Parse destination and create a tmpfile if none was specified
174
174
  file_d, destination = (
@@ -204,26 +204,26 @@ def create_tar_archive(
204
204
  with file_path.open("rb") as file_data:
205
205
  for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
206
206
  content_hasher.update(chunk)
207
- return file_d, destination, content_hasher.hexdigest()
207
+ return file_d, destination
208
208
 
209
209
 
210
210
  def create_tar_zst_archive(
211
211
  source: Path, destination: Path | None = None
212
- ) -> tuple[int | None, Path, str, str]:
212
+ ) -> tuple[int | None, Path, str]:
213
213
  """Helper to create a TAR+ZST archive from a source folder.
214
214
 
215
215
  :param source: Path to the folder whose content should be archived.
216
216
  :param destination: Path to the created archive, defaults to None. If unspecified, a temporary file will be created.
217
- :return: The file descriptor of the created tempfile (if one was created), path to the archive, its hash and the hash of the tar archive's content.
217
+ :return: The file descriptor of the created tempfile (if one was created), path to the archive and its hash.
218
218
  """
219
219
  # Create tar archive
220
- tar_fd, tar_archive, tar_hash = create_tar_archive(source)
220
+ tar_fd, tar_archive = create_tar_archive(source)
221
221
 
222
222
  zst_fd, zst_archive, zst_hash = zstd_compress(tar_archive, destination)
223
223
 
224
224
  close_delete_file(tar_fd, tar_archive)
225
225
 
226
- return zst_fd, zst_archive, zst_hash, tar_hash
226
+ return zst_fd, zst_archive, zst_hash
227
227
 
228
228
 
229
229
  def create_zip_archive(source: Path, destination: Path | None = None) -> Path:
@@ -15,7 +15,7 @@ import gnupg
15
15
  import yaml
16
16
 
17
17
  from arkindex import options_from_env
18
- from arkindex.exceptions import ClientError, ErrorResponse
18
+ from arkindex.exceptions import ErrorResponse
19
19
  from arkindex_worker import logger
20
20
  from arkindex_worker.cache import (
21
21
  check_version,
@@ -261,6 +261,10 @@ class BaseWorker:
261
261
 
262
262
  logger.info(f"Loaded {worker_run['summary']} from API")
263
263
 
264
+ # The `RetrieveSecret` endpoint is only available in Arkindex EE.
265
+ # In CE, the values of `secret` fields should be used directly without calling `RetrieveSecret`.
266
+ can_retrieve_secret = "RetrieveSecret" in self.api_client.document.links
267
+
264
268
  def _process_config_item(item: dict) -> tuple[str, Any]:
265
269
  if not item["secret"]:
266
270
  return (item["key"], item["value"])
@@ -270,16 +274,12 @@ class BaseWorker:
270
274
  logger.info(f"Optional secret `{item['key']}` is not set")
271
275
  return (item["key"], None)
272
276
 
273
- # Load secret, only available in Arkindex EE
274
- try:
275
- secret = self.load_secret(Path(item["value"]))
276
- except ClientError as e:
277
- logger.error(
278
- f"Failed to retrieve the secret {item['value']}, probably an Arkindex Community Edition: {e}"
279
- )
280
- return (item["key"], None)
277
+ value = item["value"]
278
+ # Load secret when `RetrieveSecret` is available
279
+ if can_retrieve_secret:
280
+ value = self.load_secret(Path(item["value"]))
281
281
 
282
- return (item["key"], secret)
282
+ return (item["key"], value)
283
283
 
284
284
  # Load model version configuration when available
285
285
  # Workers will use model version ID and details to download the model
@@ -3,16 +3,15 @@ BaseWorker methods for training.
3
3
  """
4
4
 
5
5
  import functools
6
+ from collections.abc import Generator
6
7
  from contextlib import contextmanager
7
8
  from pathlib import Path
8
9
  from typing import NewType
9
10
  from uuid import UUID
10
11
 
11
- import requests
12
-
13
- from arkindex.exceptions import ErrorResponse
14
12
  from arkindex_worker import logger
15
13
  from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
14
+ from teklia_toolbox.uploads import MultipartUpload
16
15
 
17
16
  DirPath = NewType("DirPath", Path)
18
17
  """Path to a directory"""
@@ -25,23 +24,21 @@ FileSize = NewType("FileSize", int)
25
24
 
26
25
 
27
26
  @contextmanager
28
- def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
27
+ def create_archive(path: DirPath) -> Generator[tuple[Path, FileSize, Hash]]:
29
28
  """
30
29
  Create a tar archive from the files at the given location then compress it to a zst archive.
31
30
 
32
- Yield its location, its hash, its size and its content's hash.
31
+ Yield its location, its size and its hash.
33
32
 
34
33
  :param path: Create a compressed tar archive from the files
35
- :returns: The location of the created archive, its hash, its size and its content's hash
34
+ :returns: The location of the created archive, its size and its hash
36
35
  """
37
36
  assert path.is_dir(), "create_archive needs a directory"
38
37
 
39
- zst_descriptor, zst_archive, archive_hash, content_hash = create_tar_zst_archive(
40
- path
41
- )
38
+ zst_descriptor, zst_archive, archive_hash = create_tar_zst_archive(path)
42
39
 
43
40
  # Get content hash, archive size and hash
44
- yield zst_archive, content_hash, zst_archive.stat().st_size, archive_hash
41
+ yield zst_archive, zst_archive.stat().st_size, archive_hash
45
42
 
46
43
  # Remove the zst archive
47
44
  close_delete_file(zst_descriptor, zst_archive)
@@ -112,62 +109,48 @@ class TrainingMixin:
112
109
  """
113
110
 
114
111
  configuration = configuration or {}
115
- if not self.model_version:
116
- self.create_model_version(
117
- model_id=model_id,
118
- tag=tag,
119
- description=description,
120
- configuration=configuration,
121
- parent=parent,
122
- )
123
-
124
- elif tag or description or configuration or parent:
125
- assert self.model_version.get("model_id") == model_id, (
126
- "Given `model_id` does not match the current model version"
127
- )
128
- # If any attribute field has been defined, PATCH the current model version
129
- self.update_model_version(
130
- tag=tag,
131
- description=description,
132
- configuration=configuration,
133
- parent=parent,
134
- )
135
112
 
136
113
  # Create the zst archive, get its hash and size
137
- # Validate the model version
138
114
  with create_archive(path=model_path) as (
139
115
  path_to_archive,
140
- hash,
141
116
  size,
142
- archive_hash,
117
+ hash,
143
118
  ):
144
- # Create a new model version with hash and size
145
- self.upload_to_s3(archive_path=path_to_archive)
146
-
147
- current_version_id = self.model_version["id"]
148
- # Mark the model as valid
149
- self.validate_model_version(
150
- size=size,
151
- hash=hash,
152
- archive_hash=archive_hash,
153
- )
154
- if self.model_version["id"] != current_version_id and (
155
- tag or description or configuration or parent
156
- ):
157
- logger.warning(
158
- "Updating the existing available model version with the given attributes."
119
+ # Update an existing model version with hash, size and any other defined attribute
120
+ if self.model_version:
121
+ assert self.model_version.get("model_id") == model_id, (
122
+ "Given `model_id` does not match the current model version"
159
123
  )
160
124
  self.update_model_version(
125
+ size=size,
126
+ archive_hash=hash,
127
+ tag=tag,
128
+ description=description,
129
+ configuration=configuration,
130
+ parent=parent,
131
+ )
132
+
133
+ # Create a new model version with hash and size
134
+ else:
135
+ self.create_model_version(
136
+ model_id=model_id,
137
+ size=size,
138
+ archive_hash=hash,
161
139
  tag=tag,
162
140
  description=description,
163
141
  configuration=configuration,
164
142
  parent=parent,
165
143
  )
166
144
 
145
+ # Upload the archive in multiple parts (supports huge files)
146
+ self.upload_to_s3(path_to_archive)
147
+
167
148
  @skip_if_read_only
168
149
  def create_model_version(
169
150
  self,
170
151
  model_id: str,
152
+ size: FileSize,
153
+ archive_hash: Hash,
171
154
  tag: str | None = None,
172
155
  description: str | None = None,
173
156
  configuration: dict | None = None,
@@ -177,6 +160,8 @@ class TrainingMixin:
177
160
  Create a new version of the specified model with its base attributes.
178
161
  Once successfully created, the model version is accessible via `self.model_version`.
179
162
 
163
+ :param size: Size of uploaded archive
164
+ :param hash: MD5 hash of the uploaded archive
180
165
  :param tag: Tag of the model version
181
166
  :param description: Description of the model version
182
167
  :param configuration: Configuration of the model version
@@ -189,6 +174,8 @@ class TrainingMixin:
189
174
  "CreateModelVersion",
190
175
  id=model_id,
191
176
  body=build_clean_payload(
177
+ size=size,
178
+ archive_hash=archive_hash,
192
179
  tag=tag,
193
180
  description=description,
194
181
  configuration=configuration,
@@ -197,12 +184,14 @@ class TrainingMixin:
197
184
  )
198
185
 
199
186
  logger.info(
200
- f"Model version ({self.model_version['id']}) was successfully created"
187
+ f"Model version ({self.model_version['id']}) was successfully created."
201
188
  )
202
189
 
203
190
  @skip_if_read_only
204
191
  def update_model_version(
205
192
  self,
193
+ size: FileSize,
194
+ archive_hash: Hash,
206
195
  tag: str | None = None,
207
196
  description: str | None = None,
208
197
  configuration: dict | None = None,
@@ -211,6 +200,8 @@ class TrainingMixin:
211
200
  """
212
201
  Update the current model version with the given attributes.
213
202
 
203
+ :param size: Size of uploaded archive
204
+ :param hash: MD5 hash of the uploaded archive
214
205
  :param tag: Tag of the model version
215
206
  :param description: Description of the model version
216
207
  :param configuration: Configuration of the model version
@@ -221,6 +212,8 @@ class TrainingMixin:
221
212
  "UpdateModelVersion",
222
213
  id=self.model_version["id"],
223
214
  body=build_clean_payload(
215
+ size=size,
216
+ archive_hash=archive_hash,
224
217
  tag=tag,
225
218
  description=description,
226
219
  configuration=configuration,
@@ -228,93 +221,34 @@ class TrainingMixin:
228
221
  ),
229
222
  )
230
223
  logger.info(
231
- f"Model version ({self.model_version['id']}) was successfully updated"
224
+ f"Model version ({self.model_version['id']}) was successfully updated."
232
225
  )
233
226
 
234
227
  @skip_if_read_only
235
228
  def upload_to_s3(self, archive_path: Path) -> None:
236
229
  """
237
- Upload the archive of the model's files to an Amazon s3 compatible storage
230
+ Upload the archive of the model's files to an Amazon s3 compatible storage in multiple parts
238
231
  """
239
-
240
232
  assert self.model_version, (
241
233
  "You must create the model version before uploading an archive."
242
234
  )
243
235
  assert self.model_version["state"] != "Available", (
244
- "The model is already marked as available."
236
+ "The model version is already marked as available."
245
237
  )
246
238
 
247
- s3_put_url = self.model_version.get("s3_put_url")
248
- assert s3_put_url, (
249
- "S3 PUT URL is not set, please ensure you have the right to validate a model version."
250
- )
251
-
252
- logger.info("Uploading to s3...")
253
- # Upload the archive on s3
254
- with archive_path.open("rb") as archive:
255
- r = requests.put(
256
- url=s3_put_url,
257
- data=archive,
258
- headers={"Content-Type": "application/zstd"},
259
- )
260
- r.raise_for_status()
261
-
262
- @skip_if_read_only
263
- def validate_model_version(
264
- self,
265
- hash: str,
266
- size: int,
267
- archive_hash: str,
268
- ):
269
- """
270
- Sets the model version as `Available`, once its archive has been uploaded to S3.
271
-
272
- :param hash: MD5 hash of the files contained in the archive
273
- :param size: The size of the uploaded archive
274
- :param archive_hash: MD5 hash of the uploaded archive
275
- """
276
- assert self.model_version, (
277
- "You must create the model version and upload its archive before validating it."
239
+ multipart = MultipartUpload(
240
+ client=self.api_client,
241
+ file_path=archive_path,
242
+ object_type="model_version",
243
+ object_id=str(self.model_version["id"]),
278
244
  )
279
245
  try:
280
- self.model_version = self.api_client.request(
281
- "PartialUpdateModelVersion",
282
- id=self.model_version["id"],
283
- body={
284
- "state": "available",
285
- "size": size,
286
- "hash": hash,
287
- "archive_hash": archive_hash,
288
- },
246
+ multipart.upload()
247
+ multipart.complete()
248
+ except Exception:
249
+ multipart.abort()
250
+ raise
251
+ else:
252
+ logger.info(
253
+ f"Model version ({self.model_version['id']}) archive was successfully uploaded and is now available."
289
254
  )
290
- except ErrorResponse as e:
291
- model_version = e.content
292
- if not model_version or "id" not in model_version:
293
- raise e
294
-
295
- logger.warning(
296
- f"An available model version exists with hash {hash}, using it instead of the pending version."
297
- )
298
- pending_version_id = self.model_version["id"]
299
- logger.warning("Removing the pending model version.")
300
- try:
301
- self.api_client.request("DestroyModelVersion", id=pending_version_id)
302
- except ErrorResponse as e:
303
- msg = getattr(e, "content", str(e))
304
- logger.error(
305
- f"An error occurred removing the pending version {pending_version_id}: {msg}."
306
- )
307
-
308
- logger.info("Retrieving the existing model version.")
309
- existing_version_id = model_version["id"].pop()
310
- try:
311
- self.model_version = self.api_client.request(
312
- "RetrieveModelVersion", id=existing_version_id
313
- )
314
- except ErrorResponse as e:
315
- logger.error(
316
- f"An error occurred retrieving the existing version {existing_version_id}: {e.status_code} - {e.content}."
317
- )
318
- raise
319
-
320
- logger.info(f"Model version {self.model_version['id']} is now available.")
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "arkindex-base-worker"
7
- version = "0.5.2a2"
7
+ version = "0.5.2b1"
8
8
  description = "Base Worker to easily build Arkindex ML workflows"
9
9
  license-files = ["LICENSE"]
10
10
  dependencies = [
@@ -14,7 +14,7 @@ dependencies = [
14
14
  "python-gnupg==0.5.6",
15
15
  "python-magic==0.4.27",
16
16
  "shapely==2.0.6",
17
- "teklia-toolbox==0.1.12",
17
+ "teklia-toolbox==0.1.13",
18
18
  "zstandard==0.25.0",
19
19
  ]
20
20
  authors = [
@@ -27,16 +27,16 @@ def default_model_version():
27
27
  return {
28
28
  "id": "model_version_id",
29
29
  "model_id": "model_id",
30
- "state": "created",
31
30
  "parent": "42" * 16,
32
- "tag": "A simple tag",
33
31
  "description": "A description",
32
+ "tag": "A simple tag",
33
+ "state": "created",
34
+ "size": 42,
35
+ "archive_hash": "123456789",
34
36
  "configuration": {"test": "value"},
35
- "s3_url": None,
37
+ "s3_etag": None,
36
38
  "s3_put_url": "http://upload.archive",
37
- "hash": None,
38
- "archive_hash": None,
39
- "size": None,
39
+ "s3_url": None,
40
40
  "created": "2000-01-01T00:00:00Z",
41
41
  }
42
42
 
@@ -46,14 +46,11 @@ def test_create_archive(model_file_dir):
46
46
 
47
47
  with create_archive(path=model_file_dir) as (
48
48
  zst_archive_path,
49
- hash,
50
49
  size,
51
- archive_hash,
50
+ hash,
52
51
  ):
53
52
  assert zst_archive_path.exists(), "The archive was not created"
54
- assert hash == "c5aedde18a768757351068b840c8c8f9", (
55
- "Hash was not properly computed"
56
- )
53
+ assert len(hash) == 32
57
54
  assert 300 < size < 700
58
55
 
59
56
  assert not zst_archive_path.exists(), "Auto removal failed"
@@ -64,37 +61,16 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
64
61
 
65
62
  with create_archive(path=model_file_dir_with_subfolder) as (
66
63
  zst_archive_path,
67
- hash,
68
64
  size,
69
- archive_hash,
65
+ hash,
70
66
  ):
71
67
  assert zst_archive_path.exists(), "The archive was not created"
72
- assert hash == "3e453881404689e6e125144d2db3e605", (
73
- "Hash was not properly computed"
74
- )
68
+ assert len(hash) == 32
75
69
  assert 300 < size < 1500
76
70
 
77
71
  assert not zst_archive_path.exists(), "Auto removal failed"
78
72
 
79
73
 
80
- def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_dir):
81
- s3_endpoint_url = "http://s3.localhost.com"
82
- responses.add_passthru(s3_endpoint_url)
83
- responses.add(responses.PUT, s3_endpoint_url, status=400)
84
-
85
- mock_training_worker.model_version = {
86
- "state": "Created",
87
- "s3_put_url": s3_endpoint_url,
88
- }
89
-
90
- file_path = model_file_dir / "model_file.pth"
91
- with pytest.raises(
92
- Exception,
93
- match="400 Client Error: Bad Request for url: http://s3.localhost.com/",
94
- ):
95
- mock_training_worker.upload_to_s3(file_path)
96
-
97
-
98
74
  @pytest.mark.parametrize(
99
75
  "method",
100
76
  [
@@ -102,7 +78,6 @@ def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_
102
78
  "create_model_version",
103
79
  "update_model_version",
104
80
  "upload_to_s3",
105
- "validate_model_version",
106
81
  ],
107
82
  )
108
83
  def test_training_mixin_read_only(mock_training_worker, method, caplog):
@@ -127,12 +102,16 @@ def test_create_model_version_already_created(mock_training_worker):
127
102
  with pytest.raises(
128
103
  AssertionError, match="A model version has already been created."
129
104
  ):
130
- mock_training_worker.create_model_version(model_id="model_id")
105
+ mock_training_worker.create_model_version(
106
+ model_id="model_id", size=42, archive_hash="123456789"
107
+ )
131
108
 
132
109
 
133
110
  @pytest.mark.parametrize("set_tag", [True, False])
134
111
  def test_create_model_version(mock_training_worker, default_model_version, set_tag):
135
112
  args = {
113
+ "size": 42,
114
+ "archive_hash": "123456789",
136
115
  "parent": "42" * 16,
137
116
  "tag": "A simple tag",
138
117
  "description": "A description",
@@ -154,12 +133,12 @@ def test_create_model_version(mock_training_worker, default_model_version, set_t
154
133
 
155
134
  def test_update_model_version_not_created(mock_training_worker):
156
135
  with pytest.raises(AssertionError, match="No model version has been created yet."):
157
- mock_training_worker.update_model_version()
136
+ mock_training_worker.update_model_version(size=42, archive_hash="123456789")
158
137
 
159
138
 
160
139
  def test_update_model_version(mock_training_worker, default_model_version):
161
140
  mock_training_worker.model_version = default_model_version
162
- args = {"tag": "A new tag"}
141
+ args = {"size": 42, "archive_hash": "123456789", "tag": "A new tag"}
163
142
  new_model_version = {**default_model_version, "tag": "A new tag"}
164
143
  mock_training_worker.api_client.add_response(
165
144
  "UpdateModelVersion",
@@ -169,101 +148,3 @@ def test_update_model_version(mock_training_worker, default_model_version):
169
148
  )
170
149
  mock_training_worker.update_model_version(**args)
171
150
  assert mock_training_worker.model_version == new_model_version
172
-
173
-
174
- def test_validate_model_version_not_created(mock_training_worker):
175
- with pytest.raises(
176
- AssertionError,
177
- match="You must create the model version and upload its archive before validating it.",
178
- ):
179
- mock_training_worker.validate_model_version(hash="a", size=1, archive_hash="b")
180
-
181
-
182
- @pytest.mark.parametrize("deletion_failed", [True, False])
183
- def test_validate_model_version_hash_conflict(
184
- mock_training_worker,
185
- default_model_version,
186
- caplog,
187
- deletion_failed,
188
- ):
189
- mock_training_worker.model_version = {"id": "another_id"}
190
- args = {
191
- "hash": "hash",
192
- "archive_hash": "archive_hash",
193
- "size": 30,
194
- }
195
- mock_training_worker.api_client.add_error_response(
196
- "PartialUpdateModelVersion",
197
- id="another_id",
198
- status_code=409,
199
- body={"state": "available", **args},
200
- content={"id": ["model_version_id"]},
201
- )
202
- if deletion_failed:
203
- mock_training_worker.api_client.add_error_response(
204
- "DestroyModelVersion",
205
- id="another_id",
206
- status_code=403,
207
- content="Not admin",
208
- )
209
- else:
210
- mock_training_worker.api_client.add_response(
211
- "DestroyModelVersion",
212
- id="another_id",
213
- response="No content",
214
- )
215
- mock_training_worker.api_client.add_response(
216
- "RetrieveModelVersion",
217
- id="model_version_id",
218
- response=default_model_version,
219
- )
220
-
221
- mock_training_worker.validate_model_version(**args)
222
- assert mock_training_worker.model_version == default_model_version
223
- error_msg = []
224
- if deletion_failed:
225
- error_msg = [
226
- (
227
- logging.ERROR,
228
- "An error occurred removing the pending version another_id: Not admin.",
229
- )
230
- ]
231
- assert [
232
- (level, message)
233
- for module, level, message in caplog.record_tuples
234
- if module == "arkindex_worker"
235
- ] == [
236
- (
237
- logging.WARNING,
238
- "An available model version exists with hash hash, using it instead of the pending version.",
239
- ),
240
- (logging.WARNING, "Removing the pending model version."),
241
- *error_msg,
242
- (logging.INFO, "Retrieving the existing model version."),
243
- (logging.INFO, "Model version model_version_id is now available."),
244
- ]
245
-
246
-
247
- def test_validate_model_version(mock_training_worker, default_model_version, caplog):
248
- mock_training_worker.model_version = {"id": "model_version_id"}
249
- args = {
250
- "hash": "hash",
251
- "archive_hash": "archive_hash",
252
- "size": 30,
253
- }
254
- mock_training_worker.api_client.add_response(
255
- "PartialUpdateModelVersion",
256
- id="model_version_id",
257
- body={"state": "available", **args},
258
- response=default_model_version,
259
- )
260
-
261
- mock_training_worker.validate_model_version(**args)
262
- assert mock_training_worker.model_version == default_model_version
263
- assert [
264
- (level, message)
265
- for module, level, message in caplog.record_tuples
266
- if module == "arkindex_worker"
267
- ] == [
268
- (logging.INFO, "Model version model_version_id is now available."),
269
- ]
@@ -79,3 +79,42 @@ def test_with_secrets(mock_base_worker_modern_conf, responses):
79
79
  assert mock_base_worker_modern_conf.secrets == {
80
80
  "a_secret": "My super duper secret value"
81
81
  }
82
+
83
+
84
+ def test_with_secrets_ce(mock_base_worker_modern_conf, responses, monkeypatch):
85
+ # Provide the full configuration directly from the worker run
86
+ responses.add(
87
+ responses.GET,
88
+ "http://testserver/api/v1/workers/runs/56785678-5678-5678-5678-567856785678/configuration/",
89
+ status=200,
90
+ json={
91
+ "configuration": [
92
+ {"key": "some_key", "value": "test", "secret": False},
93
+ {
94
+ "key": "a_secret",
95
+ "value": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
96
+ "secret": True,
97
+ },
98
+ ]
99
+ },
100
+ )
101
+
102
+ # Remove the RetrieveSecret endpoint to simulate Arkindex CE
103
+ monkeypatch.delitem(
104
+ mock_base_worker_modern_conf.api_client.document.links, "RetrieveSecret"
105
+ )
106
+
107
+ mock_base_worker_modern_conf.configure()
108
+
109
+ assert mock_base_worker_modern_conf.config == {
110
+ "a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
111
+ "some_key": "test",
112
+ }
113
+ assert (
114
+ mock_base_worker_modern_conf.user_configuration
115
+ == mock_base_worker_modern_conf.config
116
+ )
117
+ assert mock_base_worker_modern_conf.secrets == {
118
+ # The value is used directly instead of treated as a secret name
119
+ "a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
120
+ }