amzn-sagemaker-checkpointing 1.0.10__tar.gz → 1.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of amzn-sagemaker-checkpointing might be problematic. Click here for more details.

Files changed (52) hide show
  1. amzn_sagemaker_checkpointing-1.0.12/DEVELOPING.md +22 -0
  2. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/PKG-INFO +3 -3
  3. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/README.md +2 -2
  4. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/pyproject.toml +2 -1
  5. amzn_sagemaker_checkpointing-1.0.12/setup-hatch.sh +19 -0
  6. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/filesystem.py +117 -59
  7. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/inmemory_client.py +16 -11
  8. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_checkpoint.py +148 -0
  9. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_delete_namespace.py +70 -0
  10. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_checkpoint.py +217 -0
  11. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_latest_checkpoints.py +116 -0
  12. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_namespace_config.py +118 -0
  13. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_get_or_create_namespace.py +255 -0
  14. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_put_checkpoint.py +209 -0
  15. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_reset_cluster.py +69 -0
  16. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/__init__.py +0 -0
  17. amzn_sagemaker_checkpointing-1.0.12/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/utils/test_base.py +120 -0
  18. amzn_sagemaker_checkpointing-1.0.10/DEVELOPING.md +0 -46
  19. amzn_sagemaker_checkpointing-1.0.10/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/test_inmemory_client.py +0 -77
  20. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/.crux_dry_run_build +0 -0
  21. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/.gitignore +0 -0
  22. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/LICENSE.txt +0 -0
  23. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/brazil.ion +0 -0
  24. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-build-tools.txt +0 -0
  25. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-build.txt +0 -0
  26. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-static-analysis.txt +0 -0
  27. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.11.txt +0 -0
  28. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements/requirements-hatch-test.py3.12.txt +0 -0
  29. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/requirements.txt +0 -0
  30. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/__init__.py +0 -0
  31. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/__init__.py +0 -0
  32. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/checkpointing/filesystem/exceptions.py +0 -0
  33. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/__init__.py +0 -0
  34. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/in_memory_client.py +0 -0
  35. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/config/sagemaker_checkpoint_config.py +0 -0
  36. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/py.typed +0 -0
  37. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/__init__.py +0 -0
  38. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/__init__.py +0 -0
  39. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/__init__.py +0 -0
  40. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum.py +0 -0
  41. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/exceptions.py +0 -0
  42. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/inmemory/models.py +0 -0
  43. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/local/disk_fs.py +0 -0
  44. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/__init__.py +0 -0
  45. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client.py +0 -0
  46. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/storage/clients/s3/s3_client_manager.py +0 -0
  47. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/amzn_sagemaker_checkpointing/utils/logging_utils.py +0 -0
  48. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/src/scripts/test_inmemory_client.py +0 -0
  49. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/checkpointing/filesystem/test_filesystem.py +0 -0
  50. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/inmemory/checksum_test.py +0 -0
  51. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/amzn_sagemaker_checkpointing/storage/clients/s3/test_s3_client.py +0 -0
  52. {amzn_sagemaker_checkpointing-1.0.10 → amzn_sagemaker_checkpointing-1.0.12}/tests/test_dummy.py +0 -0
@@ -0,0 +1,22 @@
1
+ # Developing SageMakerCheckpointing
2
+
3
+ This package uses the [hatch](https://hatch.pypa.io/latest/) build system.
4
+
5
+ ### Building
6
+
7
+ A number of scripts and commands exist in `pyproject.toml` under the `scripts` configurations with more
8
+ documentation in the comments of `pyproject.toml`. Running a script for a specific environment is simply running
9
+ `hatch run <env_name>:<script>`. You can omit the `<env_name>` for those under the `default` environment.
10
+
11
+ You need to set up hatch pluging first:
12
+ ```
13
+ ./setup-hatch.sh
14
+ ```
15
+
16
+ ### Available Hatch Commands
17
+
18
+ - **`hatch run release`** - Runs typing checks (mypy), tests, and coverage.
19
+ - **`hatch test --cover`** - Runs tests and coverage.
20
+ - **`hatch typing`** - Runs mypy type checking.
21
+ - **`hatch fmt`** - Formats code using ruff.
22
+ - **`hatch build`** - builds both source and wheel distributions in ./build directory.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: amzn-sagemaker-checkpointing
3
- Version: 1.0.10
3
+ Version: 1.0.12
4
4
  Summary: Amazon SageMaker Checkpointing Library
5
5
  License: Apache 2.0
6
6
  License-File: LICENSE.txt
@@ -95,12 +95,12 @@ following to your S3 bucket policy
95
95
  ```
96
96
 
97
97
  ## Installation
98
- ### PreRequisites
98
+ ### Prerequisites
99
99
  ```bash
100
100
  pip install s3torchconnector tenacity torch boto3 botocore
101
101
  ```
102
102
 
103
- ### Install amzn-sagemaker-checkpointing library
103
+ ### SageMaker Checkpointing Library
104
104
  ```bash
105
105
  pip install amzn-sagemaker-checkpointing
106
106
  ```
@@ -82,12 +82,12 @@ following to your S3 bucket policy
82
82
  ```
83
83
 
84
84
  ## Installation
85
- ### PreRequisites
85
+ ### Prerequisites
86
86
  ```bash
87
87
  pip install s3torchconnector tenacity torch boto3 botocore
88
88
  ```
89
89
 
90
- ### Install amzn-sagemaker-checkpointing library
90
+ ### SageMaker Checkpointing Library
91
91
  ```bash
92
92
  pip install amzn-sagemaker-checkpointing
93
93
  ```
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "amzn-sagemaker-checkpointing"
7
- version = "1.0.10"
7
+ version = "1.0.12"
8
8
  description = "Amazon SageMaker Checkpointing Library"
9
9
  readme = "README.md"
10
10
  license = { "text" = "Apache 2.0" }
@@ -72,6 +72,7 @@ exclude = [ "./build", ".hatch", "private" ]
72
72
 
73
73
  [tool.hatch.build]
74
74
  directory = "./build"
75
+ exclude = ["DEVELOPING_INTERNAL.md"]
75
76
 
76
77
  [tool.hatch.env]
77
78
  requires = [ "hatch-pip-compile" ]
@@ -0,0 +1,19 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ mkdir -p .hatch
5
+
6
+ cat > .hatch/hatch_plugin.py << 'EOF'
7
+ from hatch.env.collectors.plugin.interface import EnvironmentCollectorInterface
8
+
9
+ class CustomEnvironmentCollector(EnvironmentCollectorInterface):
10
+ PLUGIN_NAME = 'custom'
11
+
12
+ def get_initial_config(self):
13
+ return {}
14
+
15
+ def finalize_config(self, config):
16
+ return config
17
+ EOF
18
+
19
+ echo "Hatch plugin created"
@@ -19,6 +19,7 @@ import pickle
19
19
  import threading
20
20
  import time
21
21
  from dataclasses import dataclass
22
+ from enum import Enum
22
23
  from logging import FileHandler
23
24
  from typing import Any, Union
24
25
 
@@ -46,9 +47,6 @@ from torch.futures import Future
46
47
  from amzn_sagemaker_checkpointing.config.sagemaker_checkpoint_config import (
47
48
  SageMakerCheckpointConfig,
48
49
  )
49
- from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import (
50
- InMemoryServerError,
51
- )
52
50
  from amzn_sagemaker_checkpointing.storage.clients.inmemory.inmemory_client import (
53
51
  InMemoryCheckpointClient,
54
52
  )
@@ -80,6 +78,15 @@ class _SageMakerStorageInfo:
80
78
  offset: int
81
79
  length: int
82
80
 
81
+ class StorageTier(Enum):
82
+ IN_MEMORY = 0
83
+ S3 = 1
84
+
85
+ def __str__(self):
86
+ return {
87
+ 0: "IN_MEMORY",
88
+ 1: "S3"
89
+ }[self.value]
83
90
 
84
91
  def _get_step_val(step: int, path: str | os.PathLike) -> int:
85
92
  """
@@ -791,51 +798,42 @@ class SageMakerTieredStorageReader(StorageReader):
791
798
 
792
799
  def read_metadata(self) -> Metadata:
793
800
  """
794
- Retrieve and deserialize checkpoint metadata from the in-memory storage.
801
+ Retrieve and deserialize checkpoint metadata.
795
802
 
796
803
  Returns
797
804
  -------
798
805
  Metadata
799
806
  Metadata object containing checkpoint information.
800
-
801
- Raises
802
- ------
803
- RuntimeError
804
- If metadata retrieval fails.
807
+ (or) empty Metadata if not available
805
808
  """
806
- # Use provided step or find latest available
807
- if self.step is None:
808
- self.step = self._find_latest_complete_step_across_tiers()
809
-
810
- if not self.step:
811
- self.logger.info(
812
- f"[Rank {self.rank}] Step {self.step}: No checkpoints found"
813
- )
814
- return Metadata({})
815
-
816
- # Try in-memory first (faster)
817
- metadata_buffer = self._try_read_md_from_memory(self.step)
818
- if metadata_buffer:
819
- self.logger.info(
820
- f"[Rank {self.rank}] Step {self.step}: Successfully read metadata from memory, size={len(metadata_buffer)} bytes"
821
- )
822
- return pickle.loads(metadata_buffer)
823
-
824
- self.logger.info(
825
- f"[Rank {self.rank}] Step {self.step}: In-memory metadata not found"
826
- )
827
- # Fallback to S3
828
- if self.s3_base_path:
829
- self.logger.info(
830
- f"[Rank {self.rank}] Step {self.step}: Attempting metadata read from S3"
831
- )
832
- metadata_buffer = self._try_read_md_from_s3(self.step)
833
- if metadata_buffer:
834
- self.logger.info(
835
- f"[Rank {self.rank}] Step {self.step}: Successfully read metadata from S3, size={len(metadata_buffer)} bytes"
836
- )
837
- return pickle.loads(metadata_buffer)
838
- return Metadata({})
809
+ metadata = Metadata({})
810
+ try:
811
+ if self.step is not None:
812
+ self.logger.info(f"[Rank {self.rank}] Step {self.step}: "
813
+ "reading metadata for configured step")
814
+ metadata = self._read_metadata_for_step(self.step)
815
+ else:
816
+ latest_step_all_tiers = self._get_latest_step_all_tiers()
817
+ for latest_step, tier in latest_step_all_tiers:
818
+ if tier == StorageTier.IN_MEMORY:
819
+ self.logger.info(f"[Rank {self.rank}] Attempting to read "
820
+ f"metadata from memory for {latest_step}")
821
+ step_metadata = self._read_metadata_from_memory(latest_step)
822
+ elif tier == StorageTier.S3:
823
+ self.logger.info(f"[Rank {self.rank}] Attempting to read "
824
+ f"metadata from S3 for {latest_step}")
825
+ step_metadata = self._read_metadata_from_s3(latest_step)
826
+ if step_metadata is not None:
827
+ metadata = step_metadata
828
+ self.step = latest_step
829
+ self.logger.info(f"[Rank {self.rank}] Metadata "
830
+ f"read from step {latest_step} of {tier} tier")
831
+ break
832
+ if self.step is None:
833
+ self.logger.error(f"[Rank {self.rank}] No checkpoints to read metadata")
834
+ except Exception as e:
835
+ self.logger.error(f"[Rank {self.rank}] Step {self.step}: read_metadata failed: {e}")
836
+ return metadata
839
837
 
840
838
  def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
841
839
  """
@@ -1083,23 +1081,6 @@ class SageMakerTieredStorageReader(StorageReader):
1083
1081
  """
1084
1082
  return True
1085
1083
 
1086
- def _find_latest_complete_step_across_tiers(self) -> int | None:
1087
- """Find latest step from both storage tiers."""
1088
- memory_step = self.client.get_latest_checkpoints(limit=1)
1089
- s3_step = self._find_latest_complete_step()
1090
- latest_step = None
1091
- if not memory_step:
1092
- latest_step = s3_step
1093
- elif not s3_step:
1094
- latest_step = memory_step[0]
1095
- else:
1096
- latest_step = max(memory_step[0], s3_step)
1097
- self.logger.info(
1098
- f"[Rank {self.rank}] Step {self.step}: Latest steps: "
1099
- f"memory:{memory_step}, s3:{s3_step}, across_tiers:{latest_step}"
1100
- )
1101
- return latest_step
1102
-
1103
1084
  def _try_read_md_from_memory(self, step: int) -> bytes | None:
1104
1085
  """Try reading metadata from in-memory storage."""
1105
1086
  try:
@@ -1252,3 +1233,80 @@ class SageMakerTieredStorageReader(StorageReader):
1252
1233
  f"[Rank {self.rank}] Failed to read item {item_index} from step {step}: {e}"
1253
1234
  )
1254
1235
  return None
1236
+
1237
+ def _read_metadata_from_memory(self, step) -> Metadata | None:
1238
+ metadata = None
1239
+ try:
1240
+ metadata_buffer = self._try_read_md_from_memory(step)
1241
+ if metadata_buffer:
1242
+ self.logger.info(
1243
+ f"[Rank {self.rank}] Step {step}: Successfully read metadata from memory, "
1244
+ f"size={len(metadata_buffer)} bytes"
1245
+ )
1246
+ metadata = pickle.loads(metadata_buffer)
1247
+ else:
1248
+ self.logger.info(
1249
+ f"[Rank {self.rank}] Step {step}: "
1250
+ f"In-memory metadata not found"
1251
+ )
1252
+ except Exception as e:
1253
+ self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_from_memory failed: {e}")
1254
+ return metadata
1255
+
1256
+ def _read_metadata_from_s3(self, step) -> Metadata | None:
1257
+ metadata = None
1258
+ try:
1259
+ if self.s3_base_path:
1260
+ self.logger.info(
1261
+ f"[Rank {self.rank}] Step {step}: Attempting metadata read from S3"
1262
+ )
1263
+ metadata_buffer = self._try_read_md_from_s3(step)
1264
+ if metadata_buffer:
1265
+ self.logger.info(f"[Rank {self.rank}] Step {step}: "
1266
+ f"Successfully read metadata from size={len(metadata_buffer)} bytes")
1267
+ metadata = pickle.loads(metadata_buffer)
1268
+ else:
1269
+ self.logger.info(
1270
+ f"[Rank {self.rank}] Step {step}: "
1271
+ "S3 metadata not found")
1272
+ else:
1273
+ self.logger.info(
1274
+ f"[Rank {self.rank}] Step {step}: Unable to read metadata "
1275
+ "as S3 path is not provided"
1276
+ )
1277
+ except Exception as e:
1278
+ self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_from_s3 failed: {e}")
1279
+ return metadata
1280
+
1281
+ def _read_metadata_for_step(self, step) -> Metadata:
1282
+ metadata = Metadata({})
1283
+ try:
1284
+ in_memory_metadata = self._read_metadata_from_memory(step)
1285
+ if in_memory_metadata is not None:
1286
+ metadata = in_memory_metadata
1287
+ else:
1288
+ s3_metadata = self._read_metadata_from_s3(step)
1289
+ if s3_metadata is not None:
1290
+ metadata = s3_metadata
1291
+ except Exception as e:
1292
+ self.logger.error(f"[Rank {self.rank}] Step {step}: _read_metadata_for_step failed: {e}")
1293
+ return metadata
1294
+
1295
+ def _get_latest_step_all_tiers(self) -> list[tuple[int, StorageTier]]:
1296
+ latest_step_all_tiers = []
1297
+ try:
1298
+ memory_steps = self.client.get_latest_checkpoints(limit=3)
1299
+ if memory_steps:
1300
+ latest_step_all_tiers = [(step, StorageTier.IN_MEMORY) for step in memory_steps]
1301
+ except Exception as e:
1302
+ self.logger.error(f"[Rank {self.rank}]: Failed to get memory steps: {e}")
1303
+ try:
1304
+ s3_step = self._find_latest_complete_step()
1305
+ if s3_step:
1306
+ latest_step_all_tiers.append((s3_step, StorageTier.S3))
1307
+ except Exception as e:
1308
+ self.logger.error(f"[Rank {self.rank}]: Failed to get S3 step: {e}")
1309
+
1310
+ latest_step_all_tiers.sort(key=lambda tier_step: (-tier_step[0], tier_step[1].value))
1311
+ self.logger.info(f"[Rank {self.rank}] Latest steps across tiers: {latest_step_all_tiers}")
1312
+ return latest_step_all_tiers
@@ -471,17 +471,22 @@ class InMemoryCheckpointClient:
471
471
  checksum=encode_base_64(hash_xxh3_128(data)), algorithm="xxh3_128"
472
472
  ).to_json()
473
473
  }
474
- if isinstance(data, str) and os.path.exists(data):
475
- with open(data, "rb") as f:
476
- self._make_request(
477
- "POST",
478
- endpoint,
479
- data=f,
480
- headers=headers,
481
- timeout=timeout,
482
- retries=retries,
483
- retry_backoff=retry_backoff,
484
- )
474
+ if isinstance(data, str):
475
+ try:
476
+ with open(data, "rb") as f:
477
+ self._make_request(
478
+ "POST",
479
+ endpoint,
480
+ data=f,
481
+ headers=headers,
482
+ timeout=timeout,
483
+ retries=retries,
484
+ retry_backoff=retry_backoff,
485
+ )
486
+ except Exception as e:
487
+ error_msg = f"Error opening file: {data}"
488
+ self._logger.error(error_msg)
489
+ raise InMemoryStorageError(error_msg) from e
485
490
  else:
486
491
  self._make_request(
487
492
  "POST",
@@ -0,0 +1,148 @@
1
+ from unittest.mock import Mock
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from amzn_sagemaker_checkpointing.storage.clients.inmemory.exceptions import InMemoryConfigError
6
+ from utils.test_base import (
7
+ InMemoryCheckpointClientTest,
8
+ BASE_URL,
9
+ NAMESPACE,
10
+ RANK,
11
+ REQUEST_ERROR_CASES,
12
+ REQUEST_TIMEOUT,
13
+ WORLD_SIZE
14
+ )
15
+
16
+
17
+ class TestDeleteCheckpoint(InMemoryCheckpointClientTest):
18
+ STEP = 42
19
+
20
+ def setup_method(self):
21
+ super().setup_method()
22
+ self.checkpoint_path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{self.STEP}"
23
+
24
+ def test_delete_checkpoint_success(self):
25
+ # Arrange
26
+ mock_response = Mock(status_code=200)
27
+ self.mock_session.request.return_value = mock_response
28
+
29
+ # Act
30
+ self.client.delete_checkpoint(step=self.STEP)
31
+
32
+ # Assert
33
+ self.assert_http_adapter_and_retry_config()
34
+ self.mock_session.request.assert_called_once_with(
35
+ method="DELETE",
36
+ url=f"{BASE_URL}/{self.checkpoint_path}",
37
+ params=None,
38
+ data=None,
39
+ headers=None,
40
+ timeout=InMemoryClientConfig.request_timeout
41
+ )
42
+
43
+ def test_delete_checkpoint_with_custom_rank(self):
44
+ # Arrange
45
+ mock_response = Mock(status_code=200)
46
+ self.mock_session.request.return_value = mock_response
47
+ custom_rank = 5
48
+ custom_path = f"v1/cp/checkpoints/{NAMESPACE}/{custom_rank}/{self.STEP}"
49
+
50
+ # Act
51
+ self.client.delete_checkpoint(step=self.STEP, rank=custom_rank)
52
+
53
+ # Assert
54
+ self.assert_http_adapter_and_retry_config()
55
+ self.mock_session.request.assert_called_once_with(
56
+ method="DELETE",
57
+ url=f"{BASE_URL}/{custom_path}",
58
+ params=None,
59
+ data=None,
60
+ headers=None,
61
+ timeout=InMemoryClientConfig.request_timeout
62
+ )
63
+
64
+ def test_delete_checkpoint_with_metadata_index(self):
65
+ # Arrange
66
+ mock_response = Mock(status_code=200)
67
+ self.mock_session.request.return_value = mock_response
68
+ metadata_index = 0
69
+ metadata_rank = int(WORLD_SIZE) + metadata_index
70
+ metadata_path = f"v1/cp/checkpoints/{NAMESPACE}/{metadata_rank}/{self.STEP}"
71
+
72
+ # Act
73
+ self.client.delete_checkpoint(step=self.STEP, metadata_index=metadata_index)
74
+
75
+ # Assert
76
+ self.assert_http_adapter_and_retry_config()
77
+ self.mock_session.request.assert_called_once_with(
78
+ method="DELETE",
79
+ url=f"{BASE_URL}/{metadata_path}",
80
+ params=None,
81
+ data=None,
82
+ headers=None,
83
+ timeout=InMemoryClientConfig.request_timeout
84
+ )
85
+
86
+ def test_delete_checkpoint_with_custom_timeout(self):
87
+ # Arrange
88
+ mock_response = Mock(status_code=200)
89
+ self.mock_session.request.return_value = mock_response
90
+
91
+ # Act
92
+ self.client.delete_checkpoint(step=self.STEP, timeout=REQUEST_TIMEOUT)
93
+
94
+ # Assert
95
+ self.assert_http_adapter_and_retry_config()
96
+ self.mock_session.request.assert_called_once_with(
97
+ method="DELETE",
98
+ url=f"{BASE_URL}/{self.checkpoint_path}",
99
+ params=None,
100
+ data=None,
101
+ headers=None,
102
+ timeout=REQUEST_TIMEOUT
103
+ )
104
+
105
+ def test_delete_checkpoint_with_string_step(self):
106
+ # Arrange
107
+ mock_response = Mock(status_code=200)
108
+ self.mock_session.request.return_value = mock_response
109
+ step = "latest"
110
+ path = f"v1/cp/checkpoints/{NAMESPACE}/{RANK}/{step}"
111
+
112
+ # Act
113
+ self.client.delete_checkpoint(step=step)
114
+
115
+ # Assert
116
+ self.assert_http_adapter_and_retry_config()
117
+ self.mock_session.request.assert_called_once_with(
118
+ method="DELETE",
119
+ url=f"{BASE_URL}/{path}",
120
+ params=None,
121
+ data=None,
122
+ headers=None,
123
+ timeout=InMemoryClientConfig.request_timeout
124
+ )
125
+
126
+ def test_delete_checkpoint_invalid_metadata_index(self):
127
+ # Act & Assert
128
+ with pytest.raises(InMemoryConfigError) as exc_info:
129
+ self.client.delete_checkpoint(step=self.STEP, metadata_index=999)
130
+ assert "Invalid metadata_index" in str(exc_info.value)
131
+
132
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
133
+ def test_delete_checkpoint_request_errors(self, test_case):
134
+ # Arrange
135
+ if "exception" in test_case["response"]:
136
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
137
+ else:
138
+ mock_response = Mock()
139
+ mock_response.status_code = test_case["response"]["status_code"]
140
+ mock_response.text = test_case["response"]["text"]
141
+ self.mock_session.request.return_value = mock_response
142
+
143
+ # Act & Assert
144
+ self.assert_request_error(
145
+ test_case,
146
+ self.client.delete_checkpoint,
147
+ step=self.STEP
148
+ )
@@ -0,0 +1,70 @@
1
+ from unittest.mock import Mock, call
2
+ import pytest
3
+
4
+ from amzn_sagemaker_checkpointing.config.in_memory_client import InMemoryClientConfig
5
+ from utils.test_base import (
6
+ InMemoryCheckpointClientTest,
7
+ BASE_URL,
8
+ NAMESPACE,
9
+ REQUEST_ERROR_CASES,
10
+ REQUEST_TIMEOUT,
11
+ )
12
+
13
+
14
+ class TestDeleteNamespace(InMemoryCheckpointClientTest):
15
+ def setup_method(self):
16
+ super().setup_method()
17
+ self.namespace_path = f"v1/cp/namespaces/{NAMESPACE}"
18
+
19
+ @pytest.mark.parametrize(
20
+ "params",
21
+ [
22
+ {"timeout": None},
23
+ {"timeout": REQUEST_TIMEOUT},
24
+ ],
25
+ )
26
+ def test_delete_namespace_success(self, params):
27
+ # Arrange
28
+ mock_response = Mock()
29
+ mock_response.status_code = 200
30
+ self.mock_session.request.return_value = mock_response
31
+
32
+ # Act
33
+ self.client.delete_namespace(**params)
34
+
35
+ # Assert
36
+ self.assert_http_adapter_and_retry_config()
37
+ self.mock_session.request.assert_called_once_with(
38
+ method="DELETE",
39
+ url=f"{BASE_URL}/{self.namespace_path}",
40
+ params=None,
41
+ data=None,
42
+ headers=None,
43
+ timeout=params["timeout"] or InMemoryClientConfig.request_timeout
44
+ )
45
+
46
+ @pytest.mark.parametrize("test_case", REQUEST_ERROR_CASES)
47
+ def test_delete_namespace_request_errors(self, test_case):
48
+ # Arrange
49
+ if "exception" in test_case["response"]:
50
+ self.mock_session.request.side_effect = test_case["response"]["exception"]
51
+ else:
52
+ mock_response = Mock()
53
+ mock_response.status_code = test_case["response"]["status_code"]
54
+ mock_response.text = test_case["response"]["text"]
55
+ self.mock_session.request.return_value = mock_response
56
+
57
+ # Act & Assert
58
+ self.assert_http_adapter_and_retry_config()
59
+ self.assert_request_error(
60
+ test_case,
61
+ self.client.delete_namespace
62
+ )
63
+ self.mock_session.request.assert_called_once_with(
64
+ method="DELETE",
65
+ url=f"{BASE_URL}/{self.namespace_path}",
66
+ params=None,
67
+ data=None,
68
+ headers=None,
69
+ timeout=InMemoryClientConfig.request_timeout
70
+ )