amzn-sagemaker-checkpointing 1.0.10__tar.gz → 1.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of amzn-sagemaker-checkpointing might be problematic. Click here for more details.
- amzn_sagemaker_checkpointing-1.0.12/DEVELOPING.md +22 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/PKG-INFO +3 -3
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/README.md +2 -2
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/pyproject.toml +2 -1
- amzn_sagemaker_checkpointing-1.0.12/setup-hatch.sh +19 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/filesystem.py +117 -59
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/inmemory_client.py +16 -11
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_checkpoint.py +148 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_namespace.py +70 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_checkpoint.py +217 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_latest_checkpoints.py +116 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_namespace_config.py +118 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_or_create_namespace.py +255 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_put_checkpoint.py +209 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_reset_cluster.py +69 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/__init__.py +0 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/test_base.py +120 -0
- amzn_sagemaker_checkpointing-1.0.10/DEVELOPING.md +0 -46
- amzn_sagemaker_checkpointing-1.0.10/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_inmemory_client.py +0 -77
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/.crux_dry_run_build +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/.gitignore +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/LICENSE.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/brazil.ion +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-build-tools.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-build.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-static-analysis.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.11.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.12.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/exceptions.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/in_memory_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/sagemaker_checkpoint_config.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/py.typed +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/exceptions.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/models.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/local/disk_fs.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client_manager.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/utils/logging_utils.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/scripts/test_inmemory_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/checkpointing/filesystem/test_filesystem.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum_test.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/s3/test_s3_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/test_dummy.py +0 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Developing SageMakerCheckpointing
|
|
2
|
+
|
|
3
|
+
This package uses the [hatch](https://hatch.pypa.io/latest/) build system.
|
|
4
|
+
|
|
5
|
+
### Building
|
|
6
|
+
|
|
7
|
+
A number of scripts and commands exist in `pyproject.toml` under the `scripts` configurations with more
|
|
8
|
+
documentation in the comments of `pyproject.toml`. Running a script for a specific environment is simply running
|
|
9
|
+
`hatch run <env_name>:<script>`. You can omit the `<env_name>` for those under the `default` environment.
|
|
10
|
+
|
|
11
|
+
You need to set up hatch pluging first:
|
|
12
|
+
```
|
|
13
|
+
./setup-hatch.sh
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
### Available Hatch Commands
|
|
17
|
+
|
|
18
|
+
- **`hatch run release`** - Runs typing checks (mypy), tests, and coverage.
|
|
19
|
+
- **`hatch test --cover`** - Runs tests and coverage.
|
|
20
|
+
- **`hatch typing`** - Runs mypy type checking.
|
|
21
|
+
- **`hatch fmt`** - Formats code using ruff.
|
|
22
|
+
- **`hatch build`** - builds both source and wheel distributions in ./build directory.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: amzn-sagemaker-checkpointing
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.12
|
|
4
4
|
Summary: Amazon SageMaker Checkpointing Library
|
|
5
5
|
License: Apache 2.0
|
|
6
6
|
License-File: LICENSE.txt
|
|
@@ -95,12 +95,12 @@ following to your S3 bucket policy
|
|
|
95
95
|
```
|
|
96
96
|
|
|
97
97
|
## Installation
|
|
98
|
-
###
|
|
98
|
+
### Prerequisites
|
|
99
99
|
```bash
|
|
100
100
|
pip install s3torchconnector tenacity torch boto3 botocore
|
|
101
101
|
```
|
|
102
102
|
|
|
103
|
-
###
|
|
103
|
+
### SageMaker Checkpointing Library
|
|
104
104
|
```bash
|
|
105
105
|
pip install amzn-sagemaker-checkpointing
|
|
106
106
|
```
|
|
@@ -82,12 +82,12 @@ following to your S3 bucket policy
|
|
|
82
82
|
```
|
|
83
83
|
|
|
84
84
|
## Installation
|
|
85
|
-
###
|
|
85
|
+
### Prerequisites
|
|
86
86
|
```bash
|
|
87
87
|
pip install s3torchconnector tenacity torch boto3 botocore
|
|
88
88
|
```
|
|
89
89
|
|
|
90
|
-
###
|
|
90
|
+
### SageMaker Checkpointing Library
|
|
91
91
|
```bash
|
|
92
92
|
pip install amzn-sagemaker-checkpointing
|
|
93
93
|
```
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "amzn-sagemaker-checkpointing"
|
|
7
|
-
version = "1.0.
|
|
7
|
+
version = "1.0.12"
|
|
8
8
|
description = "Amazon SageMaker Checkpointing Library"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { "text" = "Apache 2.0" }
|
|
@@ -72,6 +72,7 @@ exclude = [ "./build", ".hatch", "private" ]
|
|
|
72
72
|
|
|
73
73
|
[tool.hatch.build]
|
|
74
74
|
directory = "./build"
|
|
75
|
+
exclude = ["DEVELOPING_INTERNAL.md"]
|
|
75
76
|
|
|
76
77
|
[tool.hatch.env]
|
|
77
78
|
requires = [ "hatch-pip-compile" ]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
set -e
|
|
3
|
+
|
|
4
|
+
mkdir -p .hatch
|
|
5
|
+
|
|
6
|
+
cat > .hatch/hatch_plugin.py << 'EOF'
|
|
7
|
+
from hatch.env.collectors.plugin.interface import EnvironmentCollectorInterface
|
|
8
|
+
|
|
9
|
+
class CustomEnvironmentCollector(EnvironmentCollectorInterface):
|
|
10
|
+
PLUGIN_NAME = 'custom'
|
|
11
|
+
|
|
12
|
+
def get_initial_config(self):
|
|
13
|
+
return {}
|
|
14
|
+
|
|
15
|
+
def finalize_config(self, config):
|
|
16
|
+
return config
|
|
17
|
+
EOF
|
|
18
|
+
|
|
19
|
+
echo "Hatch plugin created"
|
|
@@ -19,6 +19,7 @@ import pickle
|
|
|
19
19
|
import threading
|
|
20
20
|
import time
|
|
21
21
|
from dataclasses import dataclass
|
|
22
|
+
from enum import Enum
|
|
22
23
|
from logging import FileHandler
|
|
23
24
|
from typing import Any, Union
|
|
24
25
|
|
|
@@ -46,9 +47,6 @@ from torch.futures import Future
|
|
|
46
47
|
from amzn_sagemaker_checkpointing.config.sagemaker_checkpoint_config import (
|
|
47
48
|
SageMakerCheckpointConfig,
|
|
48
49
|
)
|
|
49
|
-
from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import (
|
|
50
|
-
InMemoryServerError,
|
|
51
|
-
)
|
|
52
50
|
from amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client import (
|
|
53
51
|
InMemoryCheckpointClient,
|
|
54
52
|
)
|
|
@@ -80,6 +78,15 @@ class _SageMakerStorageInfo:
|
|
|
80
78
|
offset: int
|
|
81
79
|
length: int
|
|
82
80
|
|
|
81
|
+
class StorageTier(Enum):
|
|
82
|
+
IN_MEMORY = 0
|
|
83
|
+
S3 = 1
|
|
84
|
+
|
|
85
|
+
def __str__(self):
|
|
86
|
+
return {
|
|
87
|
+
0: "IN_MEMORY",
|
|
88
|
+
1: "S3"
|
|
89
|
+
}[self.value]
|
|
83
90
|
|
|
84
91
|
def _get_step_val(step: int, path: str | os.PathLike) -> int:
|
|
85
92
|
"""
|
|
@@ -791,51 +798,42 @@ class SageMakerTieredStorageReader(StorageReader):
|
|
|
791
798
|
|
|
792
799
|
def read_metadata(self) -> Metadata:
|
|
793
800
|
"""
|
|
794
|
-
Retrieve and deserialize checkpoint metadata
|
|
801
|
+
Retrieve and deserialize checkpoint metadata.
|
|
795
802
|
|
|
796
803
|
Returns
|
|
797
804
|
-------
|
|
798
805
|
Metadata
|
|
799
806
|
Metadata object containing checkpoint information.
|
|
800
|
-
|
|
801
|
-
Raises
|
|
802
|
-
------
|
|
803
|
-
RuntimeError
|
|
804
|
-
If metadata retrieval fails.
|
|
807
|
+
(or) empty Metadata if not available
|
|
805
808
|
"""
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
self.step
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
self.logger.info(
|
|
835
|
-
f"[Rank {self.rank}] Step {self.step}: Successfully read metadata from S3, size={len(metadata_buffer)} bytes"
|
|
836
|
-
)
|
|
837
|
-
return pickle.loads(metadata_buffer)
|
|
838
|
-
return Metadata({})
|
|
809
|
+
metadata = Metadata({})
|
|
810
|
+
try:
|
|
811
|
+
if self.step is not None:
|
|
812
|
+
self.logger.info(f"[Rank {self.rank}] Step {self.step}: "
|
|
813
|
+
"reading metadata for configured step")
|
|
814
|
+
metadata = self._read_metadata_for_step(self.step)
|
|
815
|
+
else:
|
|
816
|
+
latest_step_all_tiers = self._get_latest_step_all_tiers()
|
|
817
|
+
for latest_step, tier in latest_step_all_tiers:
|
|
818
|
+
if tier == StorageTier.IN_MEMORY:
|
|
819
|
+
self.logger.info(f"[Rank {self.rank}] Attempting to read "
|
|
820
|
+
f"metadata from memory for {latest_step}")
|
|
821
|
+
step_metadata = self._read_metadata_from_memory(latest_step)
|
|
822
|
+
elif tier == StorageTier.S3:
|
|
823
|
+
self.logger.info(f"[Rank {self.rank}] Attempting to read "
|
|
824
|
+
f"metadata from S3 for {latest_step}")
|
|
825
|
+
step_metadata = self._read_metadata_from_s3(latest_step)
|
|
826
|
+
if step_metadata is not None:
|
|
827
|
+
metadata = step_metadata
|
|
828
|
+
self.step = latest_step
|
|
829
|
+
self.logger.info(f"[Rank {self.rank}] Metadata "
|
|
830
|
+
f"read from step {latest_step} of {tier} tier")
|
|
831
|
+
break
|
|
832
|
+
if self.step is None:
|
|
833
|
+
self.logger.error(f"[Rank {self.rank}] No checkpoints to read metadata")
|
|
834
|
+
except Exception as e:
|
|
835
|
+
self.logger.error(f"[Rank {self.rank}] Step {self.step}: read_metadata failed: {e}")
|
|
836
|
+
return metadata
|
|
839
837
|
|
|
840
838
|
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
|
841
839
|
"""
|
|
@@ -1083,23 +1081,6 @@ class SageMakerTieredStorageReader(StorageReader):
|
|
|
1083
1081
|
"""
|
|
1084
1082
|
return True
|
|
1085
1083
|
|
|
1086
|
-
def _find_latest_complete_step_across_tiers(self) -> int | None:
|
|
1087
|
-
"""Find latest step from both storage tiers."""
|
|
1088
|
-
memory_step = self.client.get_latest_checkpoints(limit=1)
|
|
1089
|
-
s3_step = self._find_latest_complete_step()
|
|
1090
|
-
latest_step = None
|
|
1091
|
-
if not memory_step:
|
|
1092
|
-
latest_step = s3_step
|
|
1093
|
-
elif not s3_step:
|
|
1094
|
-
latest_step = memory_step[0]
|
|
1095
|
-
else:
|
|
1096
|
-
latest_step = max(memory_step[0], s3_step)
|
|
1097
|
-
self.logger.info(
|
|
1098
|
-
f"[Rank {self.rank}] Step {self.step}: Latest steps: "
|
|
1099
|
-
f"memory:{memory_step}, s3:{s3_step}, across_tiers:{latest_step}"
|
|
1100
|
-
)
|
|
1101
|
-
return latest_step
|
|
1102
|
-
|
|
1103
1084
|
def _try_read_md_from_memory(self, step: int) -> bytes | None:
|
|
1104
1085
|
"""Try reading metadata from in-memory storage."""
|
|
1105
1086
|
try:
|
|
@@ -1252,3 +1233,80 @@ class SageMakerTieredStorageReader(StorageReader):
|
|
|
1252
1233
|
f"[Rank {self.rank}] Failed to read item {item_index} from step {step}: {e}"
|
|
1253
1234
|
)
|
|
1254
1235
|
return None
|
|
1236
|
+
|
|
1237
|
+
def _read_metadata_from_memory(self, step) -> Metadata | None:
|
|
1238
|
+
metadata = None
|
|
1239
|
+
try:
|
|
1240
|
+
metadata_buffer = self._try_read_md_from_memory(step)
|
|
1241
|
+
if metadata_buffer:
|
|
1242
|
+
self.logger.info(
|
|
1243
|
+
f"[Rank {self.rank}] Step {step}: Successfully read metadata from memory, "
|
|
1244
|
+
f"size={len(metadata_buffer)} bytes"
|
|
1245
|
+
)
|
|
1246
|
+
metadata = pickle.loads(metadata_buffer)
|
|
1247
|
+
else:
|
|
1248
|
+
self.logger.info(
|
|
1249
|
+
f"[Rank {self.rank}] Step {step}: "
|
|
1250
|
+
f"In-memory metadata not found"
|
|
1251
|
+
)
|
|
1252
|
+
except Exception as e:
|
|
1253
|
+
self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_from_memory failed: {e}")
|
|
1254
|
+
return metadata
|
|
1255
|
+
|
|
1256
|
+
def _read_metadata_from_s3(self, step) -> Metadata | None:
|
|
1257
|
+
metadata = None
|
|
1258
|
+
try:
|
|
1259
|
+
if self.s3_base_path:
|
|
1260
|
+
self.logger.info(
|
|
1261
|
+
f"[Rank {self.rank}] Step {step}: Attempting metadata read from S3"
|
|
1262
|
+
)
|
|
1263
|
+
metadata_buffer = self._try_read_md_from_s3(step)
|
|
1264
|
+
if metadata_buffer:
|
|
1265
|
+
self.logger.info(f"[Rank {self.rank}] Step {step}: "
|
|
1266
|
+
f"Successfully read metadata from size={len(metadata_buffer)} bytes")
|
|
1267
|
+
metadata = pickle.loads(metadata_buffer)
|
|
1268
|
+
else:
|
|
1269
|
+
self.logger.info(
|
|
1270
|
+
f"[Rank {self.rank}] Step {step}: "
|
|
1271
|
+
"S3 metadata not found")
|
|
1272
|
+
else:
|
|
1273
|
+
self.logger.info(
|
|
1274
|
+
f"[Rank {self.rank}] Step {step}: Unable to read metadata "
|
|
1275
|
+
"as S3 path is not provided"
|
|
1276
|
+
)
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_from_s3 failed: {e}")
|
|
1279
|
+
return metadata
|
|
1280
|
+
|
|
1281
|
+
def _read_metadata_for_step(self, step) -> Metadata:
|
|
1282
|
+
metadata = Metadata({})
|
|
1283
|
+
try:
|
|
1284
|
+
in_memory_metadata = self._read_metadata_from_memory(step)
|
|
1285
|
+
if in_memory_metadata is not None:
|
|
1286
|
+
metadata = in_memory_metadata
|
|
1287
|
+
else:
|
|
1288
|
+
s3_metadata = self._read_metadata_from_s3(step)
|
|
1289
|
+
if s3_metadata is not None:
|
|
1290
|
+
metadata = s3_metadata
|
|
1291
|
+
except Exception as e:
|
|
1292
|
+
self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_for_step failed: {e}")
|
|
1293
|
+
return metadata
|
|
1294
|
+
|
|
1295
|
+
def _get_latest_step_all_tiers(self) -> list[tuple[int, StorageTier]]:
|
|
1296
|
+
latest_step_all_tiers = []
|
|
1297
|
+
try:
|
|
1298
|
+
memory_steps = self.client.get_latest_checkpoints(limit=3)
|
|
1299
|
+
if memory_steps:
|
|
1300
|
+
latest_step_all_tiers = [(step, StorageTier.IN_MEMORY) for step in memory_steps]
|
|
1301
|
+
except Exception as e:
|
|
1302
|
+
self.logger.error(f"[Rank {self.rank}]: Failed to get memory steps: {e}")
|
|
1303
|
+
try:
|
|
1304
|
+
s3_step = self._find_latest_complete_step()
|
|
1305
|
+
if s3_step:
|
|
1306
|
+
latest_step_all_tiers.append((s3_step, StorageTier.S3))
|
|
1307
|
+
except Exception as e:
|
|
1308
|
+
self.logger.error(f"[Rank {self.rank}]: Failed to get S3 step: {e}")
|
|
1309
|
+
|
|
1310
|
+
latest_step_all_tiers.sort(key=lambda tier_step: (-tier_step[0], tier_step[1].value))
|
|
1311
|
+
self.logger.info(f"[Rank {self.rank}] Latest steps across tiers: {latest_step_all_tiers}")
|
|
1312
|
+
return latest_step_all_tiers
|
|
@@ -471,17 +471,22 @@ class InMemoryCheckpointClient:
|
|
|
471
471
|
checksum=encode_base_64(hash_xxh3_128(data)), algorithm="xxh3_128"
|
|
472
472
|
).to_json()
|
|
473
473
|
}
|
|
474
|
-
if isinstance(data, str)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
474
|
+
if isinstance(data, str):
|
|
475
|
+
try:
|
|
476
|
+
with open(data, "rb") as f:
|
|
477
|
+
self._make_request(
|
|
478
|
+
"POST",
|
|
479
|
+
endpoint,
|
|
480
|
+
data=f,
|
|
481
|
+
headers=headers,
|
|
482
|
+
timeout=timeout,
|
|
483
|
+
retries=retries,
|
|
484
|
+
retry_backoff=retry_backoff,
|
|
485
|
+
)
|
|
486
|
+
except Exception as e:
|
|
487
|
+
error_msg = f"Error opening file: {data}"
|
|
488
|
+
self._logger.error(error_msg)
|
|
489
|
+
raise InMemoryStorageError(error_msg) from e
|
|
485
490
|
else:
|
|
486
491
|
self._make_request(
|
|
487
492
|
"POST",
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from unittest.mock import Mock
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import InMemoryConfigError
|
|
6
|
+
from utils.test_base import (
|
|
7
|
+
InMemoryCheckpointClientTest,
|
|
8
|
+
BASE_URL,
|
|
9
|
+
NAMESPACE,
|
|
10
|
+
RANK,
|
|
11
|
+
REQUEST_ERROR_CASES,
|
|
12
|
+
REQUEST_TIMEOUT,
|
|
13
|
+
WORLD_SIZE
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestDeleteCheckpoint(InMemoryCheckpointClientTest):
|
|
18
|
+
STEP = 42
|
|
19
|
+
|
|
20
|
+
def setup_method(self):
|
|
21
|
+
super().setup_method()
|
|
22
|
+
self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
|
|
23
|
+
|
|
24
|
+
def test_delete_checkpoint_success(self):
|
|
25
|
+
# Arrange
|
|
26
|
+
mock_response = Mock(status_code=200)
|
|
27
|
+
self.mock_session.request.return_value = mock_response
|
|
28
|
+
|
|
29
|
+
# Act
|
|
30
|
+
self.client.delete_checkpoint(step=self.STEP)
|
|
31
|
+
|
|
32
|
+
# Assert
|
|
33
|
+
self.assert_http_adapter_and_retry_config()
|
|
34
|
+
self.mock_session.request.assert_called_once_with(
|
|
35
|
+
method="DELETE",
|
|
36
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
37
|
+
params=None,
|
|
38
|
+
data=None,
|
|
39
|
+
headers=None,
|
|
40
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def test_delete_checkpoint_with_custom_rank(self):
|
|
44
|
+
# Arrange
|
|
45
|
+
mock_response = Mock(status_code=200)
|
|
46
|
+
self.mock_session.request.return_value = mock_response
|
|
47
|
+
custom_rank = 5
|
|
48
|
+
custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
|
|
49
|
+
|
|
50
|
+
# Act
|
|
51
|
+
self.client.delete_checkpoint(step=self.STEP, rank=custom_rank)
|
|
52
|
+
|
|
53
|
+
# Assert
|
|
54
|
+
self.assert_http_adapter_and_retry_config()
|
|
55
|
+
self.mock_session.request.assert_called_once_with(
|
|
56
|
+
method="DELETE",
|
|
57
|
+
url=f"{BASE_URL}/{custom_path}",
|
|
58
|
+
params=None,
|
|
59
|
+
data=None,
|
|
60
|
+
headers=None,
|
|
61
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def test_delete_checkpoint_with_metadata_index(self):
|
|
65
|
+
# Arrange
|
|
66
|
+
mock_response = Mock(status_code=200)
|
|
67
|
+
self.mock_session.request.return_value = mock_response
|
|
68
|
+
metadata_index = 0
|
|
69
|
+
metadata_rank = int(WORLD_SIZE) + metadata_index
|
|
70
|
+
metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
|
|
71
|
+
|
|
72
|
+
# Act
|
|
73
|
+
self.client.delete_checkpoint(step=self.STEP, metadata_index=metadata_index)
|
|
74
|
+
|
|
75
|
+
# Assert
|
|
76
|
+
self.assert_http_adapter_and_retry_config()
|
|
77
|
+
self.mock_session.request.assert_called_once_with(
|
|
78
|
+
method="DELETE",
|
|
79
|
+
url=f"{BASE_URL}/{metadata_path}",
|
|
80
|
+
params=None,
|
|
81
|
+
data=None,
|
|
82
|
+
headers=None,
|
|
83
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def test_delete_checkpoint_with_custom_timeout(self):
|
|
87
|
+
# Arrange
|
|
88
|
+
mock_response = Mock(status_code=200)
|
|
89
|
+
self.mock_session.request.return_value = mock_response
|
|
90
|
+
|
|
91
|
+
# Act
|
|
92
|
+
self.client.delete_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
|
|
93
|
+
|
|
94
|
+
# Assert
|
|
95
|
+
self.assert_http_adapter_and_retry_config()
|
|
96
|
+
self.mock_session.request.assert_called_once_with(
|
|
97
|
+
method="DELETE",
|
|
98
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
99
|
+
params=None,
|
|
100
|
+
data=None,
|
|
101
|
+
headers=None,
|
|
102
|
+
timeout=REQUEST_TIMEOUT
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def test_delete_checkpoint_with_string_step(self):
|
|
106
|
+
# Arrange
|
|
107
|
+
mock_response = Mock(status_code=200)
|
|
108
|
+
self.mock_session.request.return_value = mock_response
|
|
109
|
+
step = "latest"
|
|
110
|
+
path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{step}"
|
|
111
|
+
|
|
112
|
+
# Act
|
|
113
|
+
self.client.delete_checkpoint(step=step)
|
|
114
|
+
|
|
115
|
+
# Assert
|
|
116
|
+
self.assert_http_adapter_and_retry_config()
|
|
117
|
+
self.mock_session.request.assert_called_once_with(
|
|
118
|
+
method="DELETE",
|
|
119
|
+
url=f"{BASE_URL}/{path}",
|
|
120
|
+
params=None,
|
|
121
|
+
data=None,
|
|
122
|
+
headers=None,
|
|
123
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def test_delete_checkpoint_invalid_metadata_index(self):
|
|
127
|
+
# Act & Assert
|
|
128
|
+
with pytest.raises(InMemoryConfigError) as exc_info:
|
|
129
|
+
self.client.delete_checkpoint(step=self.STEP, metadata_index=999)
|
|
130
|
+
assert "Invalid metadata_index" in str(exc_info.value)
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
133
|
+
def test_delete_checkpoint_request_errors(self, test_case):
|
|
134
|
+
# Arrange
|
|
135
|
+
if "exception" in test_case["response"]:
|
|
136
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
137
|
+
else:
|
|
138
|
+
mock_response = Mock()
|
|
139
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
140
|
+
mock_response.text = test_case["response"]["text"]
|
|
141
|
+
self.mock_session.request.return_value = mock_response
|
|
142
|
+
|
|
143
|
+
# Act & Assert
|
|
144
|
+
self.assert_request_error(
|
|
145
|
+
test_case,
|
|
146
|
+
self.client.delete_checkpoint,
|
|
147
|
+
step=self.STEP
|
|
148
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from unittest.mock import Mock, call
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from utils.test_base import (
|
|
6
|
+
InMemoryCheckpointClientTest,
|
|
7
|
+
BASE_URL,
|
|
8
|
+
NAMESPACE,
|
|
9
|
+
REQUEST_ERROR_CASES,
|
|
10
|
+
REQUEST_TIMEOUT,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestDeleteNamespace(InMemoryCheckpointClientTest):
|
|
15
|
+
def setup_method(self):
|
|
16
|
+
super().setup_method()
|
|
17
|
+
self.namespace_path = f"v1/cp/namespaces/{NAMESPACE}"
|
|
18
|
+
|
|
19
|
+
@pytest.mark.parametrize(
|
|
20
|
+
"params",
|
|
21
|
+
[
|
|
22
|
+
{"timeout": None},
|
|
23
|
+
{"timeout": REQUEST_TIMEOUT},
|
|
24
|
+
],
|
|
25
|
+
)
|
|
26
|
+
def test_delete_namespace_success(self, params):
|
|
27
|
+
# Arrange
|
|
28
|
+
mock_response = Mock()
|
|
29
|
+
mock_response.status_code = 200
|
|
30
|
+
self.mock_session.request.return_value = mock_response
|
|
31
|
+
|
|
32
|
+
# Act
|
|
33
|
+
self.client.delete_namespace(**params)
|
|
34
|
+
|
|
35
|
+
# Assert
|
|
36
|
+
self.assert_http_adapter_and_retry_config()
|
|
37
|
+
self.mock_session.request.assert_called_once_with(
|
|
38
|
+
method="DELETE",
|
|
39
|
+
url=f"{BASE_URL}/{self.namespace_path}",
|
|
40
|
+
params=None,
|
|
41
|
+
data=None,
|
|
42
|
+
headers=None,
|
|
43
|
+
timeout=params["timeout"] or InMemoryClientConfig.request_timeout
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
47
|
+
def test_delete_namespace_request_errors(self, test_case):
|
|
48
|
+
# Arrange
|
|
49
|
+
if "exception" in test_case["response"]:
|
|
50
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
51
|
+
else:
|
|
52
|
+
mock_response = Mock()
|
|
53
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
54
|
+
mock_response.text = test_case["response"]["text"]
|
|
55
|
+
self.mock_session.request.return_value = mock_response
|
|
56
|
+
|
|
57
|
+
# Act & Assert
|
|
58
|
+
self.assert_http_adapter_and_retry_config()
|
|
59
|
+
self.assert_request_error(
|
|
60
|
+
test_case,
|
|
61
|
+
self.client.delete_namespace
|
|
62
|
+
)
|
|
63
|
+
self.mock_session.request.assert_called_once_with(
|
|
64
|
+
method="DELETE",
|
|
65
|
+
url=f"{BASE_URL}/{self.namespace_path}",
|
|
66
|
+
params=None,
|
|
67
|
+
data=None,
|
|
68
|
+
headers=None,
|
|
69
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
70
|
+
)
|