amzn-sagemaker-checkpointing 1.0.11__tar.gz → 1.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of amzn-sagemaker-checkpointing might be problematic. Click here for more details.
- amzn_sagemaker_checkpointing-1.0.12/DEVELOPING.md +22 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/PKG-INFO +1 -1
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/pyproject.toml +2 -1
- amzn_sagemaker_checkpointing-1.0.12/setup-hatch.sh +19 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/inmemory_client.py +16 -11
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_checkpoint.py +148 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_namespace.py +70 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_checkpoint.py +217 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_latest_checkpoints.py +116 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_namespace_config.py +118 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_or_create_namespace.py +255 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_put_checkpoint.py +209 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_reset_cluster.py +69 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/__init__.py +0 -0
- amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/test_base.py +120 -0
- amzn_sagemaker_checkpointing-1.0.11/DEVELOPING.md +0 -46
- amzn_sagemaker_checkpointing-1.0.11/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_inmemory_client.py +0 -77
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/.crux_dry_run_build +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/.gitignore +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/LICENSE.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/README.md +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/brazil.ion +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-build-tools.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-build.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-static-analysis.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.11.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.12.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements.txt +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/exceptions.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/filesystem.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/in_memory_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/sagemaker_checkpoint_config.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/py.typed +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/exceptions.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/models.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/local/disk_fs.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/__init__.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client_manager.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/utils/logging_utils.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/scripts/test_inmemory_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/checkpointing/filesystem/test_filesystem.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum_test.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/s3/test_s3_client.py +0 -0
- {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/test_dummy.py +0 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Developing SageMakerCheckpointing
|
|
2
|
+
|
|
3
|
+
This package uses the [hatch](https://hatch.pypa.io/latest/) build system.
|
|
4
|
+
|
|
5
|
+
### Building
|
|
6
|
+
|
|
7
|
+
A number of scripts and commands exist in `pyproject.toml` under the `scripts` configurations with more
|
|
8
|
+
documentation in the comments of `pyproject.toml`. Running a script for a specific environment is simply running
|
|
9
|
+
`hatch run <env_name>:<script>`. You can omit the `<env_name>` for those under the `default` environment.
|
|
10
|
+
|
|
11
|
+
You need to set up hatch pluging first:
|
|
12
|
+
```
|
|
13
|
+
./setup-hatch.sh
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
### Available Hatch Commands
|
|
17
|
+
|
|
18
|
+
- **`hatch run release`** - Runs typing checks (mypy), tests, and coverage.
|
|
19
|
+
- **`hatch test --cover`** - Runs tests and coverage.
|
|
20
|
+
- **`hatch typing`** - Runs mypy type checking.
|
|
21
|
+
- **`hatch fmt`** - Formats code using ruff.
|
|
22
|
+
- **`hatch build`** - builds both source and wheel distributions in ./build directory.
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "amzn-sagemaker-checkpointing"
|
|
7
|
-
version = "1.0.
|
|
7
|
+
version = "1.0.12"
|
|
8
8
|
description = "Amazon SageMaker Checkpointing Library"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { "text" = "Apache 2.0" }
|
|
@@ -72,6 +72,7 @@ exclude = [ "./build", ".hatch", "private" ]
|
|
|
72
72
|
|
|
73
73
|
[tool.hatch.build]
|
|
74
74
|
directory = "./build"
|
|
75
|
+
exclude = ["DEVELOPING_INTERNAL.md"]
|
|
75
76
|
|
|
76
77
|
[tool.hatch.env]
|
|
77
78
|
requires = [ "hatch-pip-compile" ]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
set -e
|
|
3
|
+
|
|
4
|
+
mkdir -p .hatch
|
|
5
|
+
|
|
6
|
+
cat > .hatch/hatch_plugin.py << 'EOF'
|
|
7
|
+
from hatch.env.collectors.plugin.interface import EnvironmentCollectorInterface
|
|
8
|
+
|
|
9
|
+
class CustomEnvironmentCollector(EnvironmentCollectorInterface):
|
|
10
|
+
PLUGIN_NAME = 'custom'
|
|
11
|
+
|
|
12
|
+
def get_initial_config(self):
|
|
13
|
+
return {}
|
|
14
|
+
|
|
15
|
+
def finalize_config(self, config):
|
|
16
|
+
return config
|
|
17
|
+
EOF
|
|
18
|
+
|
|
19
|
+
echo "Hatch plugin created"
|
|
@@ -471,17 +471,22 @@ class InMemoryCheckpointClient:
|
|
|
471
471
|
checksum=encode_base_64(hash_xxh3_128(data)), algorithm="xxh3_128"
|
|
472
472
|
).to_json()
|
|
473
473
|
}
|
|
474
|
-
if isinstance(data, str)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
474
|
+
if isinstance(data, str):
|
|
475
|
+
try:
|
|
476
|
+
with open(data, "rb") as f:
|
|
477
|
+
self._make_request(
|
|
478
|
+
"POST",
|
|
479
|
+
endpoint,
|
|
480
|
+
data=f,
|
|
481
|
+
headers=headers,
|
|
482
|
+
timeout=timeout,
|
|
483
|
+
retries=retries,
|
|
484
|
+
retry_backoff=retry_backoff,
|
|
485
|
+
)
|
|
486
|
+
except Exception as e:
|
|
487
|
+
error_msg = f"Error opening file: {data}"
|
|
488
|
+
self._logger.error(error_msg)
|
|
489
|
+
raise InMemoryStorageError(error_msg) from e
|
|
485
490
|
else:
|
|
486
491
|
self._make_request(
|
|
487
492
|
"POST",
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from unittest.mock import Mock
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import InMemoryConfigError
|
|
6
|
+
from utils.test_base import (
|
|
7
|
+
InMemoryCheckpointClientTest,
|
|
8
|
+
BASE_URL,
|
|
9
|
+
NAMESPACE,
|
|
10
|
+
RANK,
|
|
11
|
+
REQUEST_ERROR_CASES,
|
|
12
|
+
REQUEST_TIMEOUT,
|
|
13
|
+
WORLD_SIZE
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestDeleteCheckpoint(InMemoryCheckpointClientTest):
|
|
18
|
+
STEP = 42
|
|
19
|
+
|
|
20
|
+
def setup_method(self):
|
|
21
|
+
super().setup_method()
|
|
22
|
+
self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
|
|
23
|
+
|
|
24
|
+
def test_delete_checkpoint_success(self):
|
|
25
|
+
# Arrange
|
|
26
|
+
mock_response = Mock(status_code=200)
|
|
27
|
+
self.mock_session.request.return_value = mock_response
|
|
28
|
+
|
|
29
|
+
# Act
|
|
30
|
+
self.client.delete_checkpoint(step=self.STEP)
|
|
31
|
+
|
|
32
|
+
# Assert
|
|
33
|
+
self.assert_http_adapter_and_retry_config()
|
|
34
|
+
self.mock_session.request.assert_called_once_with(
|
|
35
|
+
method="DELETE",
|
|
36
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
37
|
+
params=None,
|
|
38
|
+
data=None,
|
|
39
|
+
headers=None,
|
|
40
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def test_delete_checkpoint_with_custom_rank(self):
|
|
44
|
+
# Arrange
|
|
45
|
+
mock_response = Mock(status_code=200)
|
|
46
|
+
self.mock_session.request.return_value = mock_response
|
|
47
|
+
custom_rank = 5
|
|
48
|
+
custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
|
|
49
|
+
|
|
50
|
+
# Act
|
|
51
|
+
self.client.delete_checkpoint(step=self.STEP, rank=custom_rank)
|
|
52
|
+
|
|
53
|
+
# Assert
|
|
54
|
+
self.assert_http_adapter_and_retry_config()
|
|
55
|
+
self.mock_session.request.assert_called_once_with(
|
|
56
|
+
method="DELETE",
|
|
57
|
+
url=f"{BASE_URL}/{custom_path}",
|
|
58
|
+
params=None,
|
|
59
|
+
data=None,
|
|
60
|
+
headers=None,
|
|
61
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def test_delete_checkpoint_with_metadata_index(self):
|
|
65
|
+
# Arrange
|
|
66
|
+
mock_response = Mock(status_code=200)
|
|
67
|
+
self.mock_session.request.return_value = mock_response
|
|
68
|
+
metadata_index = 0
|
|
69
|
+
metadata_rank = int(WORLD_SIZE) + metadata_index
|
|
70
|
+
metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
|
|
71
|
+
|
|
72
|
+
# Act
|
|
73
|
+
self.client.delete_checkpoint(step=self.STEP, metadata_index=metadata_index)
|
|
74
|
+
|
|
75
|
+
# Assert
|
|
76
|
+
self.assert_http_adapter_and_retry_config()
|
|
77
|
+
self.mock_session.request.assert_called_once_with(
|
|
78
|
+
method="DELETE",
|
|
79
|
+
url=f"{BASE_URL}/{metadata_path}",
|
|
80
|
+
params=None,
|
|
81
|
+
data=None,
|
|
82
|
+
headers=None,
|
|
83
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def test_delete_checkpoint_with_custom_timeout(self):
|
|
87
|
+
# Arrange
|
|
88
|
+
mock_response = Mock(status_code=200)
|
|
89
|
+
self.mock_session.request.return_value = mock_response
|
|
90
|
+
|
|
91
|
+
# Act
|
|
92
|
+
self.client.delete_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
|
|
93
|
+
|
|
94
|
+
# Assert
|
|
95
|
+
self.assert_http_adapter_and_retry_config()
|
|
96
|
+
self.mock_session.request.assert_called_once_with(
|
|
97
|
+
method="DELETE",
|
|
98
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
99
|
+
params=None,
|
|
100
|
+
data=None,
|
|
101
|
+
headers=None,
|
|
102
|
+
timeout=REQUEST_TIMEOUT
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def test_delete_checkpoint_with_string_step(self):
|
|
106
|
+
# Arrange
|
|
107
|
+
mock_response = Mock(status_code=200)
|
|
108
|
+
self.mock_session.request.return_value = mock_response
|
|
109
|
+
step = "latest"
|
|
110
|
+
path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{step}"
|
|
111
|
+
|
|
112
|
+
# Act
|
|
113
|
+
self.client.delete_checkpoint(step=step)
|
|
114
|
+
|
|
115
|
+
# Assert
|
|
116
|
+
self.assert_http_adapter_and_retry_config()
|
|
117
|
+
self.mock_session.request.assert_called_once_with(
|
|
118
|
+
method="DELETE",
|
|
119
|
+
url=f"{BASE_URL}/{path}",
|
|
120
|
+
params=None,
|
|
121
|
+
data=None,
|
|
122
|
+
headers=None,
|
|
123
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def test_delete_checkpoint_invalid_metadata_index(self):
|
|
127
|
+
# Act & Assert
|
|
128
|
+
with pytest.raises(InMemoryConfigError) as exc_info:
|
|
129
|
+
self.client.delete_checkpoint(step=self.STEP, metadata_index=999)
|
|
130
|
+
assert "Invalid metadata_index" in str(exc_info.value)
|
|
131
|
+
|
|
132
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
133
|
+
def test_delete_checkpoint_request_errors(self, test_case):
|
|
134
|
+
# Arrange
|
|
135
|
+
if "exception" in test_case["response"]:
|
|
136
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
137
|
+
else:
|
|
138
|
+
mock_response = Mock()
|
|
139
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
140
|
+
mock_response.text = test_case["response"]["text"]
|
|
141
|
+
self.mock_session.request.return_value = mock_response
|
|
142
|
+
|
|
143
|
+
# Act & Assert
|
|
144
|
+
self.assert_request_error(
|
|
145
|
+
test_case,
|
|
146
|
+
self.client.delete_checkpoint,
|
|
147
|
+
step=self.STEP
|
|
148
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from unittest.mock import Mock, call
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from utils.test_base import (
|
|
6
|
+
InMemoryCheckpointClientTest,
|
|
7
|
+
BASE_URL,
|
|
8
|
+
NAMESPACE,
|
|
9
|
+
REQUEST_ERROR_CASES,
|
|
10
|
+
REQUEST_TIMEOUT,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestDeleteNamespace(InMemoryCheckpointClientTest):
|
|
15
|
+
def setup_method(self):
|
|
16
|
+
super().setup_method()
|
|
17
|
+
self.namespace_path = f"v1/cp/namespaces/{NAMESPACE}"
|
|
18
|
+
|
|
19
|
+
@pytest.mark.parametrize(
|
|
20
|
+
"params",
|
|
21
|
+
[
|
|
22
|
+
{"timeout": None},
|
|
23
|
+
{"timeout": REQUEST_TIMEOUT},
|
|
24
|
+
],
|
|
25
|
+
)
|
|
26
|
+
def test_delete_namespace_success(self, params):
|
|
27
|
+
# Arrange
|
|
28
|
+
mock_response = Mock()
|
|
29
|
+
mock_response.status_code = 200
|
|
30
|
+
self.mock_session.request.return_value = mock_response
|
|
31
|
+
|
|
32
|
+
# Act
|
|
33
|
+
self.client.delete_namespace(**params)
|
|
34
|
+
|
|
35
|
+
# Assert
|
|
36
|
+
self.assert_http_adapter_and_retry_config()
|
|
37
|
+
self.mock_session.request.assert_called_once_with(
|
|
38
|
+
method="DELETE",
|
|
39
|
+
url=f"{BASE_URL}/{self.namespace_path}",
|
|
40
|
+
params=None,
|
|
41
|
+
data=None,
|
|
42
|
+
headers=None,
|
|
43
|
+
timeout=params["timeout"] or InMemoryClientConfig.request_timeout
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
47
|
+
def test_delete_namespace_request_errors(self, test_case):
|
|
48
|
+
# Arrange
|
|
49
|
+
if "exception" in test_case["response"]:
|
|
50
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
51
|
+
else:
|
|
52
|
+
mock_response = Mock()
|
|
53
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
54
|
+
mock_response.text = test_case["response"]["text"]
|
|
55
|
+
self.mock_session.request.return_value = mock_response
|
|
56
|
+
|
|
57
|
+
# Act & Assert
|
|
58
|
+
self.assert_http_adapter_and_retry_config()
|
|
59
|
+
self.assert_request_error(
|
|
60
|
+
test_case,
|
|
61
|
+
self.client.delete_namespace
|
|
62
|
+
)
|
|
63
|
+
self.mock_session.request.assert_called_once_with(
|
|
64
|
+
method="DELETE",
|
|
65
|
+
url=f"{BASE_URL}/{self.namespace_path}",
|
|
66
|
+
params=None,
|
|
67
|
+
data=None,
|
|
68
|
+
headers=None,
|
|
69
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
70
|
+
)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from unittest.mock import Mock, call, patch
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import (
|
|
6
|
+
InMemoryConfigError,
|
|
7
|
+
InMemoryStorageError
|
|
8
|
+
)
|
|
9
|
+
from utils.test_base import (
|
|
10
|
+
InMemoryCheckpointClientTest,
|
|
11
|
+
BASE_URL,
|
|
12
|
+
NAMESPACE,
|
|
13
|
+
RANK,
|
|
14
|
+
REQUEST_ERROR_CASES,
|
|
15
|
+
REQUEST_TIMEOUT,
|
|
16
|
+
WORLD_SIZE
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestGetCheckpoint(InMemoryCheckpointClientTest):
|
|
21
|
+
# Constants
|
|
22
|
+
STEP = 42
|
|
23
|
+
TEST_CONTENT = b"test checkpoint data"
|
|
24
|
+
MOCK_CHECKSUM = "mock-checksum"
|
|
25
|
+
MOCK_ENCODED_CHECKSUM = "bW9jay1jaGVja3N1bQ==" # base64 of mock-checksum
|
|
26
|
+
CONTENT_LENGTH = len(TEST_CONTENT)
|
|
27
|
+
RESPONSE_HEADERS = {
|
|
28
|
+
"Content-Length": str(CONTENT_LENGTH),
|
|
29
|
+
"Shard-Meta": '{"checksum": "bW9jay1jaGVja3N1bQ==", "algorithm": "xxh3_128"}'
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def setup_method(self):
|
|
33
|
+
super().setup_method()
|
|
34
|
+
self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
|
|
35
|
+
|
|
36
|
+
# Mock checksum functions
|
|
37
|
+
self.mock_decode_base_64 = patch(
|
|
38
|
+
"amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client.decode_base_64"
|
|
39
|
+
).start()
|
|
40
|
+
self.mock_hash_xxh3_128 = patch(
|
|
41
|
+
"amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client.hash_xxh3_128"
|
|
42
|
+
).start()
|
|
43
|
+
self.mock_decode_base_64.return_value = self.MOCK_CHECKSUM
|
|
44
|
+
self.mock_hash_xxh3_128.return_value = self.MOCK_CHECKSUM
|
|
45
|
+
|
|
46
|
+
def test_get_checkpoint_success(self):
|
|
47
|
+
# Arrange
|
|
48
|
+
mock_response = Mock()
|
|
49
|
+
mock_response.status_code = 200
|
|
50
|
+
mock_response.content = self.TEST_CONTENT
|
|
51
|
+
mock_response.headers = self.RESPONSE_HEADERS
|
|
52
|
+
self.mock_session.request.return_value = mock_response
|
|
53
|
+
|
|
54
|
+
# Act
|
|
55
|
+
result = self.client.get_checkpoint(step=self.STEP)
|
|
56
|
+
|
|
57
|
+
# Assert
|
|
58
|
+
self.assert_http_adapter_and_retry_config()
|
|
59
|
+
self.mock_hash_xxh3_128.assert_called_once_with(self.TEST_CONTENT)
|
|
60
|
+
self.mock_decode_base_64.assert_called_once_with(self.MOCK_ENCODED_CHECKSUM)
|
|
61
|
+
self.mock_session.request.assert_called_once_with(
|
|
62
|
+
method="GET",
|
|
63
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
64
|
+
params=None,
|
|
65
|
+
data=None,
|
|
66
|
+
headers=None,
|
|
67
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
68
|
+
)
|
|
69
|
+
assert result == self.TEST_CONTENT
|
|
70
|
+
|
|
71
|
+
def test_get_checkpoint_with_custom_rank(self):
|
|
72
|
+
# Arrange
|
|
73
|
+
mock_response = Mock()
|
|
74
|
+
mock_response.status_code = 200
|
|
75
|
+
mock_response.content = self.TEST_CONTENT
|
|
76
|
+
mock_response.headers = self.RESPONSE_HEADERS
|
|
77
|
+
self.mock_session.request.return_value = mock_response
|
|
78
|
+
custom_rank = 5
|
|
79
|
+
custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
|
|
80
|
+
|
|
81
|
+
# Act
|
|
82
|
+
result = self.client.get_checkpoint(step=self.STEP, rank=custom_rank)
|
|
83
|
+
|
|
84
|
+
# Assert
|
|
85
|
+
self.assert_http_adapter_and_retry_config()
|
|
86
|
+
self.mock_session.request.assert_called_once_with(
|
|
87
|
+
method="GET",
|
|
88
|
+
url=f"{BASE_URL}/{custom_path}",
|
|
89
|
+
params=None,
|
|
90
|
+
data=None,
|
|
91
|
+
headers=None,
|
|
92
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
93
|
+
)
|
|
94
|
+
assert result == self.TEST_CONTENT
|
|
95
|
+
|
|
96
|
+
def test_get_checkpoint_with_metadata_index(self):
|
|
97
|
+
# Arrange
|
|
98
|
+
mock_response = Mock()
|
|
99
|
+
mock_response.status_code = 200
|
|
100
|
+
mock_response.content = self.TEST_CONTENT
|
|
101
|
+
mock_response.headers = self.RESPONSE_HEADERS
|
|
102
|
+
self.mock_session.request.return_value = mock_response
|
|
103
|
+
metadata_index = 0
|
|
104
|
+
metadata_rank = int(WORLD_SIZE) + metadata_index
|
|
105
|
+
metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
|
|
106
|
+
|
|
107
|
+
# Act
|
|
108
|
+
result = self.client.get_checkpoint(step=self.STEP, metadata_index=metadata_index)
|
|
109
|
+
|
|
110
|
+
# Assert
|
|
111
|
+
self.assert_http_adapter_and_retry_config()
|
|
112
|
+
self.mock_session.request.assert_called_once_with(
|
|
113
|
+
method="GET",
|
|
114
|
+
url=f"{BASE_URL}/{metadata_path}",
|
|
115
|
+
params=None,
|
|
116
|
+
data=None,
|
|
117
|
+
headers=None,
|
|
118
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
119
|
+
)
|
|
120
|
+
assert result == self.TEST_CONTENT
|
|
121
|
+
|
|
122
|
+
def test_get_checkpoint_with_custom_timeout(self):
|
|
123
|
+
# Arrange
|
|
124
|
+
mock_response = Mock()
|
|
125
|
+
mock_response.status_code = 200
|
|
126
|
+
mock_response.content = self.TEST_CONTENT
|
|
127
|
+
mock_response.headers = self.RESPONSE_HEADERS
|
|
128
|
+
self.mock_session.request.return_value = mock_response
|
|
129
|
+
|
|
130
|
+
# Act
|
|
131
|
+
result = self.client.get_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
|
|
132
|
+
|
|
133
|
+
# Assert
|
|
134
|
+
self.assert_http_adapter_and_retry_config()
|
|
135
|
+
self.mock_session.request.assert_called_once_with(
|
|
136
|
+
method="GET",
|
|
137
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
138
|
+
params=None,
|
|
139
|
+
data=None,
|
|
140
|
+
headers=None,
|
|
141
|
+
timeout=REQUEST_TIMEOUT
|
|
142
|
+
)
|
|
143
|
+
assert result == self.TEST_CONTENT
|
|
144
|
+
|
|
145
|
+
def test_get_checkpoint_not_found(self):
|
|
146
|
+
# Arrange
|
|
147
|
+
mock_response = Mock(status_code=404)
|
|
148
|
+
self.mock_session.request.return_value = mock_response
|
|
149
|
+
|
|
150
|
+
# Act
|
|
151
|
+
result = self.client.get_checkpoint(step=self.STEP)
|
|
152
|
+
|
|
153
|
+
# Assert
|
|
154
|
+
self.assert_http_adapter_and_retry_config()
|
|
155
|
+
self.mock_session.request.assert_called_once_with(
|
|
156
|
+
method="GET",
|
|
157
|
+
url=f"{BASE_URL}/{self.checkpoint_path}",
|
|
158
|
+
params=None,
|
|
159
|
+
data=None,
|
|
160
|
+
headers=None,
|
|
161
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
162
|
+
)
|
|
163
|
+
assert result is None
|
|
164
|
+
|
|
165
|
+
def test_get_checkpoint_checksum_mismatch(self):
|
|
166
|
+
# Arrange
|
|
167
|
+
mock_response = Mock()
|
|
168
|
+
mock_response.status_code = 200
|
|
169
|
+
mock_response.content = self.TEST_CONTENT
|
|
170
|
+
mock_response.headers = self.RESPONSE_HEADERS
|
|
171
|
+
self.mock_session.request.return_value = mock_response
|
|
172
|
+
self.mock_hash_xxh3_128.return_value = "different-checksum"
|
|
173
|
+
|
|
174
|
+
# Act & Assert
|
|
175
|
+
with pytest.raises(InMemoryStorageError) as exc_info:
|
|
176
|
+
self.client.get_checkpoint(step=self.STEP)
|
|
177
|
+
assert "Checksum mismatch in response" in str(exc_info.value)
|
|
178
|
+
|
|
179
|
+
def test_get_checkpoint_content_length_mismatch(self):
|
|
180
|
+
# Arrange
|
|
181
|
+
mock_response = Mock()
|
|
182
|
+
mock_response.status_code = 200
|
|
183
|
+
mock_response.content = self.TEST_CONTENT
|
|
184
|
+
mock_response.headers = {
|
|
185
|
+
"Content-Length": str(self.CONTENT_LENGTH + 1), # Wrong content length
|
|
186
|
+
"Shard-Meta": '{"checksum": "bW9jay1jaGVja3N1bQ==", "algorithm": "xxh3_128"}'
|
|
187
|
+
}
|
|
188
|
+
self.mock_session.request.return_value = mock_response
|
|
189
|
+
|
|
190
|
+
# Act & Assert
|
|
191
|
+
with pytest.raises(InMemoryStorageError) as exc_info:
|
|
192
|
+
self.client.get_checkpoint(step=self.STEP)
|
|
193
|
+
assert "Content length mismatch" in str(exc_info.value)
|
|
194
|
+
|
|
195
|
+
def test_get_checkpoint_invalid_metadata_index(self):
|
|
196
|
+
# Act & Assert
|
|
197
|
+
with pytest.raises(InMemoryConfigError) as exc_info:
|
|
198
|
+
self.client.get_checkpoint(step=self.STEP, metadata_index=999)
|
|
199
|
+
assert "Invalid metadata_index" in str(exc_info.value)
|
|
200
|
+
|
|
201
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
202
|
+
def test_get_checkpoint_request_errors(self, test_case):
|
|
203
|
+
# Arrange
|
|
204
|
+
if "exception" in test_case["response"]:
|
|
205
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
206
|
+
else:
|
|
207
|
+
mock_response = Mock()
|
|
208
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
209
|
+
mock_response.text = test_case["response"]["text"]
|
|
210
|
+
self.mock_session.request.return_value = mock_response
|
|
211
|
+
|
|
212
|
+
# Act & Assert
|
|
213
|
+
self.assert_request_error(
|
|
214
|
+
test_case,
|
|
215
|
+
self.client.get_checkpoint,
|
|
216
|
+
step=self.STEP
|
|
217
|
+
)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from unittest.mock import Mock, call
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
|
|
5
|
+
from utils.test_base import (
|
|
6
|
+
InMemoryCheckpointClientTest,
|
|
7
|
+
BASE_URL,
|
|
8
|
+
NAMESPACE,
|
|
9
|
+
REQUEST_ERROR_CASES,
|
|
10
|
+
REQUEST_TIMEOUT,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestGetLatestCheckpoints(InMemoryCheckpointClientTest):
|
|
15
|
+
def setup_method(self):
|
|
16
|
+
super().setup_method()
|
|
17
|
+
self.checkpoints_path = f"v1/cp/checkpoints/{NAMESPACE}/latest"
|
|
18
|
+
|
|
19
|
+
@pytest.mark.parametrize(
|
|
20
|
+
"params, api_steps, expected_steps",
|
|
21
|
+
[
|
|
22
|
+
# Default limit, ordered steps
|
|
23
|
+
(
|
|
24
|
+
{"limit": 5, "timeout": None},
|
|
25
|
+
["10", "8", "6", "4", "2"],
|
|
26
|
+
[10, 8, 6, 4, 2]
|
|
27
|
+
),
|
|
28
|
+
# Custom limit
|
|
29
|
+
(
|
|
30
|
+
{"limit": 2, "timeout": None},
|
|
31
|
+
["10", "8"],
|
|
32
|
+
[10, 8]
|
|
33
|
+
),
|
|
34
|
+
# Custom timeout
|
|
35
|
+
(
|
|
36
|
+
{"limit": 5, "timeout": REQUEST_TIMEOUT},
|
|
37
|
+
["10", "8", "6", "4", "2"],
|
|
38
|
+
[10, 8, 6, 4, 2]
|
|
39
|
+
),
|
|
40
|
+
# Unordered steps
|
|
41
|
+
(
|
|
42
|
+
{"limit": 5, "timeout": None},
|
|
43
|
+
["4", "8", "2", "10", "6"],
|
|
44
|
+
[10, 8, 6, 4, 2]
|
|
45
|
+
),
|
|
46
|
+
],
|
|
47
|
+
)
|
|
48
|
+
def test_get_latest_checkpoints_success(self, params, api_steps, expected_steps):
|
|
49
|
+
# Arrange
|
|
50
|
+
mock_response = Mock()
|
|
51
|
+
mock_response.status_code = 200
|
|
52
|
+
mock_response.json.return_value = api_steps
|
|
53
|
+
self.mock_session.request.return_value = mock_response
|
|
54
|
+
|
|
55
|
+
# Act
|
|
56
|
+
result = self.client.get_latest_checkpoints(**params)
|
|
57
|
+
|
|
58
|
+
# Assert
|
|
59
|
+
self.assert_http_adapter_and_retry_config()
|
|
60
|
+
self.mock_session.request.assert_called_once_with(
|
|
61
|
+
method="GET",
|
|
62
|
+
url=f"{BASE_URL}/{self.checkpoints_path}",
|
|
63
|
+
params={"limit": params["limit"]},
|
|
64
|
+
data=None,
|
|
65
|
+
headers=None,
|
|
66
|
+
timeout=params["timeout"] or InMemoryClientConfig.request_timeout
|
|
67
|
+
)
|
|
68
|
+
assert result == expected_steps
|
|
69
|
+
|
|
70
|
+
@pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
|
|
71
|
+
def test_get_latest_checkpoints_request_errors(self, test_case):
|
|
72
|
+
# Arrange
|
|
73
|
+
if "exception" in test_case["response"]:
|
|
74
|
+
self.mock_session.request.side_effect = test_case["response"]["exception"]
|
|
75
|
+
else:
|
|
76
|
+
mock_response = Mock()
|
|
77
|
+
mock_response.status_code = test_case["response"]["status_code"]
|
|
78
|
+
mock_response.text = test_case["response"]["text"]
|
|
79
|
+
self.mock_session.request.return_value = mock_response
|
|
80
|
+
|
|
81
|
+
# Act
|
|
82
|
+
result = self.client.get_latest_checkpoints()
|
|
83
|
+
|
|
84
|
+
# Assert
|
|
85
|
+
self.assert_http_adapter_and_retry_config()
|
|
86
|
+
self.mock_session.request.assert_called_once_with(
|
|
87
|
+
method="GET",
|
|
88
|
+
url=f"{BASE_URL}/{self.checkpoints_path}",
|
|
89
|
+
params={"limit": 5},
|
|
90
|
+
data=None,
|
|
91
|
+
headers=None,
|
|
92
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
93
|
+
)
|
|
94
|
+
assert result == []
|
|
95
|
+
|
|
96
|
+
def test_get_latest_checkpoints_json_decode_error(self):
|
|
97
|
+
# Arrange
|
|
98
|
+
mock_response = Mock()
|
|
99
|
+
mock_response.status_code = 200
|
|
100
|
+
mock_response.json.side_effect = ValueError("Invalid JSON")
|
|
101
|
+
self.mock_session.request.return_value = mock_response
|
|
102
|
+
|
|
103
|
+
# Act
|
|
104
|
+
result = self.client.get_latest_checkpoints()
|
|
105
|
+
|
|
106
|
+
# Assert
|
|
107
|
+
self.assert_http_adapter_and_retry_config()
|
|
108
|
+
self.mock_session.request.assert_called_once_with(
|
|
109
|
+
method="GET",
|
|
110
|
+
url=f"{BASE_URL}/{self.checkpoints_path}",
|
|
111
|
+
params={"limit": 5},
|
|
112
|
+
data=None,
|
|
113
|
+
headers=None,
|
|
114
|
+
timeout=InMemoryClientConfig.request_timeout
|
|
115
|
+
)
|
|
116
|
+
assert result == []
|