arkindex-base-worker 0.5.2a1__tar.gz → 0.5.2b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/PKG-INFO +3 -2
  2. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/PKG-INFO +3 -2
  3. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/requires.txt +2 -1
  4. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/image.py +2 -2
  5. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/utils.py +7 -7
  6. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/__init__.py +8 -7
  7. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/base.py +10 -10
  8. arkindex_base_worker-0.5.2b1/arkindex_worker/worker/task.py +100 -0
  9. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/training.py +58 -124
  10. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/pyproject.toml +3 -2
  11. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_dataset_worker.py +50 -63
  12. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_task.py +112 -0
  13. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_training.py +17 -136
  14. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_modern_config.py +39 -0
  15. arkindex_base_worker-0.5.2a1/arkindex_worker/worker/task.py +0 -47
  16. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/LICENSE +0 -0
  17. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/README.md +0 -0
  18. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/SOURCES.txt +0 -0
  19. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/dependency_links.txt +0 -0
  20. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_base_worker.egg-info/top_level.txt +0 -0
  21. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/__init__.py +0 -0
  22. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/cache.py +0 -0
  23. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/models.py +0 -0
  24. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/classification.py +0 -0
  25. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/corpus.py +0 -0
  26. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/dataset.py +0 -0
  27. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/element.py +0 -0
  28. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/entity.py +0 -0
  29. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/image.py +0 -0
  30. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/metadata.py +0 -0
  31. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/process.py +0 -0
  32. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/arkindex_worker/worker/transcription.py +0 -0
  33. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/examples/standalone/python/worker.py +0 -0
  34. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/examples/tooled/python/worker.py +0 -0
  35. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/hooks/pre_gen_project.py +0 -0
  36. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/setup.cfg +0 -0
  37. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/__init__.py +0 -0
  38. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/conftest.py +0 -0
  39. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_base_worker.py +0 -0
  40. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_cache.py +0 -0
  41. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_element.py +0 -0
  42. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/__init__.py +0 -0
  43. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_classification.py +0 -0
  44. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_cli.py +0 -0
  45. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_corpus.py +0 -0
  46. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_dataset.py +0 -0
  47. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element.py +0 -0
  48. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_create_multiple.py +0 -0
  49. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_create_single.py +0 -0
  50. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_list_children.py +0 -0
  51. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_element_list_parents.py +0 -0
  52. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_entity.py +0 -0
  53. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_image.py +0 -0
  54. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_metadata.py +0 -0
  55. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_process.py +0 -0
  56. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_create.py +0 -0
  57. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_create_with_elements.py +0 -0
  58. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_transcription_list.py +0 -0
  59. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_elements_worker/test_worker.py +0 -0
  60. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_image.py +0 -0
  61. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_merge.py +0 -0
  62. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/tests/test_utils.py +0 -0
  63. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/__init__.py +0 -0
  64. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/conftest.py +0 -0
  65. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/worker-demo/tests/test_worker.py +0 -0
  66. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/worker-demo/worker_demo/__init__.py +0 -0
  67. {arkindex_base_worker-0.5.2a1 → arkindex_base_worker-0.5.2b1}/worker-demo/worker_demo/worker.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arkindex-base-worker
3
- Version: 0.5.2a1
3
+ Version: 0.5.2b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -23,8 +23,9 @@ Requires-Dist: humanize==4.15.0
23
23
  Requires-Dist: peewee~=3.17
24
24
  Requires-Dist: Pillow==11.3.0
25
25
  Requires-Dist: python-gnupg==0.5.6
26
+ Requires-Dist: python-magic==0.4.27
26
27
  Requires-Dist: shapely==2.0.6
27
- Requires-Dist: teklia-toolbox==0.1.12
28
+ Requires-Dist: teklia-toolbox==0.1.13
28
29
  Requires-Dist: zstandard==0.25.0
29
30
  Provides-Extra: tests
30
31
  Requires-Dist: pytest-mock==3.15.1; extra == "tests"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arkindex-base-worker
3
- Version: 0.5.2a1
3
+ Version: 0.5.2b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -23,8 +23,9 @@ Requires-Dist: humanize==4.15.0
23
23
  Requires-Dist: peewee~=3.17
24
24
  Requires-Dist: Pillow==11.3.0
25
25
  Requires-Dist: python-gnupg==0.5.6
26
+ Requires-Dist: python-magic==0.4.27
26
27
  Requires-Dist: shapely==2.0.6
27
- Requires-Dist: teklia-toolbox==0.1.12
28
+ Requires-Dist: teklia-toolbox==0.1.13
28
29
  Requires-Dist: zstandard==0.25.0
29
30
  Provides-Extra: tests
30
31
  Requires-Dist: pytest-mock==3.15.1; extra == "tests"
@@ -2,8 +2,9 @@ humanize==4.15.0
2
2
  peewee~=3.17
3
3
  Pillow==11.3.0
4
4
  python-gnupg==0.5.6
5
+ python-magic==0.4.27
5
6
  shapely==2.0.6
6
- teklia-toolbox==0.1.12
7
+ teklia-toolbox==0.1.13
7
8
  zstandard==0.25.0
8
9
 
9
10
  [tests]
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  from arkindex_worker.models import Element
39
39
 
40
40
  # See http://docs.python-requests.org/en/master/user/advanced/#timeouts
41
- DOWNLOAD_TIMEOUT = (30, 60)
41
+ REQUEST_TIMEOUT = (30, 60)
42
42
 
43
43
  BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
44
44
 
@@ -346,7 +346,7 @@ def _retried_request(url, *args, method=requests.get, **kwargs):
346
346
  url,
347
347
  *args,
348
348
  headers={"User-Agent": IIIF_USER_AGENT},
349
- timeout=DOWNLOAD_TIMEOUT,
349
+ timeout=REQUEST_TIMEOUT,
350
350
  verify=should_verify_cert(url),
351
351
  **kwargs,
352
352
  )
@@ -163,12 +163,12 @@ def zstd_compress(
163
163
 
164
164
  def create_tar_archive(
165
165
  path: Path, destination: Path | None = None
166
- ) -> tuple[int | None, Path, str]:
166
+ ) -> tuple[int | None, Path]:
167
167
  """Create a tar archive using the content at specified location.
168
168
 
169
169
  :param path: Path to the file to archive
170
170
  :param destination: Optional path for the created TAR archive. A tempfile will be created if this is omitted.
171
- :return: The file descriptor (if one was created) and path to the TAR archive, hash of its content.
171
+ :return: The file descriptor (if one was created) and path to the TAR archive.
172
172
  """
173
173
  # Parse destination and create a tmpfile if none was specified
174
174
  file_d, destination = (
@@ -204,26 +204,26 @@ def create_tar_archive(
204
204
  with file_path.open("rb") as file_data:
205
205
  for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
206
206
  content_hasher.update(chunk)
207
- return file_d, destination, content_hasher.hexdigest()
207
+ return file_d, destination
208
208
 
209
209
 
210
210
  def create_tar_zst_archive(
211
211
  source: Path, destination: Path | None = None
212
- ) -> tuple[int | None, Path, str, str]:
212
+ ) -> tuple[int | None, Path, str]:
213
213
  """Helper to create a TAR+ZST archive from a source folder.
214
214
 
215
215
  :param source: Path to the folder whose content should be archived.
216
216
  :param destination: Path to the created archive, defaults to None. If unspecified, a temporary file will be created.
217
- :return: The file descriptor of the created tempfile (if one was created), path to the archive, its hash and the hash of the tar archive's content.
217
+ :return: The file descriptor of the created tempfile (if one was created), path to the archive and its hash.
218
218
  """
219
219
  # Create tar archive
220
- tar_fd, tar_archive, tar_hash = create_tar_archive(source)
220
+ tar_fd, tar_archive = create_tar_archive(source)
221
221
 
222
222
  zst_fd, zst_archive, zst_hash = zstd_compress(tar_archive, destination)
223
223
 
224
224
  close_delete_file(tar_fd, tar_archive)
225
225
 
226
- return zst_fd, zst_archive, zst_hash, tar_hash
226
+ return zst_fd, zst_archive, zst_hash
227
227
 
228
228
 
229
229
  def create_zip_archive(source: Path, destination: Path | None = None) -> Path:
@@ -424,12 +424,13 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
424
424
  failed = 0
425
425
  for i, dataset_set in enumerate(dataset_sets, start=1):
426
426
  try:
427
- assert dataset_set.dataset.state == DatasetState.Complete.value, (
428
- "When processing a set, its dataset state should be Complete."
429
- )
430
-
431
- logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
432
- self.download_dataset_artifact(dataset_set.dataset)
427
+ if dataset_set.dataset.state == DatasetState.Complete.value:
428
+ logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
429
+ self.download_dataset_artifact(dataset_set.dataset)
430
+ else:
431
+ logger.warning(
432
+ f"The dataset {dataset_set.dataset} has its state set to `{dataset_set.dataset.state}`, its archive will not be downloaded"
433
+ )
433
434
 
434
435
  logger.info(f"Processing {dataset_set} ({i}/{count})")
435
436
  self.process_set(dataset_set)
@@ -444,7 +445,7 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
444
445
 
445
446
  logger.warning(message, exc_info=e if self.args.verbose else None)
446
447
 
447
- # Cleanup the latest downloaded dataset artifact
448
+ # Cleanup the latest downloaded dataset artifact (if needed)
448
449
  self.cleanup_downloaded_artifact()
449
450
 
450
451
  message = f"Ran on {count} {pluralize('set', count)}: {count - failed} completed, {failed} failed"
@@ -15,7 +15,7 @@ import gnupg
15
15
  import yaml
16
16
 
17
17
  from arkindex import options_from_env
18
- from arkindex.exceptions import ClientError, ErrorResponse
18
+ from arkindex.exceptions import ErrorResponse
19
19
  from arkindex_worker import logger
20
20
  from arkindex_worker.cache import (
21
21
  check_version,
@@ -261,6 +261,10 @@ class BaseWorker:
261
261
 
262
262
  logger.info(f"Loaded {worker_run['summary']} from API")
263
263
 
264
+ # The `RetrieveSecret` endpoint is only available in Arkindex EE.
265
+ # In CE, the values of `secret` fields should be used directly without calling `RetrieveSecret`.
266
+ can_retrieve_secret = "RetrieveSecret" in self.api_client.document.links
267
+
264
268
  def _process_config_item(item: dict) -> tuple[str, Any]:
265
269
  if not item["secret"]:
266
270
  return (item["key"], item["value"])
@@ -270,16 +274,12 @@ class BaseWorker:
270
274
  logger.info(f"Optional secret `{item['key']}` is not set")
271
275
  return (item["key"], None)
272
276
 
273
- # Load secret, only available in Arkindex EE
274
- try:
275
- secret = self.load_secret(Path(item["value"]))
276
- except ClientError as e:
277
- logger.error(
278
- f"Failed to retrieve the secret {item['value']}, probably an Arkindex Community Edition: {e}"
279
- )
280
- return (item["key"], None)
277
+ value = item["value"]
278
+ # Load secret when `RetrieveSecret` is available
279
+ if can_retrieve_secret:
280
+ value = self.load_secret(Path(item["value"]))
281
281
 
282
- return (item["key"], secret)
282
+ return (item["key"], value)
283
283
 
284
284
  # Load model version configuration when available
285
285
  # Workers will use model version ID and details to download the model
@@ -0,0 +1,100 @@
1
+ """
2
+ BaseWorker methods for tasks.
3
+ """
4
+
5
+ import uuid
6
+ from collections.abc import Iterator
7
+ from http.client import REQUEST_TIMEOUT
8
+ from pathlib import Path
9
+
10
+ import magic
11
+ import requests
12
+
13
+ from arkindex.compat import DownloadedFile
14
+ from arkindex_worker import logger
15
+ from arkindex_worker.models import Artifact
16
+ from teklia_toolbox.requests import should_verify_cert
17
+
18
+
19
+ class TaskMixin:
20
+ def list_artifacts(self, task_id: uuid.UUID) -> Iterator[Artifact]:
21
+ """
22
+ List artifacts associated to a task.
23
+
24
+ :param task_id: Task ID to find artifacts from.
25
+ :returns: An iterator of ``Artifact`` objects built from the ``ListArtifacts`` API endpoint.
26
+ """
27
+ assert task_id and isinstance(task_id, uuid.UUID), (
28
+ "task_id shouldn't be null and should be an UUID"
29
+ )
30
+
31
+ results = self.api_client.request("ListArtifacts", id=task_id)
32
+
33
+ return map(Artifact, results)
34
+
35
+ def download_artifact(
36
+ self, task_id: uuid.UUID, artifact: Artifact
37
+ ) -> DownloadedFile:
38
+ """
39
+ Download an artifact content.
40
+
41
+ :param task_id: Task ID the Artifact is from.
42
+ :param artifact: Artifact to download content from.
43
+ :returns: A temporary file containing the ``Artifact`` downloaded from the ``DownloadArtifact`` API endpoint.
44
+ """
45
+ assert task_id and isinstance(task_id, uuid.UUID), (
46
+ "task_id shouldn't be null and should be an UUID"
47
+ )
48
+ assert artifact and isinstance(artifact, Artifact), (
49
+ "artifact shouldn't be null and should be an Artifact"
50
+ )
51
+
52
+ return self.api_client.request(
53
+ "DownloadArtifact", id=task_id, path=artifact.path
54
+ )
55
+
56
+ def upload_artifact(self, path: Path) -> None:
57
+ """
58
+ Upload a single file as an Artifact of the current task.
59
+
60
+ :param path: Path of the single file to upload as an Artifact.
61
+ """
62
+ assert path and isinstance(path, Path) and path.exists(), (
63
+ "path shouldn't be null, should be a Path and should exist"
64
+ )
65
+
66
+ if self.is_read_only:
67
+ logger.warning("Cannot upload artifact as this worker is in read-only mode")
68
+ return
69
+
70
+ # Get path relative to task's data directory
71
+ relpath = str(path.relative_to(self.work_dir))
72
+
73
+ # Get file size
74
+ size = path.stat().st_size
75
+
76
+ # Detect content type
77
+ try:
78
+ content_type = magic.from_file(path, mime=True)
79
+ except Exception as e:
80
+ logger.warning(f"Failed to get a mime type for {path}: {e}")
81
+ content_type = "application/octet-stream"
82
+
83
+ # Create artifact on API to get an S3 url
84
+ artifact = self.api_client.request(
85
+ "CreateArtifact",
86
+ id=self.task_id,
87
+ body={"path": relpath, "content_type": content_type, "size": size},
88
+ )
89
+
90
+ # Upload the file content to S3
91
+ s3_put_url = artifact["s3_put_url"]
92
+ with path.open("rb") as content:
93
+ resp = requests.put(
94
+ s3_put_url,
95
+ data=content,
96
+ headers={"Content-Type": content_type},
97
+ timeout=REQUEST_TIMEOUT,
98
+ verify=should_verify_cert(s3_put_url),
99
+ )
100
+ resp.raise_for_status()
@@ -3,16 +3,15 @@ BaseWorker methods for training.
3
3
  """
4
4
 
5
5
  import functools
6
+ from collections.abc import Generator
6
7
  from contextlib import contextmanager
7
8
  from pathlib import Path
8
9
  from typing import NewType
9
10
  from uuid import UUID
10
11
 
11
- import requests
12
-
13
- from arkindex.exceptions import ErrorResponse
14
12
  from arkindex_worker import logger
15
13
  from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
14
+ from teklia_toolbox.uploads import MultipartUpload
16
15
 
17
16
  DirPath = NewType("DirPath", Path)
18
17
  """Path to a directory"""
@@ -25,23 +24,21 @@ FileSize = NewType("FileSize", int)
25
24
 
26
25
 
27
26
  @contextmanager
28
- def create_archive(path: DirPath) -> tuple[Path, Hash, FileSize, Hash]:
27
+ def create_archive(path: DirPath) -> Generator[tuple[Path, FileSize, Hash]]:
29
28
  """
30
29
  Create a tar archive from the files at the given location then compress it to a zst archive.
31
30
 
32
- Yield its location, its hash, its size and its content's hash.
31
+ Yield its location, its size and its hash.
33
32
 
34
33
  :param path: Create a compressed tar archive from the files
35
- :returns: The location of the created archive, its hash, its size and its content's hash
34
+ :returns: The location of the created archive, its size and its hash
36
35
  """
37
36
  assert path.is_dir(), "create_archive needs a directory"
38
37
 
39
- zst_descriptor, zst_archive, archive_hash, content_hash = create_tar_zst_archive(
40
- path
41
- )
38
+ zst_descriptor, zst_archive, archive_hash = create_tar_zst_archive(path)
42
39
 
43
40
  # Get content hash, archive size and hash
44
- yield zst_archive, content_hash, zst_archive.stat().st_size, archive_hash
41
+ yield zst_archive, zst_archive.stat().st_size, archive_hash
45
42
 
46
43
  # Remove the zst archive
47
44
  close_delete_file(zst_descriptor, zst_archive)
@@ -112,62 +109,48 @@ class TrainingMixin:
112
109
  """
113
110
 
114
111
  configuration = configuration or {}
115
- if not self.model_version:
116
- self.create_model_version(
117
- model_id=model_id,
118
- tag=tag,
119
- description=description,
120
- configuration=configuration,
121
- parent=parent,
122
- )
123
-
124
- elif tag or description or configuration or parent:
125
- assert self.model_version.get("model_id") == model_id, (
126
- "Given `model_id` does not match the current model version"
127
- )
128
- # If any attribute field has been defined, PATCH the current model version
129
- self.update_model_version(
130
- tag=tag,
131
- description=description,
132
- configuration=configuration,
133
- parent=parent,
134
- )
135
112
 
136
113
  # Create the zst archive, get its hash and size
137
- # Validate the model version
138
114
  with create_archive(path=model_path) as (
139
115
  path_to_archive,
140
- hash,
141
116
  size,
142
- archive_hash,
117
+ hash,
143
118
  ):
144
- # Create a new model version with hash and size
145
- self.upload_to_s3(archive_path=path_to_archive)
146
-
147
- current_version_id = self.model_version["id"]
148
- # Mark the model as valid
149
- self.validate_model_version(
150
- size=size,
151
- hash=hash,
152
- archive_hash=archive_hash,
153
- )
154
- if self.model_version["id"] != current_version_id and (
155
- tag or description or configuration or parent
156
- ):
157
- logger.warning(
158
- "Updating the existing available model version with the given attributes."
119
+ # Update an existing model version with hash, size and any other defined attribute
120
+ if self.model_version:
121
+ assert self.model_version.get("model_id") == model_id, (
122
+ "Given `model_id` does not match the current model version"
159
123
  )
160
124
  self.update_model_version(
125
+ size=size,
126
+ archive_hash=hash,
127
+ tag=tag,
128
+ description=description,
129
+ configuration=configuration,
130
+ parent=parent,
131
+ )
132
+
133
+ # Create a new model version with hash and size
134
+ else:
135
+ self.create_model_version(
136
+ model_id=model_id,
137
+ size=size,
138
+ archive_hash=hash,
161
139
  tag=tag,
162
140
  description=description,
163
141
  configuration=configuration,
164
142
  parent=parent,
165
143
  )
166
144
 
145
+ # Upload the archive in multiple parts (supports huge files)
146
+ self.upload_to_s3(path_to_archive)
147
+
167
148
  @skip_if_read_only
168
149
  def create_model_version(
169
150
  self,
170
151
  model_id: str,
152
+ size: FileSize,
153
+ archive_hash: Hash,
171
154
  tag: str | None = None,
172
155
  description: str | None = None,
173
156
  configuration: dict | None = None,
@@ -177,6 +160,8 @@ class TrainingMixin:
177
160
  Create a new version of the specified model with its base attributes.
178
161
  Once successfully created, the model version is accessible via `self.model_version`.
179
162
 
163
+ :param size: Size of uploaded archive
164
+ :param hash: MD5 hash of the uploaded archive
180
165
  :param tag: Tag of the model version
181
166
  :param description: Description of the model version
182
167
  :param configuration: Configuration of the model version
@@ -189,6 +174,8 @@ class TrainingMixin:
189
174
  "CreateModelVersion",
190
175
  id=model_id,
191
176
  body=build_clean_payload(
177
+ size=size,
178
+ archive_hash=archive_hash,
192
179
  tag=tag,
193
180
  description=description,
194
181
  configuration=configuration,
@@ -197,12 +184,14 @@ class TrainingMixin:
197
184
  )
198
185
 
199
186
  logger.info(
200
- f"Model version ({self.model_version['id']}) was successfully created"
187
+ f"Model version ({self.model_version['id']}) was successfully created."
201
188
  )
202
189
 
203
190
  @skip_if_read_only
204
191
  def update_model_version(
205
192
  self,
193
+ size: FileSize,
194
+ archive_hash: Hash,
206
195
  tag: str | None = None,
207
196
  description: str | None = None,
208
197
  configuration: dict | None = None,
@@ -211,6 +200,8 @@ class TrainingMixin:
211
200
  """
212
201
  Update the current model version with the given attributes.
213
202
 
203
+ :param size: Size of uploaded archive
204
+ :param hash: MD5 hash of the uploaded archive
214
205
  :param tag: Tag of the model version
215
206
  :param description: Description of the model version
216
207
  :param configuration: Configuration of the model version
@@ -221,6 +212,8 @@ class TrainingMixin:
221
212
  "UpdateModelVersion",
222
213
  id=self.model_version["id"],
223
214
  body=build_clean_payload(
215
+ size=size,
216
+ archive_hash=archive_hash,
224
217
  tag=tag,
225
218
  description=description,
226
219
  configuration=configuration,
@@ -228,93 +221,34 @@ class TrainingMixin:
228
221
  ),
229
222
  )
230
223
  logger.info(
231
- f"Model version ({self.model_version['id']}) was successfully updated"
224
+ f"Model version ({self.model_version['id']}) was successfully updated."
232
225
  )
233
226
 
234
227
  @skip_if_read_only
235
228
  def upload_to_s3(self, archive_path: Path) -> None:
236
229
  """
237
- Upload the archive of the model's files to an Amazon s3 compatible storage
230
+ Upload the archive of the model's files to an Amazon s3 compatible storage in multiple parts
238
231
  """
239
-
240
232
  assert self.model_version, (
241
233
  "You must create the model version before uploading an archive."
242
234
  )
243
235
  assert self.model_version["state"] != "Available", (
244
- "The model is already marked as available."
236
+ "The model version is already marked as available."
245
237
  )
246
238
 
247
- s3_put_url = self.model_version.get("s3_put_url")
248
- assert s3_put_url, (
249
- "S3 PUT URL is not set, please ensure you have the right to validate a model version."
250
- )
251
-
252
- logger.info("Uploading to s3...")
253
- # Upload the archive on s3
254
- with archive_path.open("rb") as archive:
255
- r = requests.put(
256
- url=s3_put_url,
257
- data=archive,
258
- headers={"Content-Type": "application/zstd"},
259
- )
260
- r.raise_for_status()
261
-
262
- @skip_if_read_only
263
- def validate_model_version(
264
- self,
265
- hash: str,
266
- size: int,
267
- archive_hash: str,
268
- ):
269
- """
270
- Sets the model version as `Available`, once its archive has been uploaded to S3.
271
-
272
- :param hash: MD5 hash of the files contained in the archive
273
- :param size: The size of the uploaded archive
274
- :param archive_hash: MD5 hash of the uploaded archive
275
- """
276
- assert self.model_version, (
277
- "You must create the model version and upload its archive before validating it."
239
+ multipart = MultipartUpload(
240
+ client=self.api_client,
241
+ file_path=archive_path,
242
+ object_type="model_version",
243
+ object_id=str(self.model_version["id"]),
278
244
  )
279
245
  try:
280
- self.model_version = self.api_client.request(
281
- "PartialUpdateModelVersion",
282
- id=self.model_version["id"],
283
- body={
284
- "state": "available",
285
- "size": size,
286
- "hash": hash,
287
- "archive_hash": archive_hash,
288
- },
246
+ multipart.upload()
247
+ multipart.complete()
248
+ except Exception:
249
+ multipart.abort()
250
+ raise
251
+ else:
252
+ logger.info(
253
+ f"Model version ({self.model_version['id']}) archive was successfully uploaded and is now available."
289
254
  )
290
- except ErrorResponse as e:
291
- model_version = e.content
292
- if not model_version or "id" not in model_version:
293
- raise e
294
-
295
- logger.warning(
296
- f"An available model version exists with hash {hash}, using it instead of the pending version."
297
- )
298
- pending_version_id = self.model_version["id"]
299
- logger.warning("Removing the pending model version.")
300
- try:
301
- self.api_client.request("DestroyModelVersion", id=pending_version_id)
302
- except ErrorResponse as e:
303
- msg = getattr(e, "content", str(e))
304
- logger.error(
305
- f"An error occurred removing the pending version {pending_version_id}: {msg}."
306
- )
307
-
308
- logger.info("Retrieving the existing model version.")
309
- existing_version_id = model_version["id"].pop()
310
- try:
311
- self.model_version = self.api_client.request(
312
- "RetrieveModelVersion", id=existing_version_id
313
- )
314
- except ErrorResponse as e:
315
- logger.error(
316
- f"An error occurred retrieving the existing version {existing_version_id}: {e.status_code} - {e.content}."
317
- )
318
- raise
319
-
320
- logger.info(f"Model version {self.model_version['id']} is now available.")
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "arkindex-base-worker"
7
- version = "0.5.2a1"
7
+ version = "0.5.2b1"
8
8
  description = "Base Worker to easily build Arkindex ML workflows"
9
9
  license-files = ["LICENSE"]
10
10
  dependencies = [
@@ -12,8 +12,9 @@ dependencies = [
12
12
  "peewee~=3.17",
13
13
  "Pillow==11.3.0",
14
14
  "python-gnupg==0.5.6",
15
+ "python-magic==0.4.27",
15
16
  "shapely==2.0.6",
16
- "teklia-toolbox==0.1.12",
17
+ "teklia-toolbox==0.1.13",
17
18
  "zstandard==0.25.0",
18
19
  ]
19
20
  authors = [