amzn-sagemaker-checkpointing 1.0.11__tar.gz → 1.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of amzn-sagemaker-checkpointing might be problematic. Click here for more details.

Files changed (52) hide show
  1. amzn_sagemaker_checkpointing-1.0.12/DEVELOPING.md +22 -0
  2. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/PKG-INFO +1 -1
  3. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/pyproject.toml +2 -1
  4. amzn_sagemaker_checkpointing-1.0.12/setup-hatch.sh +19 -0
  5. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/inmemory_client.py +16 -11
  6. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_checkpoint.py +148 -0
  7. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_namespace.py +70 -0
  8. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_checkpoint.py +217 -0
  9. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_latest_checkpoints.py +116 -0
  10. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_namespace_config.py +118 -0
  11. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_or_create_namespace.py +255 -0
  12. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_put_checkpoint.py +209 -0
  13. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_reset_cluster.py +69 -0
  14. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/__init__.py +0 -0
  15. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/test_base.py +120 -0
  16. amzn_sagemaker_checkpointing-1.0.11/DEVELOPING.md +0 -46
  17. amzn_sagemaker_checkpointing-1.0.11/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_inmemory_client.py +0 -77
  18. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/.crux_dry_run_build +0 -0
  19. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/.gitignore +0 -0
  20. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/LICENSE.txt +0 -0
  21. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/README.md +0 -0
  22. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/brazil.ion +0 -0
  23. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-build-tools.txt +0 -0
  24. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-build.txt +0 -0
  25. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-static-analysis.txt +0 -0
  26. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.11.txt +0 -0
  27. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.12.txt +0 -0
  28. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/requirements.txt +0 -0
  29. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/__init__.py +0 -0
  30. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/__init__.py +0 -0
  31. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/exceptions.py +0 -0
  32. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/filesystem.py +0 -0
  33. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/__init__.py +0 -0
  34. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/in_memory_client.py +0 -0
  35. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/sagemaker_checkpoint_config.py +0 -0
  36. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/py.typed +0 -0
  37. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/__init__.py +0 -0
  38. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/__init__.py +0 -0
  39. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/__init__.py +0 -0
  40. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum.py +0 -0
  41. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/exceptions.py +0 -0
  42. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/models.py +0 -0
  43. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/local/disk_fs.py +0 -0
  44. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/__init__.py +0 -0
  45. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client.py +0 -0
  46. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client_manager.py +0 -0
  47. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/utils/logging_utils.py +0 -0
  48. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/src/scripts/test_inmemory_client.py +0 -0
  49. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/checkpointing/filesystem/test_filesystem.py +0 -0
  50. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum_test.py +0 -0
  51. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/s3/test_s3_client.py +0 -0
  52. {amzn_sagemaker_checkpointing-1.0.11 → amzn_sagemaker_checkpointing-1.0.12}/tests/test_dummy.py +0 -0
@@ -0,0 +1,22 @@
1
+ # Developing SageMakerCheckpointing
2
+
3
+ This package uses the [hatch](https://hatch.pypa.io/latest/) build system.
4
+
5
+ ### Building
6
+
7
+ A number of scripts and commands exist in `pyproject.toml` under the `scripts` configurations with more
8
+ documentation in the comments of `pyproject.toml`. Running a script for a specific environment is simply running
9
+ `hatch run <env_name>:<script>`. You can omit the `<env_name>` for those under the `default` environment.
10
+
11
+ You need to set up hatch pluging first:
12
+ ```
13
+ ./setup-hatch.sh
14
+ ```
15
+
16
+ ### Available Hatch Commands
17
+
18
+ - **`hatch run release`** - Runs typing checks (mypy), tests, and coverage.
19
+ - **`hatch test --cover`** - Runs tests and coverage.
20
+ - **`hatch typing`** - Runs mypy type checking.
21
+ - **`hatch fmt`** - Formats code using ruff.
22
+ - **`hatch build`** - builds both source and wheel distributions in ./build directory.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: amzn-sagemaker-checkpointing
3
- Version: 1.0.11
3
+ Version: 1.0.12
4
4
  Summary: Amazon SageMaker Checkpointing Library
5
5
  License: Apache 2.0
6
6
  License-File: LICENSE.txt
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "amzn-sagemaker-checkpointing"
7
- version = "1.0.11"
7
+ version = "1.0.12"
8
8
  description = "Amazon SageMaker Checkpointing Library"
9
9
  readme = "README.md"
10
10
  license = { "text" = "Apache 2.0" }
@@ -72,6 +72,7 @@ exclude = [ "./build", ".hatch", "private" ]
72
72
 
73
73
  [tool.hatch.build]
74
74
  directory = "./build"
75
+ exclude = ["DEVELOPING_INTERNAL.md"]
75
76
 
76
77
  [tool.hatch.env]
77
78
  requires = [ "hatch-pip-compile" ]
@@ -0,0 +1,19 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ mkdir -p .hatch
5
+
6
+ cat > .hatch/hatch_plugin.py << 'EOF'
7
+ from hatch.env.collectors.plugin.interface import EnvironmentCollectorInterface
8
+
9
+ class CustomEnvironmentCollector(EnvironmentCollectorInterface):
10
+ PLUGIN_NAME = 'custom'
11
+
12
+ def get_initial_config(self):
13
+ return {}
14
+
15
+ def finalize_config(self, config):
16
+ return config
17
+ EOF
18
+
19
+ echo "Hatch plugin created"
@@ -471,17 +471,22 @@ class InMemoryCheckpointClient:
471
471
  checksum=encode_base_64(hash_xxh3_128(data)), algorithm="xxh3_128"
472
472
  ).to_json()
473
473
  }
474
- if isinstance(data, str) and os.path.exists(data):
475
- with open(data, "rb") as f:
476
- self._make_request(
477
- "POST",
478
- endpoint,
479
- data=f,
480
- headers=headers,
481
- timeout=timeout,
482
- retries=retries,
483
- retry_backoff=retry_backoff,
484
- )
474
+ if isinstance(data, str):
475
+ try:
476
+ with open(data, "rb") as f:
477
+ self._make_request(
478
+ "POST",
479
+ endpoint,
480
+ data=f,
481
+ headers=headers,
482
+ timeout=timeout,
483
+ retries=retries,
484
+ retry_backoff=retry_backoff,
485
+ )
486
+ except Exception as e:
487
+ error_msg = f"Error opening file: {data}"
488
+ self._logger.error(error_msg)
489
+ raise InMemoryStorageError(error_msg) from e
485
490
  else:
486
491
  self._make_request(
487
492
  "POST",
@@ -0,0 +1,148 @@
1
+ from unittest.mock import Mock
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import InMemoryConfigError
6
+ from utils.test_base import (
7
+ InMemoryCheckpointClientTest,
8
+ BASE_URL,
9
+ NAMESPACE,
10
+ RANK,
11
+ REQUEST_ERROR_CASES,
12
+ REQUEST_TIMEOUT,
13
+ WORLD_SIZE
14
+ )
15
+
16
+
17
+ class TestDeleteCheckpoint(InMemoryCheckpointClientTest):
18
+ STEP = 42
19
+
20
+ def setup_method(self):
21
+ super().setup_method()
22
+ self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
23
+
24
+ def test_delete_checkpoint_success(self):
25
+ # Arrange
26
+ mock_response = Mock(status_code=200)
27
+ self.mock_session.request.return_value = mock_response
28
+
29
+ # Act
30
+ self.client.delete_checkpoint(step=self.STEP)
31
+
32
+ # Assert
33
+ self.assert_http_adapter_and_retry_config()
34
+ self.mock_session.request.assert_called_once_with(
35
+ method="DELETE",
36
+ url=f"{BASE_URL}/{self.checkpoint_path}",
37
+ params=None,
38
+ data=None,
39
+ headers=None,
40
+ timeout=InMemoryClientConfig.request_timeout
41
+ )
42
+
43
+ def test_delete_checkpoint_with_custom_rank(self):
44
+ # Arrange
45
+ mock_response = Mock(status_code=200)
46
+ self.mock_session.request.return_value = mock_response
47
+ custom_rank = 5
48
+ custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
49
+
50
+ # Act
51
+ self.client.delete_checkpoint(step=self.STEP, rank=custom_rank)
52
+
53
+ # Assert
54
+ self.assert_http_adapter_and_retry_config()
55
+ self.mock_session.request.assert_called_once_with(
56
+ method="DELETE",
57
+ url=f"{BASE_URL}/{custom_path}",
58
+ params=None,
59
+ data=None,
60
+ headers=None,
61
+ timeout=InMemoryClientConfig.request_timeout
62
+ )
63
+
64
+ def test_delete_checkpoint_with_metadata_index(self):
65
+ # Arrange
66
+ mock_response = Mock(status_code=200)
67
+ self.mock_session.request.return_value = mock_response
68
+ metadata_index = 0
69
+ metadata_rank = int(WORLD_SIZE) + metadata_index
70
+ metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
71
+
72
+ # Act
73
+ self.client.delete_checkpoint(step=self.STEP, metadata_index=metadata_index)
74
+
75
+ # Assert
76
+ self.assert_http_adapter_and_retry_config()
77
+ self.mock_session.request.assert_called_once_with(
78
+ method="DELETE",
79
+ url=f"{BASE_URL}/{metadata_path}",
80
+ params=None,
81
+ data=None,
82
+ headers=None,
83
+ timeout=InMemoryClientConfig.request_timeout
84
+ )
85
+
86
+ def test_delete_checkpoint_with_custom_timeout(self):
87
+ # Arrange
88
+ mock_response = Mock(status_code=200)
89
+ self.mock_session.request.return_value = mock_response
90
+
91
+ # Act
92
+ self.client.delete_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
93
+
94
+ # Assert
95
+ self.assert_http_adapter_and_retry_config()
96
+ self.mock_session.request.assert_called_once_with(
97
+ method="DELETE",
98
+ url=f"{BASE_URL}/{self.checkpoint_path}",
99
+ params=None,
100
+ data=None,
101
+ headers=None,
102
+ timeout=REQUEST_TIMEOUT
103
+ )
104
+
105
+ def test_delete_checkpoint_with_string_step(self):
106
+ # Arrange
107
+ mock_response = Mock(status_code=200)
108
+ self.mock_session.request.return_value = mock_response
109
+ step = "latest"
110
+ path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{step}"
111
+
112
+ # Act
113
+ self.client.delete_checkpoint(step=step)
114
+
115
+ # Assert
116
+ self.assert_http_adapter_and_retry_config()
117
+ self.mock_session.request.assert_called_once_with(
118
+ method="DELETE",
119
+ url=f"{BASE_URL}/{path}",
120
+ params=None,
121
+ data=None,
122
+ headers=None,
123
+ timeout=InMemoryClientConfig.request_timeout
124
+ )
125
+
126
+ def test_delete_checkpoint_invalid_metadata_index(self):
127
+ # Act & Assert
128
+ with pytest.raises(InMemoryConfigError) as exc_info:
129
+ self.client.delete_checkpoint(step=self.STEP, metadata_index=999)
130
+ assert "Invalid metadata_index" in str(exc_info.value)
131
+
132
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
133
+ def test_delete_checkpoint_request_errors(self, test_case):
134
+ # Arrange
135
+ if "exception" in test_case["response"]:
136
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
137
+ else:
138
+ mock_response = Mock()
139
+ mock_response.status_code = test_case["response"]["status_code"]
140
+ mock_response.text = test_case["response"]["text"]
141
+ self.mock_session.request.return_value = mock_response
142
+
143
+ # Act & Assert
144
+ self.assert_request_error(
145
+ test_case,
146
+ self.client.delete_checkpoint,
147
+ step=self.STEP
148
+ )
@@ -0,0 +1,70 @@
1
+ from unittest.mock import Mock, call
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from utils.test_base import (
6
+ InMemoryCheckpointClientTest,
7
+ BASE_URL,
8
+ NAMESPACE,
9
+ REQUEST_ERROR_CASES,
10
+ REQUEST_TIMEOUT,
11
+ )
12
+
13
+
14
+ class TestDeleteNamespace(InMemoryCheckpointClientTest):
15
+ def setup_method(self):
16
+ super().setup_method()
17
+ self.namespace_path = f"v1/cp/namespaces/{NAMESPACE}"
18
+
19
+ @pytest.mark.parametrize(
20
+ "params",
21
+ [
22
+ {"timeout": None},
23
+ {"timeout": REQUEST_TIMEOUT},
24
+ ],
25
+ )
26
+ def test_delete_namespace_success(self, params):
27
+ # Arrange
28
+ mock_response = Mock()
29
+ mock_response.status_code = 200
30
+ self.mock_session.request.return_value = mock_response
31
+
32
+ # Act
33
+ self.client.delete_namespace(**params)
34
+
35
+ # Assert
36
+ self.assert_http_adapter_and_retry_config()
37
+ self.mock_session.request.assert_called_once_with(
38
+ method="DELETE",
39
+ url=f"{BASE_URL}/{self.namespace_path}",
40
+ params=None,
41
+ data=None,
42
+ headers=None,
43
+ timeout=params["timeout"] or InMemoryClientConfig.request_timeout
44
+ )
45
+
46
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
47
+ def test_delete_namespace_request_errors(self, test_case):
48
+ # Arrange
49
+ if "exception" in test_case["response"]:
50
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
51
+ else:
52
+ mock_response = Mock()
53
+ mock_response.status_code = test_case["response"]["status_code"]
54
+ mock_response.text = test_case["response"]["text"]
55
+ self.mock_session.request.return_value = mock_response
56
+
57
+ # Act & Assert
58
+ self.assert_http_adapter_and_retry_config()
59
+ self.assert_request_error(
60
+ test_case,
61
+ self.client.delete_namespace
62
+ )
63
+ self.mock_session.request.assert_called_once_with(
64
+ method="DELETE",
65
+ url=f"{BASE_URL}/{self.namespace_path}",
66
+ params=None,
67
+ data=None,
68
+ headers=None,
69
+ timeout=InMemoryClientConfig.request_timeout
70
+ )
@@ -0,0 +1,217 @@
1
+ from unittest.mock import Mock, call, patch
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import (
6
+ InMemoryConfigError,
7
+ InMemoryStorageError
8
+ )
9
+ from utils.test_base import (
10
+ InMemoryCheckpointClientTest,
11
+ BASE_URL,
12
+ NAMESPACE,
13
+ RANK,
14
+ REQUEST_ERROR_CASES,
15
+ REQUEST_TIMEOUT,
16
+ WORLD_SIZE
17
+ )
18
+
19
+
20
+ class TestGetCheckpoint(InMemoryCheckpointClientTest):
21
+ # Constants
22
+ STEP = 42
23
+ TEST_CONTENT = b"test checkpoint data"
24
+ MOCK_CHECKSUM = "mock-checksum"
25
+ MOCK_ENCODED_CHECKSUM = "bW9jay1jaGVja3N1bQ==" # base64 of mock-checksum
26
+ CONTENT_LENGTH = len(TEST_CONTENT)
27
+ RESPONSE_HEADERS = {
28
+ "Content-Length": str(CONTENT_LENGTH),
29
+ "Shard-Meta": '{"checksum": "bW9jay1jaGVja3N1bQ==", "algorithm": "xxh3_128"}'
30
+ }
31
+
32
+ def setup_method(self):
33
+ super().setup_method()
34
+ self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
35
+
36
+ # Mock checksum functions
37
+ self.mock_decode_base_64 = patch(
38
+ "amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client.decode_base_64"
39
+ ).start()
40
+ self.mock_hash_xxh3_128 = patch(
41
+ "amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client.hash_xxh3_128"
42
+ ).start()
43
+ self.mock_decode_base_64.return_value = self.MOCK_CHECKSUM
44
+ self.mock_hash_xxh3_128.return_value = self.MOCK_CHECKSUM
45
+
46
+ def test_get_checkpoint_success(self):
47
+ # Arrange
48
+ mock_response = Mock()
49
+ mock_response.status_code = 200
50
+ mock_response.content = self.TEST_CONTENT
51
+ mock_response.headers = self.RESPONSE_HEADERS
52
+ self.mock_session.request.return_value = mock_response
53
+
54
+ # Act
55
+ result = self.client.get_checkpoint(step=self.STEP)
56
+
57
+ # Assert
58
+ self.assert_http_adapter_and_retry_config()
59
+ self.mock_hash_xxh3_128.assert_called_once_with(self.TEST_CONTENT)
60
+ self.mock_decode_base_64.assert_called_once_with(self.MOCK_ENCODED_CHECKSUM)
61
+ self.mock_session.request.assert_called_once_with(
62
+ method="GET",
63
+ url=f"{BASE_URL}/{self.checkpoint_path}",
64
+ params=None,
65
+ data=None,
66
+ headers=None,
67
+ timeout=InMemoryClientConfig.request_timeout
68
+ )
69
+ assert result == self.TEST_CONTENT
70
+
71
+ def test_get_checkpoint_with_custom_rank(self):
72
+ # Arrange
73
+ mock_response = Mock()
74
+ mock_response.status_code = 200
75
+ mock_response.content = self.TEST_CONTENT
76
+ mock_response.headers = self.RESPONSE_HEADERS
77
+ self.mock_session.request.return_value = mock_response
78
+ custom_rank = 5
79
+ custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
80
+
81
+ # Act
82
+ result = self.client.get_checkpoint(step=self.STEP, rank=custom_rank)
83
+
84
+ # Assert
85
+ self.assert_http_adapter_and_retry_config()
86
+ self.mock_session.request.assert_called_once_with(
87
+ method="GET",
88
+ url=f"{BASE_URL}/{custom_path}",
89
+ params=None,
90
+ data=None,
91
+ headers=None,
92
+ timeout=InMemoryClientConfig.request_timeout
93
+ )
94
+ assert result == self.TEST_CONTENT
95
+
96
+ def test_get_checkpoint_with_metadata_index(self):
97
+ # Arrange
98
+ mock_response = Mock()
99
+ mock_response.status_code = 200
100
+ mock_response.content = self.TEST_CONTENT
101
+ mock_response.headers = self.RESPONSE_HEADERS
102
+ self.mock_session.request.return_value = mock_response
103
+ metadata_index = 0
104
+ metadata_rank = int(WORLD_SIZE) + metadata_index
105
+ metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
106
+
107
+ # Act
108
+ result = self.client.get_checkpoint(step=self.STEP, metadata_index=metadata_index)
109
+
110
+ # Assert
111
+ self.assert_http_adapter_and_retry_config()
112
+ self.mock_session.request.assert_called_once_with(
113
+ method="GET",
114
+ url=f"{BASE_URL}/{metadata_path}",
115
+ params=None,
116
+ data=None,
117
+ headers=None,
118
+ timeout=InMemoryClientConfig.request_timeout
119
+ )
120
+ assert result == self.TEST_CONTENT
121
+
122
+ def test_get_checkpoint_with_custom_timeout(self):
123
+ # Arrange
124
+ mock_response = Mock()
125
+ mock_response.status_code = 200
126
+ mock_response.content = self.TEST_CONTENT
127
+ mock_response.headers = self.RESPONSE_HEADERS
128
+ self.mock_session.request.return_value = mock_response
129
+
130
+ # Act
131
+ result = self.client.get_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
132
+
133
+ # Assert
134
+ self.assert_http_adapter_and_retry_config()
135
+ self.mock_session.request.assert_called_once_with(
136
+ method="GET",
137
+ url=f"{BASE_URL}/{self.checkpoint_path}",
138
+ params=None,
139
+ data=None,
140
+ headers=None,
141
+ timeout=REQUEST_TIMEOUT
142
+ )
143
+ assert result == self.TEST_CONTENT
144
+
145
+ def test_get_checkpoint_not_found(self):
146
+ # Arrange
147
+ mock_response = Mock(status_code=404)
148
+ self.mock_session.request.return_value = mock_response
149
+
150
+ # Act
151
+ result = self.client.get_checkpoint(step=self.STEP)
152
+
153
+ # Assert
154
+ self.assert_http_adapter_and_retry_config()
155
+ self.mock_session.request.assert_called_once_with(
156
+ method="GET",
157
+ url=f"{BASE_URL}/{self.checkpoint_path}",
158
+ params=None,
159
+ data=None,
160
+ headers=None,
161
+ timeout=InMemoryClientConfig.request_timeout
162
+ )
163
+ assert result is None
164
+
165
+ def test_get_checkpoint_checksum_mismatch(self):
166
+ # Arrange
167
+ mock_response = Mock()
168
+ mock_response.status_code = 200
169
+ mock_response.content = self.TEST_CONTENT
170
+ mock_response.headers = self.RESPONSE_HEADERS
171
+ self.mock_session.request.return_value = mock_response
172
+ self.mock_hash_xxh3_128.return_value = "different-checksum"
173
+
174
+ # Act & Assert
175
+ with pytest.raises(InMemoryStorageError) as exc_info:
176
+ self.client.get_checkpoint(step=self.STEP)
177
+ assert "Checksum mismatch in response" in str(exc_info.value)
178
+
179
+ def test_get_checkpoint_content_length_mismatch(self):
180
+ # Arrange
181
+ mock_response = Mock()
182
+ mock_response.status_code = 200
183
+ mock_response.content = self.TEST_CONTENT
184
+ mock_response.headers = {
185
+ "Content-Length": str(self.CONTENT_LENGTH + 1), # Wrong content length
186
+ "Shard-Meta": '{"checksum": "bW9jay1jaGVja3N1bQ==", "algorithm": "xxh3_128"}'
187
+ }
188
+ self.mock_session.request.return_value = mock_response
189
+
190
+ # Act & Assert
191
+ with pytest.raises(InMemoryStorageError) as exc_info:
192
+ self.client.get_checkpoint(step=self.STEP)
193
+ assert "Content length mismatch" in str(exc_info.value)
194
+
195
+ def test_get_checkpoint_invalid_metadata_index(self):
196
+ # Act & Assert
197
+ with pytest.raises(InMemoryConfigError) as exc_info:
198
+ self.client.get_checkpoint(step=self.STEP, metadata_index=999)
199
+ assert "Invalid metadata_index" in str(exc_info.value)
200
+
201
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
202
+ def test_get_checkpoint_request_errors(self, test_case):
203
+ # Arrange
204
+ if "exception" in test_case["response"]:
205
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
206
+ else:
207
+ mock_response = Mock()
208
+ mock_response.status_code = test_case["response"]["status_code"]
209
+ mock_response.text = test_case["response"]["text"]
210
+ self.mock_session.request.return_value = mock_response
211
+
212
+ # Act & Assert
213
+ self.assert_request_error(
214
+ test_case,
215
+ self.client.get_checkpoint,
216
+ step=self.STEP
217
+ )
@@ -0,0 +1,116 @@
1
+ from unittest.mock import Mock, call
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from utils.test_base import (
6
+ InMemoryCheckpointClientTest,
7
+ BASE_URL,
8
+ NAMESPACE,
9
+ REQUEST_ERROR_CASES,
10
+ REQUEST_TIMEOUT,
11
+ )
12
+
13
+
14
+ class TestGetLatestCheckpoints(InMemoryCheckpointClientTest):
15
+ def setup_method(self):
16
+ super().setup_method()
17
+ self.checkpoints_path = f"v1/cp/checkpoints/{NAMESPACE}/latest"
18
+
19
+ @pytest.mark.parametrize(
20
+ "params, api_steps, expected_steps",
21
+ [
22
+ # Default limit, ordered steps
23
+ (
24
+ {"limit": 5, "timeout": None},
25
+ ["10", "8", "6", "4", "2"],
26
+ [10, 8, 6, 4, 2]
27
+ ),
28
+ # Custom limit
29
+ (
30
+ {"limit": 2, "timeout": None},
31
+ ["10", "8"],
32
+ [10, 8]
33
+ ),
34
+ # Custom timeout
35
+ (
36
+ {"limit": 5, "timeout": REQUEST_TIMEOUT},
37
+ ["10", "8", "6", "4", "2"],
38
+ [10, 8, 6, 4, 2]
39
+ ),
40
+ # Unordered steps
41
+ (
42
+ {"limit": 5, "timeout": None},
43
+ ["4", "8", "2", "10", "6"],
44
+ [10, 8, 6, 4, 2]
45
+ ),
46
+ ],
47
+ )
48
+ def test_get_latest_checkpoints_success(self, params, api_steps, expected_steps):
49
+ # Arrange
50
+ mock_response = Mock()
51
+ mock_response.status_code = 200
52
+ mock_response.json.return_value = api_steps
53
+ self.mock_session.request.return_value = mock_response
54
+
55
+ # Act
56
+ result = self.client.get_latest_checkpoints(**params)
57
+
58
+ # Assert
59
+ self.assert_http_adapter_and_retry_config()
60
+ self.mock_session.request.assert_called_once_with(
61
+ method="GET",
62
+ url=f"{BASE_URL}/{self.checkpoints_path}",
63
+ params={"limit": params["limit"]},
64
+ data=None,
65
+ headers=None,
66
+ timeout=params["timeout"] or InMemoryClientConfig.request_timeout
67
+ )
68
+ assert result == expected_steps
69
+
70
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
71
+ def test_get_latest_checkpoints_request_errors(self, test_case):
72
+ # Arrange
73
+ if "exception" in test_case["response"]:
74
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
75
+ else:
76
+ mock_response = Mock()
77
+ mock_response.status_code = test_case["response"]["status_code"]
78
+ mock_response.text = test_case["response"]["text"]
79
+ self.mock_session.request.return_value = mock_response
80
+
81
+ # Act
82
+ result = self.client.get_latest_checkpoints()
83
+
84
+ # Assert
85
+ self.assert_http_adapter_and_retry_config()
86
+ self.mock_session.request.assert_called_once_with(
87
+ method="GET",
88
+ url=f"{BASE_URL}/{self.checkpoints_path}",
89
+ params={"limit": 5},
90
+ data=None,
91
+ headers=None,
92
+ timeout=InMemoryClientConfig.request_timeout
93
+ )
94
+ assert result == []
95
+
96
+ def test_get_latest_checkpoints_json_decode_error(self):
97
+ # Arrange
98
+ mock_response = Mock()
99
+ mock_response.status_code = 200
100
+ mock_response.json.side_effect = ValueError("Invalid JSON")
101
+ self.mock_session.request.return_value = mock_response
102
+
103
+ # Act
104
+ result = self.client.get_latest_checkpoints()
105
+
106
+ # Assert
107
+ self.assert_http_adapter_and_retry_config()
108
+ self.mock_session.request.assert_called_once_with(
109
+ method="GET",
110
+ url=f"{BASE_URL}/{self.checkpoints_path}",
111
+ params={"limit": 5},
112
+ data=None,
113
+ headers=None,
114
+ timeout=InMemoryClientConfig.request_timeout
115
+ )
116
+ assert result == []