arkindex-base-worker 0.5.2a1__py3-none-any.whl → 0.5.2b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/METADATA +3 -2
- {arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/RECORD +15 -15
- {arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/WHEEL +1 -1
- arkindex_worker/image.py +2 -2
- arkindex_worker/utils.py +7 -7
- arkindex_worker/worker/__init__.py +8 -7
- arkindex_worker/worker/base.py +10 -10
- arkindex_worker/worker/task.py +53 -0
- arkindex_worker/worker/training.py +58 -124
- tests/test_dataset_worker.py +50 -63
- tests/test_elements_worker/test_task.py +112 -0
- tests/test_elements_worker/test_training.py +17 -136
- tests/test_modern_config.py +39 -0
- {arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/licenses/LICENSE +0 -0
- {arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: arkindex-base-worker
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.2b1
|
|
4
4
|
Summary: Base Worker to easily build Arkindex ML workflows
|
|
5
5
|
Author-email: Teklia <contact@teklia.com>
|
|
6
6
|
Maintainer-email: Teklia <contact@teklia.com>
|
|
@@ -23,8 +23,9 @@ Requires-Dist: humanize==4.15.0
|
|
|
23
23
|
Requires-Dist: peewee~=3.17
|
|
24
24
|
Requires-Dist: Pillow==11.3.0
|
|
25
25
|
Requires-Dist: python-gnupg==0.5.6
|
|
26
|
+
Requires-Dist: python-magic==0.4.27
|
|
26
27
|
Requires-Dist: shapely==2.0.6
|
|
27
|
-
Requires-Dist: teklia-toolbox==0.1.
|
|
28
|
+
Requires-Dist: teklia-toolbox==0.1.13
|
|
28
29
|
Requires-Dist: zstandard==0.25.0
|
|
29
30
|
Provides-Extra: tests
|
|
30
31
|
Requires-Dist: pytest-mock==3.15.1; extra == "tests"
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
arkindex_base_worker-0.5.
|
|
1
|
+
arkindex_base_worker-0.5.2b1.dist-info/licenses/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
|
|
2
2
|
arkindex_worker/__init__.py,sha256=Sdt5KXn8EgURb2MurYVrUWaHbH3iFA1XLRo0Lc5AJ44,250
|
|
3
3
|
arkindex_worker/cache.py,sha256=XpEXMSnbhYCvrJquwA9XXqZo-ajMLpaCxKG5wH3Gp6Y,10959
|
|
4
|
-
arkindex_worker/image.py,sha256=
|
|
4
|
+
arkindex_worker/image.py,sha256=9KeZHWNIDkwNJZR0y-mbyD_pvKfrgdktMB32jZqSMYk,20927
|
|
5
5
|
arkindex_worker/models.py,sha256=DgKvAB_2e1cPcuUavZkyTkV10jBK8y083oVklB9idSk,10855
|
|
6
|
-
arkindex_worker/utils.py,sha256=
|
|
7
|
-
arkindex_worker/worker/__init__.py,sha256=
|
|
8
|
-
arkindex_worker/worker/base.py,sha256
|
|
6
|
+
arkindex_worker/utils.py,sha256=kqOTVLBh-0krD2ukTkroiMZ2820wNYxeR8Cf1AyoqNA,10859
|
|
7
|
+
arkindex_worker/worker/__init__.py,sha256=tM_ynAARmtuJw5YWb_jI0AD5KNXbWN1K-VDiixIp7O4,18009
|
|
8
|
+
arkindex_worker/worker/base.py,sha256=2nQdPGh2qQOUNmvV2Mc1KZeqE8d4Fhy9tCo6Q2nNdNQ,22214
|
|
9
9
|
arkindex_worker/worker/classification.py,sha256=qvykymkgd4nGywHCxL8obo4egstoGsmWNS4Ztc1qNWQ,11024
|
|
10
10
|
arkindex_worker/worker/corpus.py,sha256=MeIMod7jkWyX0frtD0a37rhumnMV3p9ZOC1xwAoXrAA,2291
|
|
11
11
|
arkindex_worker/worker/dataset.py,sha256=tVaPx43vaH-KTtx4w5V06e26ha8XPfiJTRzBXlu928Y,5273
|
|
@@ -14,8 +14,8 @@ arkindex_worker/worker/entity.py,sha256=Aj6EOfzHEm7qQV-Egm0YKLZgCrLS_3ggOKTY81M2
|
|
|
14
14
|
arkindex_worker/worker/image.py,sha256=L6Ikuf0Z0RxJk7JarY5PggJGrYSHLaPK0vn0dy0CIaQ,623
|
|
15
15
|
arkindex_worker/worker/metadata.py,sha256=keZdOdUthSH2hAw9iet5pN7rzWihTUYjZHRGTEjaltw,6843
|
|
16
16
|
arkindex_worker/worker/process.py,sha256=9TEHpMcBax1wc6PrWMMrdXe2uNfqyVj7n_dAYZRBGnY,1854
|
|
17
|
-
arkindex_worker/worker/task.py,sha256=
|
|
18
|
-
arkindex_worker/worker/training.py,sha256=
|
|
17
|
+
arkindex_worker/worker/task.py,sha256=HASQU5LYVtgvCnRCLFC6iH7h7v6q_usZNZ-r_Wkv9A8,3306
|
|
18
|
+
arkindex_worker/worker/training.py,sha256=b1YGeUiOWob_DacS4fphGkErJGsx84YVgr5NnsukoEQ,8420
|
|
19
19
|
arkindex_worker/worker/transcription.py,sha256=sw718R119tsLNY8inPMVeIilvFJo94fMbMtYgH0zTM8,21250
|
|
20
20
|
examples/standalone/python/worker.py,sha256=Zr4s4pHvgexEjlkixLFYZp1UuwMLeoTxjyNG5_S2iYE,6672
|
|
21
21
|
examples/tooled/python/worker.py,sha256=kIYlHLsO5UpwX4XtERRq4tf2qTsvqKK30C-w8t0yyhA,1821
|
|
@@ -24,11 +24,11 @@ tests/__init__.py,sha256=DG--S6IpGl399rzSAjDdHL76CkOIeZIjajCcyUSDhOQ,241
|
|
|
24
24
|
tests/conftest.py,sha256=Tp7YFK17NATwF2yAcBwi0QFNyKSXtLS0VhZ-zZngsQI,24343
|
|
25
25
|
tests/test_base_worker.py,sha256=lwS4X3atS2ktEKd1XdogmN3mbzq-tO206-k_0EDITlw,29302
|
|
26
26
|
tests/test_cache.py,sha256=_wztzh94EwVrb8UvpFqgl2aa2_FLaCcJKaqunCYR5Dw,10435
|
|
27
|
-
tests/test_dataset_worker.py,sha256=
|
|
27
|
+
tests/test_dataset_worker.py,sha256=LmL3ERF1__PUPkTLiAFC0IYglZTv5WQYA42Vm-uhe2w,22023
|
|
28
28
|
tests/test_element.py,sha256=hlj5VSF4plwC7uz9R4LGOOXZJQcHZiYCIDZT5V6EIB8,14334
|
|
29
29
|
tests/test_image.py,sha256=yAM5mMfpQcIurT1KLHmu0AhSX2Qm3YvCu7afyZ3XUdU,28314
|
|
30
30
|
tests/test_merge.py,sha256=REpZ13jkq_qm_4L5URQgFy5lxvPZtXxQEiWfYLMdmF0,7956
|
|
31
|
-
tests/test_modern_config.py,sha256=
|
|
31
|
+
tests/test_modern_config.py,sha256=ZbMHT5b5RG3ZPX4MoqI8zitRg2y5fV1C6ynfyRkq828,4008
|
|
32
32
|
tests/test_utils.py,sha256=tgzNqyJMpddpeFWEjgsew_yDzmqnCA9HDaA5IpevAcM,5353
|
|
33
33
|
tests/test_elements_worker/__init__.py,sha256=2t3NciCIOun_N-Wv63FWGsTm5W9N3mbwAWVuFORlMg8,308
|
|
34
34
|
tests/test_elements_worker/test_classification.py,sha256=nya7veSPR_O9G41Enodp2-o6AifMBcaSTWJP2vXSSJ4,30133
|
|
@@ -44,8 +44,8 @@ tests/test_elements_worker/test_entity.py,sha256=SNAZEsVVLnqlliOmjkgv_cZhw0bAuJU
|
|
|
44
44
|
tests/test_elements_worker/test_image.py,sha256=BljMNKgec_9a5bzNzFpYZIvSbuvwsWDfdqLHVJaTa7M,2079
|
|
45
45
|
tests/test_elements_worker/test_metadata.py,sha256=qtTDtlp3VnBkfck7PAguK2dEgTLlr1i1EVnmNTeNf3A,20515
|
|
46
46
|
tests/test_elements_worker/test_process.py,sha256=y4RoVhPfyHzR795fw7-_FXElBcKo3fy4Ew_HI-kxJic,3088
|
|
47
|
-
tests/test_elements_worker/test_task.py,sha256=
|
|
48
|
-
tests/test_elements_worker/test_training.py,sha256=
|
|
47
|
+
tests/test_elements_worker/test_task.py,sha256=oHwP1fbJftXFA2U4qA3Gb4vX-iJoV-sBvPHnfBBpRrc,8906
|
|
48
|
+
tests/test_elements_worker/test_training.py,sha256=VY3YKYAm8IijAD6gWY0g06I27gXvMxa7SnAsCWm7G-8,4896
|
|
49
49
|
tests/test_elements_worker/test_transcription_create.py,sha256=yznO9B_BVsOR0Z_VY5ZL8gJp0ZPCz_4sPUs5dXtixAg,29281
|
|
50
50
|
tests/test_elements_worker/test_transcription_create_with_elements.py,sha256=tmcyglgssEqMnt1Mdy_u6X1m2wgLWTo_HdWst3GrK2k,33056
|
|
51
51
|
tests/test_elements_worker/test_transcription_list.py,sha256=ikz7HYPCoQWTdTRCd382SB-y-T2BbigPLlIcx5Eow-I,15324
|
|
@@ -55,7 +55,7 @@ worker-demo/tests/conftest.py,sha256=XzNMNeg6pmABUAH8jN6eZTlZSFGLYjS3-DTXjiRN6Yc
|
|
|
55
55
|
worker-demo/tests/test_worker.py,sha256=3DLd4NRK4bfyatG5P_PK4k9P9tJHx9XQq5_ryFEEFVg,304
|
|
56
56
|
worker-demo/worker_demo/__init__.py,sha256=2BPomV8ZMNf3YXJgloatKeHQCE6QOkwmsHGkO6MkQuM,125
|
|
57
57
|
worker-demo/worker_demo/worker.py,sha256=Rt-DjWa5iBP08k58NDZMfeyPuFbtNcbX6nc5jFX7GNo,440
|
|
58
|
-
arkindex_base_worker-0.5.
|
|
59
|
-
arkindex_base_worker-0.5.
|
|
60
|
-
arkindex_base_worker-0.5.
|
|
61
|
-
arkindex_base_worker-0.5.
|
|
58
|
+
arkindex_base_worker-0.5.2b1.dist-info/METADATA,sha256=3dxm9h6sl-bra7LqCF5_plfIrX9qW8zgzF4cnAiUcoQ,1885
|
|
59
|
+
arkindex_base_worker-0.5.2b1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
60
|
+
arkindex_base_worker-0.5.2b1.dist-info/top_level.txt,sha256=-vNjP2VfROx0j83mdi9aIqRZ88eoJjxeWz-R_gPgyXU,49
|
|
61
|
+
arkindex_base_worker-0.5.2b1.dist-info/RECORD,,
|
arkindex_worker/image.py
CHANGED
|
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|
|
38
38
|
from arkindex_worker.models import Element
|
|
39
39
|
|
|
40
40
|
# See http://docs.python-requests.org/en/master/user/advanced/#timeouts
|
|
41
|
-
|
|
41
|
+
REQUEST_TIMEOUT = (30, 60)
|
|
42
42
|
|
|
43
43
|
BoundingBox = namedtuple("BoundingBox", ["x", "y", "width", "height"])
|
|
44
44
|
|
|
@@ -346,7 +346,7 @@ def _retried_request(url, *args, method=requests.get, **kwargs):
|
|
|
346
346
|
url,
|
|
347
347
|
*args,
|
|
348
348
|
headers={"User-Agent": IIIF_USER_AGENT},
|
|
349
|
-
timeout=
|
|
349
|
+
timeout=REQUEST_TIMEOUT,
|
|
350
350
|
verify=should_verify_cert(url),
|
|
351
351
|
**kwargs,
|
|
352
352
|
)
|
arkindex_worker/utils.py
CHANGED
|
@@ -163,12 +163,12 @@ def zstd_compress(
|
|
|
163
163
|
|
|
164
164
|
def create_tar_archive(
|
|
165
165
|
path: Path, destination: Path | None = None
|
|
166
|
-
) -> tuple[int | None, Path
|
|
166
|
+
) -> tuple[int | None, Path]:
|
|
167
167
|
"""Create a tar archive using the content at specified location.
|
|
168
168
|
|
|
169
169
|
:param path: Path to the file to archive
|
|
170
170
|
:param destination: Optional path for the created TAR archive. A tempfile will be created if this is omitted.
|
|
171
|
-
:return: The file descriptor (if one was created) and path to the TAR archive
|
|
171
|
+
:return: The file descriptor (if one was created) and path to the TAR archive.
|
|
172
172
|
"""
|
|
173
173
|
# Parse destination and create a tmpfile if none was specified
|
|
174
174
|
file_d, destination = (
|
|
@@ -204,26 +204,26 @@ def create_tar_archive(
|
|
|
204
204
|
with file_path.open("rb") as file_data:
|
|
205
205
|
for chunk in iter(lambda: file_data.read(CHUNK_SIZE), b""):
|
|
206
206
|
content_hasher.update(chunk)
|
|
207
|
-
return file_d, destination
|
|
207
|
+
return file_d, destination
|
|
208
208
|
|
|
209
209
|
|
|
210
210
|
def create_tar_zst_archive(
|
|
211
211
|
source: Path, destination: Path | None = None
|
|
212
|
-
) -> tuple[int | None, Path, str
|
|
212
|
+
) -> tuple[int | None, Path, str]:
|
|
213
213
|
"""Helper to create a TAR+ZST archive from a source folder.
|
|
214
214
|
|
|
215
215
|
:param source: Path to the folder whose content should be archived.
|
|
216
216
|
:param destination: Path to the created archive, defaults to None. If unspecified, a temporary file will be created.
|
|
217
|
-
:return: The file descriptor of the created tempfile (if one was created), path to the archive
|
|
217
|
+
:return: The file descriptor of the created tempfile (if one was created), path to the archive and its hash.
|
|
218
218
|
"""
|
|
219
219
|
# Create tar archive
|
|
220
|
-
tar_fd, tar_archive
|
|
220
|
+
tar_fd, tar_archive = create_tar_archive(source)
|
|
221
221
|
|
|
222
222
|
zst_fd, zst_archive, zst_hash = zstd_compress(tar_archive, destination)
|
|
223
223
|
|
|
224
224
|
close_delete_file(tar_fd, tar_archive)
|
|
225
225
|
|
|
226
|
-
return zst_fd, zst_archive, zst_hash
|
|
226
|
+
return zst_fd, zst_archive, zst_hash
|
|
227
227
|
|
|
228
228
|
|
|
229
229
|
def create_zip_archive(source: Path, destination: Path | None = None) -> Path:
|
|
@@ -424,12 +424,13 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
|
|
|
424
424
|
failed = 0
|
|
425
425
|
for i, dataset_set in enumerate(dataset_sets, start=1):
|
|
426
426
|
try:
|
|
427
|
-
|
|
428
|
-
"
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
427
|
+
if dataset_set.dataset.state == DatasetState.Complete.value:
|
|
428
|
+
logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
|
|
429
|
+
self.download_dataset_artifact(dataset_set.dataset)
|
|
430
|
+
else:
|
|
431
|
+
logger.warning(
|
|
432
|
+
f"The dataset {dataset_set.dataset} has its state set to `{dataset_set.dataset.state}`, its archive will not be downloaded"
|
|
433
|
+
)
|
|
433
434
|
|
|
434
435
|
logger.info(f"Processing {dataset_set} ({i}/{count})")
|
|
435
436
|
self.process_set(dataset_set)
|
|
@@ -444,7 +445,7 @@ class DatasetWorker(DatasetMixin, BaseWorker, TaskMixin):
|
|
|
444
445
|
|
|
445
446
|
logger.warning(message, exc_info=e if self.args.verbose else None)
|
|
446
447
|
|
|
447
|
-
# Cleanup the latest downloaded dataset artifact
|
|
448
|
+
# Cleanup the latest downloaded dataset artifact (if needed)
|
|
448
449
|
self.cleanup_downloaded_artifact()
|
|
449
450
|
|
|
450
451
|
message = f"Ran on {count} {pluralize('set', count)}: {count - failed} completed, {failed} failed"
|
arkindex_worker/worker/base.py
CHANGED
|
@@ -15,7 +15,7 @@ import gnupg
|
|
|
15
15
|
import yaml
|
|
16
16
|
|
|
17
17
|
from arkindex import options_from_env
|
|
18
|
-
from arkindex.exceptions import
|
|
18
|
+
from arkindex.exceptions import ErrorResponse
|
|
19
19
|
from arkindex_worker import logger
|
|
20
20
|
from arkindex_worker.cache import (
|
|
21
21
|
check_version,
|
|
@@ -261,6 +261,10 @@ class BaseWorker:
|
|
|
261
261
|
|
|
262
262
|
logger.info(f"Loaded {worker_run['summary']} from API")
|
|
263
263
|
|
|
264
|
+
# The `RetrieveSecret` endpoint is only available in Arkindex EE.
|
|
265
|
+
# In CE, the values of `secret` fields should be used directly without calling `RetrieveSecret`.
|
|
266
|
+
can_retrieve_secret = "RetrieveSecret" in self.api_client.document.links
|
|
267
|
+
|
|
264
268
|
def _process_config_item(item: dict) -> tuple[str, Any]:
|
|
265
269
|
if not item["secret"]:
|
|
266
270
|
return (item["key"], item["value"])
|
|
@@ -270,16 +274,12 @@ class BaseWorker:
|
|
|
270
274
|
logger.info(f"Optional secret `{item['key']}` is not set")
|
|
271
275
|
return (item["key"], None)
|
|
272
276
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
logger.error(
|
|
278
|
-
f"Failed to retrieve the secret {item['value']}, probably an Arkindex Community Edition: {e}"
|
|
279
|
-
)
|
|
280
|
-
return (item["key"], None)
|
|
277
|
+
value = item["value"]
|
|
278
|
+
# Load secret when `RetrieveSecret` is available
|
|
279
|
+
if can_retrieve_secret:
|
|
280
|
+
value = self.load_secret(Path(item["value"]))
|
|
281
281
|
|
|
282
|
-
return (item["key"],
|
|
282
|
+
return (item["key"], value)
|
|
283
283
|
|
|
284
284
|
# Load model version configuration when available
|
|
285
285
|
# Workers will use model version ID and details to download the model
|
arkindex_worker/worker/task.py
CHANGED
|
@@ -4,9 +4,16 @@ BaseWorker methods for tasks.
|
|
|
4
4
|
|
|
5
5
|
import uuid
|
|
6
6
|
from collections.abc import Iterator
|
|
7
|
+
from http.client import REQUEST_TIMEOUT
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import magic
|
|
11
|
+
import requests
|
|
7
12
|
|
|
8
13
|
from arkindex.compat import DownloadedFile
|
|
14
|
+
from arkindex_worker import logger
|
|
9
15
|
from arkindex_worker.models import Artifact
|
|
16
|
+
from teklia_toolbox.requests import should_verify_cert
|
|
10
17
|
|
|
11
18
|
|
|
12
19
|
class TaskMixin:
|
|
@@ -45,3 +52,49 @@ class TaskMixin:
|
|
|
45
52
|
return self.api_client.request(
|
|
46
53
|
"DownloadArtifact", id=task_id, path=artifact.path
|
|
47
54
|
)
|
|
55
|
+
|
|
56
|
+
def upload_artifact(self, path: Path) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Upload a single file as an Artifact of the current task.
|
|
59
|
+
|
|
60
|
+
:param path: Path of the single file to upload as an Artifact.
|
|
61
|
+
"""
|
|
62
|
+
assert path and isinstance(path, Path) and path.exists(), (
|
|
63
|
+
"path shouldn't be null, should be a Path and should exist"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if self.is_read_only:
|
|
67
|
+
logger.warning("Cannot upload artifact as this worker is in read-only mode")
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
# Get path relative to task's data directory
|
|
71
|
+
relpath = str(path.relative_to(self.work_dir))
|
|
72
|
+
|
|
73
|
+
# Get file size
|
|
74
|
+
size = path.stat().st_size
|
|
75
|
+
|
|
76
|
+
# Detect content type
|
|
77
|
+
try:
|
|
78
|
+
content_type = magic.from_file(path, mime=True)
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.warning(f"Failed to get a mime type for {path}: {e}")
|
|
81
|
+
content_type = "application/octet-stream"
|
|
82
|
+
|
|
83
|
+
# Create artifact on API to get an S3 url
|
|
84
|
+
artifact = self.api_client.request(
|
|
85
|
+
"CreateArtifact",
|
|
86
|
+
id=self.task_id,
|
|
87
|
+
body={"path": relpath, "content_type": content_type, "size": size},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Upload the file content to S3
|
|
91
|
+
s3_put_url = artifact["s3_put_url"]
|
|
92
|
+
with path.open("rb") as content:
|
|
93
|
+
resp = requests.put(
|
|
94
|
+
s3_put_url,
|
|
95
|
+
data=content,
|
|
96
|
+
headers={"Content-Type": content_type},
|
|
97
|
+
timeout=REQUEST_TIMEOUT,
|
|
98
|
+
verify=should_verify_cert(s3_put_url),
|
|
99
|
+
)
|
|
100
|
+
resp.raise_for_status()
|
|
@@ -3,16 +3,15 @@ BaseWorker methods for training.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import functools
|
|
6
|
+
from collections.abc import Generator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from typing import NewType
|
|
9
10
|
from uuid import UUID
|
|
10
11
|
|
|
11
|
-
import requests
|
|
12
|
-
|
|
13
|
-
from arkindex.exceptions import ErrorResponse
|
|
14
12
|
from arkindex_worker import logger
|
|
15
13
|
from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
|
|
14
|
+
from teklia_toolbox.uploads import MultipartUpload
|
|
16
15
|
|
|
17
16
|
DirPath = NewType("DirPath", Path)
|
|
18
17
|
"""Path to a directory"""
|
|
@@ -25,23 +24,21 @@ FileSize = NewType("FileSize", int)
|
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
@contextmanager
|
|
28
|
-
def create_archive(path: DirPath) -> tuple[Path,
|
|
27
|
+
def create_archive(path: DirPath) -> Generator[tuple[Path, FileSize, Hash]]:
|
|
29
28
|
"""
|
|
30
29
|
Create a tar archive from the files at the given location then compress it to a zst archive.
|
|
31
30
|
|
|
32
|
-
Yield its location, its
|
|
31
|
+
Yield its location, its size and its hash.
|
|
33
32
|
|
|
34
33
|
:param path: Create a compressed tar archive from the files
|
|
35
|
-
:returns: The location of the created archive, its
|
|
34
|
+
:returns: The location of the created archive, its size and its hash
|
|
36
35
|
"""
|
|
37
36
|
assert path.is_dir(), "create_archive needs a directory"
|
|
38
37
|
|
|
39
|
-
zst_descriptor, zst_archive, archive_hash
|
|
40
|
-
path
|
|
41
|
-
)
|
|
38
|
+
zst_descriptor, zst_archive, archive_hash = create_tar_zst_archive(path)
|
|
42
39
|
|
|
43
40
|
# Get content hash, archive size and hash
|
|
44
|
-
yield zst_archive,
|
|
41
|
+
yield zst_archive, zst_archive.stat().st_size, archive_hash
|
|
45
42
|
|
|
46
43
|
# Remove the zst archive
|
|
47
44
|
close_delete_file(zst_descriptor, zst_archive)
|
|
@@ -112,62 +109,48 @@ class TrainingMixin:
|
|
|
112
109
|
"""
|
|
113
110
|
|
|
114
111
|
configuration = configuration or {}
|
|
115
|
-
if not self.model_version:
|
|
116
|
-
self.create_model_version(
|
|
117
|
-
model_id=model_id,
|
|
118
|
-
tag=tag,
|
|
119
|
-
description=description,
|
|
120
|
-
configuration=configuration,
|
|
121
|
-
parent=parent,
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
elif tag or description or configuration or parent:
|
|
125
|
-
assert self.model_version.get("model_id") == model_id, (
|
|
126
|
-
"Given `model_id` does not match the current model version"
|
|
127
|
-
)
|
|
128
|
-
# If any attribute field has been defined, PATCH the current model version
|
|
129
|
-
self.update_model_version(
|
|
130
|
-
tag=tag,
|
|
131
|
-
description=description,
|
|
132
|
-
configuration=configuration,
|
|
133
|
-
parent=parent,
|
|
134
|
-
)
|
|
135
112
|
|
|
136
113
|
# Create the zst archive, get its hash and size
|
|
137
|
-
# Validate the model version
|
|
138
114
|
with create_archive(path=model_path) as (
|
|
139
115
|
path_to_archive,
|
|
140
|
-
hash,
|
|
141
116
|
size,
|
|
142
|
-
|
|
117
|
+
hash,
|
|
143
118
|
):
|
|
144
|
-
#
|
|
145
|
-
self.
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
# Mark the model as valid
|
|
149
|
-
self.validate_model_version(
|
|
150
|
-
size=size,
|
|
151
|
-
hash=hash,
|
|
152
|
-
archive_hash=archive_hash,
|
|
153
|
-
)
|
|
154
|
-
if self.model_version["id"] != current_version_id and (
|
|
155
|
-
tag or description or configuration or parent
|
|
156
|
-
):
|
|
157
|
-
logger.warning(
|
|
158
|
-
"Updating the existing available model version with the given attributes."
|
|
119
|
+
# Update an existing model version with hash, size and any other defined attribute
|
|
120
|
+
if self.model_version:
|
|
121
|
+
assert self.model_version.get("model_id") == model_id, (
|
|
122
|
+
"Given `model_id` does not match the current model version"
|
|
159
123
|
)
|
|
160
124
|
self.update_model_version(
|
|
125
|
+
size=size,
|
|
126
|
+
archive_hash=hash,
|
|
127
|
+
tag=tag,
|
|
128
|
+
description=description,
|
|
129
|
+
configuration=configuration,
|
|
130
|
+
parent=parent,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Create a new model version with hash and size
|
|
134
|
+
else:
|
|
135
|
+
self.create_model_version(
|
|
136
|
+
model_id=model_id,
|
|
137
|
+
size=size,
|
|
138
|
+
archive_hash=hash,
|
|
161
139
|
tag=tag,
|
|
162
140
|
description=description,
|
|
163
141
|
configuration=configuration,
|
|
164
142
|
parent=parent,
|
|
165
143
|
)
|
|
166
144
|
|
|
145
|
+
# Upload the archive in multiple parts (supports huge files)
|
|
146
|
+
self.upload_to_s3(path_to_archive)
|
|
147
|
+
|
|
167
148
|
@skip_if_read_only
|
|
168
149
|
def create_model_version(
|
|
169
150
|
self,
|
|
170
151
|
model_id: str,
|
|
152
|
+
size: FileSize,
|
|
153
|
+
archive_hash: Hash,
|
|
171
154
|
tag: str | None = None,
|
|
172
155
|
description: str | None = None,
|
|
173
156
|
configuration: dict | None = None,
|
|
@@ -177,6 +160,8 @@ class TrainingMixin:
|
|
|
177
160
|
Create a new version of the specified model with its base attributes.
|
|
178
161
|
Once successfully created, the model version is accessible via `self.model_version`.
|
|
179
162
|
|
|
163
|
+
:param size: Size of uploaded archive
|
|
164
|
+
:param hash: MD5 hash of the uploaded archive
|
|
180
165
|
:param tag: Tag of the model version
|
|
181
166
|
:param description: Description of the model version
|
|
182
167
|
:param configuration: Configuration of the model version
|
|
@@ -189,6 +174,8 @@ class TrainingMixin:
|
|
|
189
174
|
"CreateModelVersion",
|
|
190
175
|
id=model_id,
|
|
191
176
|
body=build_clean_payload(
|
|
177
|
+
size=size,
|
|
178
|
+
archive_hash=archive_hash,
|
|
192
179
|
tag=tag,
|
|
193
180
|
description=description,
|
|
194
181
|
configuration=configuration,
|
|
@@ -197,12 +184,14 @@ class TrainingMixin:
|
|
|
197
184
|
)
|
|
198
185
|
|
|
199
186
|
logger.info(
|
|
200
|
-
f"Model version ({self.model_version['id']}) was successfully created"
|
|
187
|
+
f"Model version ({self.model_version['id']}) was successfully created."
|
|
201
188
|
)
|
|
202
189
|
|
|
203
190
|
@skip_if_read_only
|
|
204
191
|
def update_model_version(
|
|
205
192
|
self,
|
|
193
|
+
size: FileSize,
|
|
194
|
+
archive_hash: Hash,
|
|
206
195
|
tag: str | None = None,
|
|
207
196
|
description: str | None = None,
|
|
208
197
|
configuration: dict | None = None,
|
|
@@ -211,6 +200,8 @@ class TrainingMixin:
|
|
|
211
200
|
"""
|
|
212
201
|
Update the current model version with the given attributes.
|
|
213
202
|
|
|
203
|
+
:param size: Size of uploaded archive
|
|
204
|
+
:param hash: MD5 hash of the uploaded archive
|
|
214
205
|
:param tag: Tag of the model version
|
|
215
206
|
:param description: Description of the model version
|
|
216
207
|
:param configuration: Configuration of the model version
|
|
@@ -221,6 +212,8 @@ class TrainingMixin:
|
|
|
221
212
|
"UpdateModelVersion",
|
|
222
213
|
id=self.model_version["id"],
|
|
223
214
|
body=build_clean_payload(
|
|
215
|
+
size=size,
|
|
216
|
+
archive_hash=archive_hash,
|
|
224
217
|
tag=tag,
|
|
225
218
|
description=description,
|
|
226
219
|
configuration=configuration,
|
|
@@ -228,93 +221,34 @@ class TrainingMixin:
|
|
|
228
221
|
),
|
|
229
222
|
)
|
|
230
223
|
logger.info(
|
|
231
|
-
f"Model version ({self.model_version['id']}) was successfully updated"
|
|
224
|
+
f"Model version ({self.model_version['id']}) was successfully updated."
|
|
232
225
|
)
|
|
233
226
|
|
|
234
227
|
@skip_if_read_only
|
|
235
228
|
def upload_to_s3(self, archive_path: Path) -> None:
|
|
236
229
|
"""
|
|
237
|
-
Upload the archive of the model's files to an Amazon s3 compatible storage
|
|
230
|
+
Upload the archive of the model's files to an Amazon s3 compatible storage in multiple parts
|
|
238
231
|
"""
|
|
239
|
-
|
|
240
232
|
assert self.model_version, (
|
|
241
233
|
"You must create the model version before uploading an archive."
|
|
242
234
|
)
|
|
243
235
|
assert self.model_version["state"] != "Available", (
|
|
244
|
-
"The model is already marked as available."
|
|
236
|
+
"The model version is already marked as available."
|
|
245
237
|
)
|
|
246
238
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
logger.info("Uploading to s3...")
|
|
253
|
-
# Upload the archive on s3
|
|
254
|
-
with archive_path.open("rb") as archive:
|
|
255
|
-
r = requests.put(
|
|
256
|
-
url=s3_put_url,
|
|
257
|
-
data=archive,
|
|
258
|
-
headers={"Content-Type": "application/zstd"},
|
|
259
|
-
)
|
|
260
|
-
r.raise_for_status()
|
|
261
|
-
|
|
262
|
-
@skip_if_read_only
|
|
263
|
-
def validate_model_version(
|
|
264
|
-
self,
|
|
265
|
-
hash: str,
|
|
266
|
-
size: int,
|
|
267
|
-
archive_hash: str,
|
|
268
|
-
):
|
|
269
|
-
"""
|
|
270
|
-
Sets the model version as `Available`, once its archive has been uploaded to S3.
|
|
271
|
-
|
|
272
|
-
:param hash: MD5 hash of the files contained in the archive
|
|
273
|
-
:param size: The size of the uploaded archive
|
|
274
|
-
:param archive_hash: MD5 hash of the uploaded archive
|
|
275
|
-
"""
|
|
276
|
-
assert self.model_version, (
|
|
277
|
-
"You must create the model version and upload its archive before validating it."
|
|
239
|
+
multipart = MultipartUpload(
|
|
240
|
+
client=self.api_client,
|
|
241
|
+
file_path=archive_path,
|
|
242
|
+
object_type="model_version",
|
|
243
|
+
object_id=str(self.model_version["id"]),
|
|
278
244
|
)
|
|
279
245
|
try:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
},
|
|
246
|
+
multipart.upload()
|
|
247
|
+
multipart.complete()
|
|
248
|
+
except Exception:
|
|
249
|
+
multipart.abort()
|
|
250
|
+
raise
|
|
251
|
+
else:
|
|
252
|
+
logger.info(
|
|
253
|
+
f"Model version ({self.model_version['id']}) archive was successfully uploaded and is now available."
|
|
289
254
|
)
|
|
290
|
-
except ErrorResponse as e:
|
|
291
|
-
model_version = e.content
|
|
292
|
-
if not model_version or "id" not in model_version:
|
|
293
|
-
raise e
|
|
294
|
-
|
|
295
|
-
logger.warning(
|
|
296
|
-
f"An available model version exists with hash {hash}, using it instead of the pending version."
|
|
297
|
-
)
|
|
298
|
-
pending_version_id = self.model_version["id"]
|
|
299
|
-
logger.warning("Removing the pending model version.")
|
|
300
|
-
try:
|
|
301
|
-
self.api_client.request("DestroyModelVersion", id=pending_version_id)
|
|
302
|
-
except ErrorResponse as e:
|
|
303
|
-
msg = getattr(e, "content", str(e))
|
|
304
|
-
logger.error(
|
|
305
|
-
f"An error occurred removing the pending version {pending_version_id}: {msg}."
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
logger.info("Retrieving the existing model version.")
|
|
309
|
-
existing_version_id = model_version["id"].pop()
|
|
310
|
-
try:
|
|
311
|
-
self.model_version = self.api_client.request(
|
|
312
|
-
"RetrieveModelVersion", id=existing_version_id
|
|
313
|
-
)
|
|
314
|
-
except ErrorResponse as e:
|
|
315
|
-
logger.error(
|
|
316
|
-
f"An error occurred retrieving the existing version {existing_version_id}: {e.status_code} - {e.content}."
|
|
317
|
-
)
|
|
318
|
-
raise
|
|
319
|
-
|
|
320
|
-
logger.info(f"Model version {self.model_version['id']} is now available.")
|
tests/test_dataset_worker.py
CHANGED
|
@@ -435,34 +435,6 @@ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
|
|
|
435
435
|
]
|
|
436
436
|
|
|
437
437
|
|
|
438
|
-
def test_run_initial_dataset_state_error(
|
|
439
|
-
mocker, responses, caplog, mock_dataset_worker, default_dataset
|
|
440
|
-
):
|
|
441
|
-
default_dataset.state = DatasetState.Building.value
|
|
442
|
-
mocker.patch(
|
|
443
|
-
"arkindex_worker.worker.DatasetWorker.list_sets",
|
|
444
|
-
return_value=[Set(name="train", dataset=default_dataset)],
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
with pytest.raises(SystemExit):
|
|
448
|
-
mock_dataset_worker.run()
|
|
449
|
-
|
|
450
|
-
assert len(responses.calls) == len(BASE_API_CALLS) * 2
|
|
451
|
-
assert [
|
|
452
|
-
(call.request.method, call.request.url) for call in responses.calls
|
|
453
|
-
] == BASE_API_CALLS * 2
|
|
454
|
-
|
|
455
|
-
assert [(level, message) for _, level, message in caplog.record_tuples] == [
|
|
456
|
-
(logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
|
|
457
|
-
(logging.INFO, "Modern configuration is not available"),
|
|
458
|
-
(
|
|
459
|
-
logging.WARNING,
|
|
460
|
-
"Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
|
|
461
|
-
),
|
|
462
|
-
(logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
|
|
463
|
-
]
|
|
464
|
-
|
|
465
|
-
|
|
466
438
|
def test_run_download_dataset_artifact_api_error(
|
|
467
439
|
mocker,
|
|
468
440
|
tmp_path,
|
|
@@ -570,16 +542,18 @@ def test_run_no_downloaded_dataset_artifact_error(
|
|
|
570
542
|
]
|
|
571
543
|
|
|
572
544
|
|
|
545
|
+
@pytest.mark.parametrize("dataset_state", DatasetState)
|
|
573
546
|
def test_run(
|
|
574
547
|
mocker,
|
|
575
548
|
tmp_path,
|
|
576
549
|
responses,
|
|
577
550
|
caplog,
|
|
551
|
+
dataset_state,
|
|
578
552
|
mock_dataset_worker,
|
|
579
553
|
default_dataset,
|
|
580
554
|
default_artifact,
|
|
581
555
|
):
|
|
582
|
-
default_dataset.state =
|
|
556
|
+
default_dataset.state = dataset_state.value
|
|
583
557
|
mocker.patch(
|
|
584
558
|
"arkindex_worker.worker.DatasetWorker.list_sets",
|
|
585
559
|
return_value=[Set(name="train", dataset=default_dataset)],
|
|
@@ -590,55 +564,68 @@ def test_run(
|
|
|
590
564
|
)
|
|
591
565
|
mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
|
|
592
566
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
responses.
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
responses.
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
567
|
+
if dataset_state == DatasetState.Complete:
|
|
568
|
+
archive_path = (
|
|
569
|
+
FIXTURES_DIR
|
|
570
|
+
/ "extract_parent_archives"
|
|
571
|
+
/ "first_parent"
|
|
572
|
+
/ "arkindex_data.tar.zst"
|
|
573
|
+
)
|
|
574
|
+
responses.add(
|
|
575
|
+
responses.GET,
|
|
576
|
+
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
|
|
577
|
+
status=200,
|
|
578
|
+
json=[default_artifact],
|
|
579
|
+
)
|
|
580
|
+
responses.add(
|
|
581
|
+
responses.GET,
|
|
582
|
+
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
|
|
583
|
+
status=200,
|
|
584
|
+
body=archive_path.read_bytes(),
|
|
585
|
+
content_type="application/zstd",
|
|
586
|
+
)
|
|
612
587
|
|
|
613
588
|
mock_dataset_worker.run()
|
|
614
589
|
|
|
615
590
|
assert mock_process.call_count == 1
|
|
616
591
|
|
|
617
|
-
|
|
592
|
+
# We only download the dataset archive when it is Complete
|
|
593
|
+
extra_calls = []
|
|
594
|
+
if dataset_state == DatasetState.Complete:
|
|
595
|
+
extra_calls = [
|
|
596
|
+
(
|
|
597
|
+
"GET",
|
|
598
|
+
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
|
|
599
|
+
),
|
|
600
|
+
(
|
|
601
|
+
"GET",
|
|
602
|
+
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
|
|
603
|
+
),
|
|
604
|
+
]
|
|
605
|
+
|
|
606
|
+
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
|
|
618
607
|
assert [
|
|
619
608
|
(call.request.method, call.request.url) for call in responses.calls
|
|
620
|
-
] == BASE_API_CALLS * 2 +
|
|
621
|
-
(
|
|
622
|
-
"GET",
|
|
623
|
-
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
|
|
624
|
-
),
|
|
625
|
-
(
|
|
626
|
-
"GET",
|
|
627
|
-
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
|
|
628
|
-
),
|
|
629
|
-
]
|
|
609
|
+
] == BASE_API_CALLS * 2 + extra_calls
|
|
630
610
|
|
|
631
|
-
|
|
611
|
+
logs = [
|
|
632
612
|
(logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
|
|
633
613
|
(logging.INFO, "Modern configuration is not available"),
|
|
634
614
|
(
|
|
635
|
-
logging.
|
|
636
|
-
"
|
|
615
|
+
logging.WARNING,
|
|
616
|
+
f"The dataset Dataset (dataset_id) has its state set to `{dataset_state.value}`, its archive will not be downloaded",
|
|
637
617
|
),
|
|
638
|
-
(logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
|
|
639
618
|
(logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
|
|
640
619
|
(logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
|
|
641
620
|
]
|
|
621
|
+
if dataset_state == DatasetState.Complete:
|
|
622
|
+
logs[2] = (
|
|
623
|
+
logging.INFO,
|
|
624
|
+
"Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
|
|
625
|
+
)
|
|
626
|
+
logs.insert(3, (logging.INFO, "Downloading artifact for Dataset (dataset_id)"))
|
|
627
|
+
|
|
628
|
+
assert [(level, message) for _, level, message in caplog.record_tuples] == logs
|
|
642
629
|
|
|
643
630
|
|
|
644
631
|
def test_run_read_only(
|
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
import tempfile
|
|
1
2
|
import uuid
|
|
3
|
+
from pathlib import Path
|
|
2
4
|
|
|
3
5
|
import pytest
|
|
6
|
+
from requests import HTTPError
|
|
4
7
|
|
|
5
8
|
from arkindex.exceptions import ErrorResponse
|
|
6
9
|
from arkindex_worker.models import Artifact
|
|
@@ -196,3 +199,112 @@ def test_download_artifact(
|
|
|
196
199
|
] == BASE_API_CALLS + [
|
|
197
200
|
("GET", f"http://testserver/api/v1/task/{TASK_ID}/artifact/dataset_id.tar.zst"),
|
|
198
201
|
]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@pytest.mark.parametrize(
|
|
205
|
+
("payload", "error"),
|
|
206
|
+
[
|
|
207
|
+
# Path
|
|
208
|
+
(
|
|
209
|
+
{"path": None},
|
|
210
|
+
"path shouldn't be null, should be a Path and should exist",
|
|
211
|
+
),
|
|
212
|
+
(
|
|
213
|
+
{"path": "not path type"},
|
|
214
|
+
"path shouldn't be null, should be a Path and should exist",
|
|
215
|
+
),
|
|
216
|
+
(
|
|
217
|
+
{"path": Path("i_do_no_exist.oops")},
|
|
218
|
+
"path shouldn't be null, should be a Path and should exist",
|
|
219
|
+
),
|
|
220
|
+
],
|
|
221
|
+
)
|
|
222
|
+
def test_upload_artifact_wrong_param_path(mock_dataset_worker, payload, error):
|
|
223
|
+
with pytest.raises(AssertionError, match=error):
|
|
224
|
+
mock_dataset_worker.upload_artifact(**payload)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@pytest.fixture
|
|
228
|
+
def tmp_file(mock_dataset_worker):
|
|
229
|
+
with tempfile.NamedTemporaryFile(
|
|
230
|
+
mode="w", suffix=".txt", dir=mock_dataset_worker.work_dir
|
|
231
|
+
) as file:
|
|
232
|
+
file.write("Some content...")
|
|
233
|
+
file.seek(0)
|
|
234
|
+
|
|
235
|
+
yield Path(file.name)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_upload_artifact_api_error(responses, mock_dataset_worker, tmp_file):
|
|
239
|
+
responses.add(
|
|
240
|
+
responses.POST,
|
|
241
|
+
"http://testserver/api/v1/task/my_task/artifacts/",
|
|
242
|
+
status=418,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
with pytest.raises(ErrorResponse):
|
|
246
|
+
mock_dataset_worker.upload_artifact(path=tmp_file)
|
|
247
|
+
|
|
248
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
249
|
+
assert [
|
|
250
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
251
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/task/my_task/artifacts/")]
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def test_upload_artifact_s3_upload_error(
|
|
255
|
+
responses,
|
|
256
|
+
mock_dataset_worker,
|
|
257
|
+
tmp_file,
|
|
258
|
+
):
|
|
259
|
+
responses.add(
|
|
260
|
+
responses.POST,
|
|
261
|
+
"http://testserver/api/v1/task/my_task/artifacts/",
|
|
262
|
+
json={
|
|
263
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
264
|
+
"path": tmp_file.name,
|
|
265
|
+
"size": 15,
|
|
266
|
+
"content_type": "text/plain",
|
|
267
|
+
"s3_put_url": "http://example.com/oops.txt",
|
|
268
|
+
},
|
|
269
|
+
)
|
|
270
|
+
responses.add(responses.PUT, "http://example.com/oops.txt", status=500)
|
|
271
|
+
|
|
272
|
+
with pytest.raises(HTTPError):
|
|
273
|
+
mock_dataset_worker.upload_artifact(path=tmp_file)
|
|
274
|
+
|
|
275
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 2
|
|
276
|
+
assert [
|
|
277
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
278
|
+
] == BASE_API_CALLS + [
|
|
279
|
+
("POST", "http://testserver/api/v1/task/my_task/artifacts/"),
|
|
280
|
+
("PUT", "http://example.com/oops.txt"),
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def test_upload_artifact(
|
|
285
|
+
responses,
|
|
286
|
+
mock_dataset_worker,
|
|
287
|
+
tmp_file,
|
|
288
|
+
):
|
|
289
|
+
responses.add(
|
|
290
|
+
responses.POST,
|
|
291
|
+
"http://testserver/api/v1/task/my_task/artifacts/",
|
|
292
|
+
json={
|
|
293
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
294
|
+
"path": tmp_file.name,
|
|
295
|
+
"size": 15,
|
|
296
|
+
"content_type": "text/plain",
|
|
297
|
+
"s3_put_url": "http://example.com/test.txt",
|
|
298
|
+
},
|
|
299
|
+
)
|
|
300
|
+
responses.add(responses.PUT, "http://example.com/test.txt")
|
|
301
|
+
|
|
302
|
+
mock_dataset_worker.upload_artifact(path=tmp_file)
|
|
303
|
+
|
|
304
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 2
|
|
305
|
+
assert [
|
|
306
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
307
|
+
] == BASE_API_CALLS + [
|
|
308
|
+
("POST", "http://testserver/api/v1/task/my_task/artifacts/"),
|
|
309
|
+
("PUT", "http://example.com/test.txt"),
|
|
310
|
+
]
|
|
@@ -27,16 +27,16 @@ def default_model_version():
|
|
|
27
27
|
return {
|
|
28
28
|
"id": "model_version_id",
|
|
29
29
|
"model_id": "model_id",
|
|
30
|
-
"state": "created",
|
|
31
30
|
"parent": "42" * 16,
|
|
32
|
-
"tag": "A simple tag",
|
|
33
31
|
"description": "A description",
|
|
32
|
+
"tag": "A simple tag",
|
|
33
|
+
"state": "created",
|
|
34
|
+
"size": 42,
|
|
35
|
+
"archive_hash": "123456789",
|
|
34
36
|
"configuration": {"test": "value"},
|
|
35
|
-
"
|
|
37
|
+
"s3_etag": None,
|
|
36
38
|
"s3_put_url": "http://upload.archive",
|
|
37
|
-
"
|
|
38
|
-
"archive_hash": None,
|
|
39
|
-
"size": None,
|
|
39
|
+
"s3_url": None,
|
|
40
40
|
"created": "2000-01-01T00:00:00Z",
|
|
41
41
|
}
|
|
42
42
|
|
|
@@ -46,14 +46,11 @@ def test_create_archive(model_file_dir):
|
|
|
46
46
|
|
|
47
47
|
with create_archive(path=model_file_dir) as (
|
|
48
48
|
zst_archive_path,
|
|
49
|
-
hash,
|
|
50
49
|
size,
|
|
51
|
-
|
|
50
|
+
hash,
|
|
52
51
|
):
|
|
53
52
|
assert zst_archive_path.exists(), "The archive was not created"
|
|
54
|
-
assert hash ==
|
|
55
|
-
"Hash was not properly computed"
|
|
56
|
-
)
|
|
53
|
+
assert len(hash) == 32
|
|
57
54
|
assert 300 < size < 700
|
|
58
55
|
|
|
59
56
|
assert not zst_archive_path.exists(), "Auto removal failed"
|
|
@@ -64,37 +61,16 @@ def test_create_archive_with_subfolder(model_file_dir_with_subfolder):
|
|
|
64
61
|
|
|
65
62
|
with create_archive(path=model_file_dir_with_subfolder) as (
|
|
66
63
|
zst_archive_path,
|
|
67
|
-
hash,
|
|
68
64
|
size,
|
|
69
|
-
|
|
65
|
+
hash,
|
|
70
66
|
):
|
|
71
67
|
assert zst_archive_path.exists(), "The archive was not created"
|
|
72
|
-
assert hash ==
|
|
73
|
-
"Hash was not properly computed"
|
|
74
|
-
)
|
|
68
|
+
assert len(hash) == 32
|
|
75
69
|
assert 300 < size < 1500
|
|
76
70
|
|
|
77
71
|
assert not zst_archive_path.exists(), "Auto removal failed"
|
|
78
72
|
|
|
79
73
|
|
|
80
|
-
def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_dir):
|
|
81
|
-
s3_endpoint_url = "http://s3.localhost.com"
|
|
82
|
-
responses.add_passthru(s3_endpoint_url)
|
|
83
|
-
responses.add(responses.PUT, s3_endpoint_url, status=400)
|
|
84
|
-
|
|
85
|
-
mock_training_worker.model_version = {
|
|
86
|
-
"state": "Created",
|
|
87
|
-
"s3_put_url": s3_endpoint_url,
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
file_path = model_file_dir / "model_file.pth"
|
|
91
|
-
with pytest.raises(
|
|
92
|
-
Exception,
|
|
93
|
-
match="400 Client Error: Bad Request for url: http://s3.localhost.com/",
|
|
94
|
-
):
|
|
95
|
-
mock_training_worker.upload_to_s3(file_path)
|
|
96
|
-
|
|
97
|
-
|
|
98
74
|
@pytest.mark.parametrize(
|
|
99
75
|
"method",
|
|
100
76
|
[
|
|
@@ -102,7 +78,6 @@ def test_handle_s3_uploading_errors(responses, mock_training_worker, model_file_
|
|
|
102
78
|
"create_model_version",
|
|
103
79
|
"update_model_version",
|
|
104
80
|
"upload_to_s3",
|
|
105
|
-
"validate_model_version",
|
|
106
81
|
],
|
|
107
82
|
)
|
|
108
83
|
def test_training_mixin_read_only(mock_training_worker, method, caplog):
|
|
@@ -127,12 +102,16 @@ def test_create_model_version_already_created(mock_training_worker):
|
|
|
127
102
|
with pytest.raises(
|
|
128
103
|
AssertionError, match="A model version has already been created."
|
|
129
104
|
):
|
|
130
|
-
mock_training_worker.create_model_version(
|
|
105
|
+
mock_training_worker.create_model_version(
|
|
106
|
+
model_id="model_id", size=42, archive_hash="123456789"
|
|
107
|
+
)
|
|
131
108
|
|
|
132
109
|
|
|
133
110
|
@pytest.mark.parametrize("set_tag", [True, False])
|
|
134
111
|
def test_create_model_version(mock_training_worker, default_model_version, set_tag):
|
|
135
112
|
args = {
|
|
113
|
+
"size": 42,
|
|
114
|
+
"archive_hash": "123456789",
|
|
136
115
|
"parent": "42" * 16,
|
|
137
116
|
"tag": "A simple tag",
|
|
138
117
|
"description": "A description",
|
|
@@ -154,12 +133,12 @@ def test_create_model_version(mock_training_worker, default_model_version, set_t
|
|
|
154
133
|
|
|
155
134
|
def test_update_model_version_not_created(mock_training_worker):
|
|
156
135
|
with pytest.raises(AssertionError, match="No model version has been created yet."):
|
|
157
|
-
mock_training_worker.update_model_version()
|
|
136
|
+
mock_training_worker.update_model_version(size=42, archive_hash="123456789")
|
|
158
137
|
|
|
159
138
|
|
|
160
139
|
def test_update_model_version(mock_training_worker, default_model_version):
|
|
161
140
|
mock_training_worker.model_version = default_model_version
|
|
162
|
-
args = {"tag": "A new tag"}
|
|
141
|
+
args = {"size": 42, "archive_hash": "123456789", "tag": "A new tag"}
|
|
163
142
|
new_model_version = {**default_model_version, "tag": "A new tag"}
|
|
164
143
|
mock_training_worker.api_client.add_response(
|
|
165
144
|
"UpdateModelVersion",
|
|
@@ -169,101 +148,3 @@ def test_update_model_version(mock_training_worker, default_model_version):
|
|
|
169
148
|
)
|
|
170
149
|
mock_training_worker.update_model_version(**args)
|
|
171
150
|
assert mock_training_worker.model_version == new_model_version
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def test_validate_model_version_not_created(mock_training_worker):
|
|
175
|
-
with pytest.raises(
|
|
176
|
-
AssertionError,
|
|
177
|
-
match="You must create the model version and upload its archive before validating it.",
|
|
178
|
-
):
|
|
179
|
-
mock_training_worker.validate_model_version(hash="a", size=1, archive_hash="b")
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
@pytest.mark.parametrize("deletion_failed", [True, False])
|
|
183
|
-
def test_validate_model_version_hash_conflict(
|
|
184
|
-
mock_training_worker,
|
|
185
|
-
default_model_version,
|
|
186
|
-
caplog,
|
|
187
|
-
deletion_failed,
|
|
188
|
-
):
|
|
189
|
-
mock_training_worker.model_version = {"id": "another_id"}
|
|
190
|
-
args = {
|
|
191
|
-
"hash": "hash",
|
|
192
|
-
"archive_hash": "archive_hash",
|
|
193
|
-
"size": 30,
|
|
194
|
-
}
|
|
195
|
-
mock_training_worker.api_client.add_error_response(
|
|
196
|
-
"PartialUpdateModelVersion",
|
|
197
|
-
id="another_id",
|
|
198
|
-
status_code=409,
|
|
199
|
-
body={"state": "available", **args},
|
|
200
|
-
content={"id": ["model_version_id"]},
|
|
201
|
-
)
|
|
202
|
-
if deletion_failed:
|
|
203
|
-
mock_training_worker.api_client.add_error_response(
|
|
204
|
-
"DestroyModelVersion",
|
|
205
|
-
id="another_id",
|
|
206
|
-
status_code=403,
|
|
207
|
-
content="Not admin",
|
|
208
|
-
)
|
|
209
|
-
else:
|
|
210
|
-
mock_training_worker.api_client.add_response(
|
|
211
|
-
"DestroyModelVersion",
|
|
212
|
-
id="another_id",
|
|
213
|
-
response="No content",
|
|
214
|
-
)
|
|
215
|
-
mock_training_worker.api_client.add_response(
|
|
216
|
-
"RetrieveModelVersion",
|
|
217
|
-
id="model_version_id",
|
|
218
|
-
response=default_model_version,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
mock_training_worker.validate_model_version(**args)
|
|
222
|
-
assert mock_training_worker.model_version == default_model_version
|
|
223
|
-
error_msg = []
|
|
224
|
-
if deletion_failed:
|
|
225
|
-
error_msg = [
|
|
226
|
-
(
|
|
227
|
-
logging.ERROR,
|
|
228
|
-
"An error occurred removing the pending version another_id: Not admin.",
|
|
229
|
-
)
|
|
230
|
-
]
|
|
231
|
-
assert [
|
|
232
|
-
(level, message)
|
|
233
|
-
for module, level, message in caplog.record_tuples
|
|
234
|
-
if module == "arkindex_worker"
|
|
235
|
-
] == [
|
|
236
|
-
(
|
|
237
|
-
logging.WARNING,
|
|
238
|
-
"An available model version exists with hash hash, using it instead of the pending version.",
|
|
239
|
-
),
|
|
240
|
-
(logging.WARNING, "Removing the pending model version."),
|
|
241
|
-
*error_msg,
|
|
242
|
-
(logging.INFO, "Retrieving the existing model version."),
|
|
243
|
-
(logging.INFO, "Model version model_version_id is now available."),
|
|
244
|
-
]
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def test_validate_model_version(mock_training_worker, default_model_version, caplog):
|
|
248
|
-
mock_training_worker.model_version = {"id": "model_version_id"}
|
|
249
|
-
args = {
|
|
250
|
-
"hash": "hash",
|
|
251
|
-
"archive_hash": "archive_hash",
|
|
252
|
-
"size": 30,
|
|
253
|
-
}
|
|
254
|
-
mock_training_worker.api_client.add_response(
|
|
255
|
-
"PartialUpdateModelVersion",
|
|
256
|
-
id="model_version_id",
|
|
257
|
-
body={"state": "available", **args},
|
|
258
|
-
response=default_model_version,
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
mock_training_worker.validate_model_version(**args)
|
|
262
|
-
assert mock_training_worker.model_version == default_model_version
|
|
263
|
-
assert [
|
|
264
|
-
(level, message)
|
|
265
|
-
for module, level, message in caplog.record_tuples
|
|
266
|
-
if module == "arkindex_worker"
|
|
267
|
-
] == [
|
|
268
|
-
(logging.INFO, "Model version model_version_id is now available."),
|
|
269
|
-
]
|
tests/test_modern_config.py
CHANGED
|
@@ -79,3 +79,42 @@ def test_with_secrets(mock_base_worker_modern_conf, responses):
|
|
|
79
79
|
assert mock_base_worker_modern_conf.secrets == {
|
|
80
80
|
"a_secret": "My super duper secret value"
|
|
81
81
|
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_with_secrets_ce(mock_base_worker_modern_conf, responses, monkeypatch):
|
|
85
|
+
# Provide the full configuration directly from the worker run
|
|
86
|
+
responses.add(
|
|
87
|
+
responses.GET,
|
|
88
|
+
"http://testserver/api/v1/workers/runs/56785678-5678-5678-5678-567856785678/configuration/",
|
|
89
|
+
status=200,
|
|
90
|
+
json={
|
|
91
|
+
"configuration": [
|
|
92
|
+
{"key": "some_key", "value": "test", "secret": False},
|
|
93
|
+
{
|
|
94
|
+
"key": "a_secret",
|
|
95
|
+
"value": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
|
|
96
|
+
"secret": True,
|
|
97
|
+
},
|
|
98
|
+
]
|
|
99
|
+
},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Remove the RetrieveSecret endpoint to simulate Arkindex CE
|
|
103
|
+
monkeypatch.delitem(
|
|
104
|
+
mock_base_worker_modern_conf.api_client.document.links, "RetrieveSecret"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
mock_base_worker_modern_conf.configure()
|
|
108
|
+
|
|
109
|
+
assert mock_base_worker_modern_conf.config == {
|
|
110
|
+
"a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
|
|
111
|
+
"some_key": "test",
|
|
112
|
+
}
|
|
113
|
+
assert (
|
|
114
|
+
mock_base_worker_modern_conf.user_configuration
|
|
115
|
+
== mock_base_worker_modern_conf.config
|
|
116
|
+
)
|
|
117
|
+
assert mock_base_worker_modern_conf.secrets == {
|
|
118
|
+
# The value is used directly instead of treated as a secret name
|
|
119
|
+
"a_secret": "471b9e64-29af-48dc-8bda-1a64a2da0c12",
|
|
120
|
+
}
|
{arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{arkindex_base_worker-0.5.2a1.dist-info → arkindex_base_worker-0.5.2b1.dist-info}/top_level.txt
RENAMED
|
File without changes
|