nshtrainer 0.11.6__tar.gz → 0.11.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/pyproject.toml +1 -1
  3. nshtrainer-0.11.7/src/nshtrainer/_checkpoint/saver.py +51 -0
  4. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +18 -22
  5. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +41 -29
  6. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +4 -13
  7. nshtrainer-0.11.6/src/nshtrainer/_checkpoint/saver.py +0 -61
  8. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/README.md +0 -0
  9. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/__init__.py +0 -0
  10. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_checkpoint/loader.py +0 -0
  11. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  12. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/__init__.py +0 -0
  13. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  14. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  15. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  16. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/__init__.py +0 -0
  17. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  18. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/actsave.py +0 -0
  19. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/base.py +0 -0
  20. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  21. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  22. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  23. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/ema.py +0 -0
  24. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  25. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  26. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/interval.py +0 -0
  27. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  28. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  29. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/print_table.py +0 -0
  30. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  31. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/timer.py +0 -0
  32. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  33. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/__init__.py +0 -0
  34. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  35. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/transform.py +0 -0
  36. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/__init__.py +0 -0
  37. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/_experimental.py +0 -0
  38. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/actsave.py +0 -0
  39. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/callbacks.py +0 -0
  40. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/config.py +0 -0
  41. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/data.py +0 -0
  42. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/log.py +0 -0
  43. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  44. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/model.py +0 -0
  45. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/nn.py +0 -0
  46. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/optimizer.py +0 -0
  47. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/runner.py +0 -0
  48. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/snapshot.py +0 -0
  49. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/snoop.py +0 -0
  50. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/trainer.py +0 -0
  51. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/typecheck.py +0 -0
  52. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/util.py +0 -0
  53. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  54. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  55. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  56. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  57. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/metrics/__init__.py +0 -0
  58. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/metrics/_config.py +0 -0
  59. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/__init__.py +0 -0
  60. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/base.py +0 -0
  61. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/config.py +0 -0
  62. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/callback.py +0 -0
  63. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/debug.py +0 -0
  64. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/distributed.py +0 -0
  65. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/logger.py +0 -0
  66. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/profiler.py +0 -0
  67. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  68. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  69. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/__init__.py +0 -0
  70. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/mlp.py +0 -0
  71. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/module_dict.py +0 -0
  72. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/module_list.py +0 -0
  73. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/nonlinearity.py +0 -0
  74. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/optimizer.py +0 -0
  75. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/runner.py +0 -0
  76. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/scripts/find_packages.py +0 -0
  77. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/__init__.py +0 -0
  78. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  79. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  80. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/signal_connector.py +0 -0
  81. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/trainer.py +0 -0
  82. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/_environment_info.py +0 -0
  83. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/_useful_types.py +0 -0
  84. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/environment.py +0 -0
  85. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/seed.py +0 -0
  86. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/slurm.py +0 -0
  87. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/typed.py +0 -0
  88. {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.6
3
+ Version: 0.11.7
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.6"
3
+ version = "0.11.7"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -0,0 +1,51 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ from lightning.pytorch import Trainer
6
+
7
+ from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
+
9
+
10
+ def _link_checkpoint(
11
+ filepath: str | Path | os.PathLike,
12
+ linkpath: str | Path | os.PathLike,
13
+ *,
14
+ metadata: bool,
15
+ ):
16
+ if not isinstance(filepath, Path):
17
+ filepath = Path(filepath)
18
+ if not isinstance(linkpath, Path):
19
+ linkpath = Path(linkpath)
20
+
21
+ if linkpath.exists():
22
+ if linkpath.is_symlink() or linkpath.is_file():
23
+ linkpath.unlink()
24
+ elif linkpath.is_dir():
25
+ shutil.rmtree(linkpath)
26
+ _remove_checkpoint_metadata(linkpath)
27
+
28
+ try:
29
+ target_path = filepath.relative_to(linkpath.parent)
30
+ linkpath.symlink_to(target_path)
31
+ except OSError:
32
+ # on Windows, special permissions are required to create symbolic links as a regular user
33
+ # fall back to copying the file
34
+ shutil.copy(filepath, linkpath)
35
+
36
+ if metadata:
37
+ _link_checkpoint_metadata(filepath, linkpath)
38
+
39
+
40
+ def _remove_checkpoint(
41
+ trainer: Trainer,
42
+ filepath: str | Path | os.PathLike,
43
+ *,
44
+ metadata: bool,
45
+ ):
46
+ if not isinstance(filepath, Path):
47
+ filepath = Path(filepath)
48
+
49
+ trainer.strategy.remove_checkpoint(filepath)
50
+ if metadata:
51
+ _remove_checkpoint_metadata(filepath)
@@ -90,10 +90,6 @@ class BestCheckpoint(Checkpoint):
90
90
  filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
91
91
  return self.dirpath / filename
92
92
 
93
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
94
- for ckpt_path in ckpt_paths:
95
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
96
-
97
93
  def _get_metric_value(self, metrics: dict[str, Any]):
98
94
  return metrics.get(
99
95
  self.metric.validation_monitor,
@@ -101,11 +97,16 @@ class BestCheckpoint(Checkpoint):
101
97
  )
102
98
 
103
99
  def _sorted_ckpts(self):
100
+ """
101
+ Get sorted checkpoints by the metric value.
102
+
103
+ Sort order: best -> worst
104
+ """
104
105
  ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
105
106
  return _sort_ckpts_by_metadata(
106
107
  ckpt_paths,
107
108
  key=lambda meta, _: self._get_metric_value(meta.metrics),
108
- reverse=(self.metric.mode == "min"),
109
+ reverse=(self.metric.mode == "max"),
109
110
  )
110
111
 
111
112
  def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
@@ -119,13 +120,7 @@ class BestCheckpoint(Checkpoint):
119
120
  if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
120
121
  return
121
122
 
122
- _link_checkpoint(
123
- trainer,
124
- best_ckpt_path,
125
- symlink_path,
126
- metadata=True,
127
- barrier=False,
128
- )
123
+ _link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
129
124
  log.debug(f"Created best symlink: {symlink_path}")
130
125
 
131
126
  def _save_best_checkpoint(self, trainer: Trainer):
@@ -159,21 +154,22 @@ class BestCheckpoint(Checkpoint):
159
154
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
160
155
  log.debug(f"Saved best checkpoint: {filepath}")
161
156
 
162
- # Remove worst checkpoint if we've reached save_top_k
163
- # NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
164
- if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
157
+ if trainer.is_global_zero:
165
158
  # Get the sorted checkpoints again because now we have added a new checkpoint.
166
159
  # We could optimize this by adding the new checkpoint to the sorted list,
167
160
  # and then sorting it in place, but this is simpler.
168
161
  sorted_ckpts = self._sorted_ckpts()
169
- self._remove_checkpoints(
170
- trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
171
- )
172
162
 
173
- # Create symlink to best model
174
- if sorted_ckpts:
175
- _, best_ckpt_path = sorted_ckpts[0]
176
- self._create_symlink(trainer, best_ckpt_path)
163
+ # Remove worst checkpoint if we've reached save_top_k
164
+ if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
165
+ # NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
166
+ for _, ckpt_path in sorted_ckpts[topk:]:
167
+ _remove_checkpoint(trainer, ckpt_path, metadata=True)
168
+
169
+ # Create symlink to best model
170
+ if sorted_ckpts:
171
+ _, best_ckpt_path = sorted_ckpts[0]
172
+ self._create_symlink(trainer, best_ckpt_path)
177
173
 
178
174
  # Update the last global step saved
179
175
  self._last_global_step_saved = trainer.global_step
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
51
51
  self.config = config
52
52
  self.dirpath = dirpath
53
53
 
54
+ self._last_global_step_saved = 0
55
+
54
56
  @override
55
57
  def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
56
58
  self._save_new_checkpoint(trainer)
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
67
69
  filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
68
70
  return self.dirpath / filename
69
71
 
70
- def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
71
- for ckpt_path in ckpt_paths:
72
- _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
73
-
74
72
  def _remove_old_checkpoints(self, trainer: Trainer):
75
73
  if (latest_k := self.config.latest_k) == "all":
76
74
  return
77
75
 
78
- # NOTE: We add 1 to the latest_k here because
79
- # we're about to save a new checkpoint.
80
- latest_k += 1
81
-
82
76
  # Get all configs, ignoring the latest symlink
83
- ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
77
+ ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
84
78
  # Ignore the latest symlink
85
79
  if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
86
- ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
80
+ ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
87
81
 
88
82
  # Sort by epoch, then step, then last modified
89
- metadata_and_ckpt_paths = _sort_ckpts_by_metadata(
90
- ckpt_paths,
83
+ ckpts = _sort_ckpts_by_metadata(
84
+ ckpts,
91
85
  key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
92
86
  reverse=True,
93
87
  )
94
88
 
95
89
  # Remove all but the latest k checkpoints
96
- ckpts_to_remove = metadata_and_ckpt_paths[latest_k:]
97
- self._remove_checkpoints(trainer, [p for _, p in ckpts_to_remove])
90
+ # NOTE: We add 1 to the latest_k here because
91
+ # we're about to save a new checkpoint.
92
+ for _, ckpt_path in ckpts[latest_k:]:
93
+ _remove_checkpoint(trainer, ckpt_path, metadata=True)
98
94
 
99
95
  def _save_new_checkpoint(self, trainer: Trainer):
100
- # Remove old checkpoints
101
- if trainer.is_global_zero:
102
- self._remove_old_checkpoints(trainer)
103
- trainer.strategy.barrier()
96
+ if self._should_skip_saving_checkpoint(trainer):
97
+ return
104
98
 
105
99
  # Save the new checkpoint
106
100
  filepath = self._ckpt_path(trainer)
107
101
  trainer.save_checkpoint(filepath, self.config.save_weights_only)
108
102
 
109
- # Create the latest symlink
110
- if (symlink_filename := self._latest_symlink_filename()) is not None:
111
- symlink_path = self.dirpath / symlink_filename
112
- _link_checkpoint(
113
- trainer,
114
- filepath,
115
- symlink_path,
116
- barrier=True,
117
- metadata=True,
118
- )
119
- log.debug(f"Created latest symlink: {symlink_path}")
103
+ if trainer.is_global_zero:
104
+ # Remove old checkpoints
105
+ self._remove_old_checkpoints(trainer)
106
+
107
+ # Create the latest symlink
108
+ if (symlink_filename := self._latest_symlink_filename()) is not None:
109
+ symlink_path = self.dirpath / symlink_filename
110
+ _link_checkpoint(filepath, symlink_path, metadata=True)
111
+ log.debug(f"Created latest symlink: {symlink_path}")
112
+
113
+ # Set the last global step saved
114
+ self._last_global_step_saved = trainer.global_step
115
+
116
+ # Barrier to ensure all processes have saved the checkpoint before continuing
117
+ trainer.strategy.barrier()
118
+
119
+ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
120
+ from lightning.pytorch.trainer.states import TrainerFn
121
+
122
+ return (
123
+ bool(
124
+ getattr(trainer, "fast_dev_run", False)
125
+ ) # disable checkpointing with fast_dev_run
126
+ or trainer.state.fn
127
+ != TrainerFn.FITTING # don't save anything during non-fit
128
+ or trainer.sanity_checking # don't save anything during sanity check
129
+ or self._last_global_step_saved
130
+ == trainer.global_step # already saved at the last step
131
+ )
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
198
198
 
199
199
  @override
200
200
  def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
201
- return _link_checkpoint(
202
- trainer,
203
- filepath,
204
- linkpath,
205
- barrier=True,
206
- metadata=True,
207
- )
201
+ if trainer.is_global_zero:
202
+ _link_checkpoint(filepath, linkpath, metadata=True)
203
+ trainer.strategy.barrier()
208
204
 
209
205
  @override
210
206
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
211
- return _ckpt_saver_remove_checkpoint(
212
- trainer,
213
- filepath,
214
- metadata=True,
215
- barrier=False,
216
- )
207
+ _ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
@@ -1,61 +0,0 @@
1
- import os
2
- import shutil
3
- from pathlib import Path
4
-
5
- from lightning.pytorch import Trainer
6
-
7
- from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
8
-
9
-
10
- def _link_checkpoint(
11
- trainer: Trainer,
12
- filepath: str | Path | os.PathLike,
13
- linkpath: str | Path | os.PathLike,
14
- *,
15
- barrier: bool,
16
- metadata: bool,
17
- ):
18
- if not isinstance(filepath, Path):
19
- filepath = Path(filepath)
20
- if not isinstance(linkpath, Path):
21
- linkpath = Path(linkpath)
22
-
23
- if trainer.is_global_zero:
24
- if linkpath.exists():
25
- if linkpath.is_symlink() or linkpath.is_file():
26
- linkpath.unlink()
27
- elif linkpath.is_dir():
28
- shutil.rmtree(linkpath)
29
- _remove_checkpoint_metadata(linkpath)
30
-
31
- try:
32
- target_path = filepath.relative_to(linkpath.parent)
33
- linkpath.symlink_to(target_path)
34
- except OSError:
35
- # on Windows, special permissions are required to create symbolic links as a regular user
36
- # fall back to copying the file
37
- shutil.copy(filepath, linkpath)
38
-
39
- if metadata:
40
- _link_checkpoint_metadata(filepath, linkpath)
41
- if barrier:
42
- trainer.strategy.barrier()
43
-
44
-
45
- def _remove_checkpoint(
46
- trainer: Trainer,
47
- filepath: str | Path | os.PathLike,
48
- *,
49
- metadata: bool,
50
- barrier: bool,
51
- ):
52
- if not isinstance(filepath, Path):
53
- filepath = Path(filepath)
54
-
55
- if trainer.is_global_zero:
56
- trainer.strategy.remove_checkpoint(filepath)
57
- if metadata:
58
- _remove_checkpoint_metadata(filepath)
59
-
60
- if barrier:
61
- trainer.strategy.barrier()
File without changes