nshtrainer 0.11.6__tar.gz → 0.11.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/PKG-INFO +1 -1
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/pyproject.toml +1 -1
- nshtrainer-0.11.7/src/nshtrainer/_checkpoint/saver.py +51 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +18 -22
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +41 -29
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +4 -13
- nshtrainer-0.11.6/src/nshtrainer/_checkpoint/saver.py +0 -61
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/README.md +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import Trainer
|
|
6
|
+
|
|
7
|
+
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _link_checkpoint(
|
|
11
|
+
filepath: str | Path | os.PathLike,
|
|
12
|
+
linkpath: str | Path | os.PathLike,
|
|
13
|
+
*,
|
|
14
|
+
metadata: bool,
|
|
15
|
+
):
|
|
16
|
+
if not isinstance(filepath, Path):
|
|
17
|
+
filepath = Path(filepath)
|
|
18
|
+
if not isinstance(linkpath, Path):
|
|
19
|
+
linkpath = Path(linkpath)
|
|
20
|
+
|
|
21
|
+
if linkpath.exists():
|
|
22
|
+
if linkpath.is_symlink() or linkpath.is_file():
|
|
23
|
+
linkpath.unlink()
|
|
24
|
+
elif linkpath.is_dir():
|
|
25
|
+
shutil.rmtree(linkpath)
|
|
26
|
+
_remove_checkpoint_metadata(linkpath)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
target_path = filepath.relative_to(linkpath.parent)
|
|
30
|
+
linkpath.symlink_to(target_path)
|
|
31
|
+
except OSError:
|
|
32
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
33
|
+
# fall back to copying the file
|
|
34
|
+
shutil.copy(filepath, linkpath)
|
|
35
|
+
|
|
36
|
+
if metadata:
|
|
37
|
+
_link_checkpoint_metadata(filepath, linkpath)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _remove_checkpoint(
|
|
41
|
+
trainer: Trainer,
|
|
42
|
+
filepath: str | Path | os.PathLike,
|
|
43
|
+
*,
|
|
44
|
+
metadata: bool,
|
|
45
|
+
):
|
|
46
|
+
if not isinstance(filepath, Path):
|
|
47
|
+
filepath = Path(filepath)
|
|
48
|
+
|
|
49
|
+
trainer.strategy.remove_checkpoint(filepath)
|
|
50
|
+
if metadata:
|
|
51
|
+
_remove_checkpoint_metadata(filepath)
|
{nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
@@ -90,10 +90,6 @@ class BestCheckpoint(Checkpoint):
|
|
|
90
90
|
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
91
91
|
return self.dirpath / filename
|
|
92
92
|
|
|
93
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
94
|
-
for ckpt_path in ckpt_paths:
|
|
95
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
96
|
-
|
|
97
93
|
def _get_metric_value(self, metrics: dict[str, Any]):
|
|
98
94
|
return metrics.get(
|
|
99
95
|
self.metric.validation_monitor,
|
|
@@ -101,11 +97,16 @@ class BestCheckpoint(Checkpoint):
|
|
|
101
97
|
)
|
|
102
98
|
|
|
103
99
|
def _sorted_ckpts(self):
|
|
100
|
+
"""
|
|
101
|
+
Get sorted checkpoints by the metric value.
|
|
102
|
+
|
|
103
|
+
Sort order: best -> worst
|
|
104
|
+
"""
|
|
104
105
|
ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
105
106
|
return _sort_ckpts_by_metadata(
|
|
106
107
|
ckpt_paths,
|
|
107
108
|
key=lambda meta, _: self._get_metric_value(meta.metrics),
|
|
108
|
-
reverse=(self.metric.mode == "
|
|
109
|
+
reverse=(self.metric.mode == "max"),
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
def _create_symlink(self, trainer: Trainer, best_ckpt_path: Path):
|
|
@@ -119,13 +120,7 @@ class BestCheckpoint(Checkpoint):
|
|
|
119
120
|
if symlink_path.exists() and symlink_path.resolve() == best_ckpt_path:
|
|
120
121
|
return
|
|
121
122
|
|
|
122
|
-
_link_checkpoint(
|
|
123
|
-
trainer,
|
|
124
|
-
best_ckpt_path,
|
|
125
|
-
symlink_path,
|
|
126
|
-
metadata=True,
|
|
127
|
-
barrier=False,
|
|
128
|
-
)
|
|
123
|
+
_link_checkpoint(best_ckpt_path, symlink_path, metadata=True)
|
|
129
124
|
log.debug(f"Created best symlink: {symlink_path}")
|
|
130
125
|
|
|
131
126
|
def _save_best_checkpoint(self, trainer: Trainer):
|
|
@@ -159,21 +154,22 @@ class BestCheckpoint(Checkpoint):
|
|
|
159
154
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
160
155
|
log.debug(f"Saved best checkpoint: {filepath}")
|
|
161
156
|
|
|
162
|
-
|
|
163
|
-
# NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
|
|
164
|
-
if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
|
|
157
|
+
if trainer.is_global_zero:
|
|
165
158
|
# Get the sorted checkpoints again because now we have added a new checkpoint.
|
|
166
159
|
# We could optimize this by adding the new checkpoint to the sorted list,
|
|
167
160
|
# and then sorting it in place, but this is simpler.
|
|
168
161
|
sorted_ckpts = self._sorted_ckpts()
|
|
169
|
-
self._remove_checkpoints(
|
|
170
|
-
trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
|
|
171
|
-
)
|
|
172
162
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
163
|
+
# Remove worst checkpoint if we've reached save_top_k
|
|
164
|
+
if (topk := self.config.save_top_k) != "all" and len(sorted_ckpts) > topk:
|
|
165
|
+
# NOTE: Sort order is best -> worst. Let's get the worst checkpoints.
|
|
166
|
+
for _, ckpt_path in sorted_ckpts[topk:]:
|
|
167
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
168
|
+
|
|
169
|
+
# Create symlink to best model
|
|
170
|
+
if sorted_ckpts:
|
|
171
|
+
_, best_ckpt_path = sorted_ckpts[0]
|
|
172
|
+
self._create_symlink(trainer, best_ckpt_path)
|
|
177
173
|
|
|
178
174
|
# Update the last global step saved
|
|
179
175
|
self._last_global_step_saved = trainer.global_step
|
|
@@ -51,6 +51,8 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
51
51
|
self.config = config
|
|
52
52
|
self.dirpath = dirpath
|
|
53
53
|
|
|
54
|
+
self._last_global_step_saved = 0
|
|
55
|
+
|
|
54
56
|
@override
|
|
55
57
|
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
56
58
|
self._save_new_checkpoint(trainer)
|
|
@@ -67,53 +69,63 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
67
69
|
filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
|
|
68
70
|
return self.dirpath / filename
|
|
69
71
|
|
|
70
|
-
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
71
|
-
for ckpt_path in ckpt_paths:
|
|
72
|
-
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
73
|
-
|
|
74
72
|
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
75
73
|
if (latest_k := self.config.latest_k) == "all":
|
|
76
74
|
return
|
|
77
75
|
|
|
78
|
-
# NOTE: We add 1 to the latest_k here because
|
|
79
|
-
# we're about to save a new checkpoint.
|
|
80
|
-
latest_k += 1
|
|
81
|
-
|
|
82
76
|
# Get all configs, ignoring the latest symlink
|
|
83
|
-
|
|
77
|
+
ckpts = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
|
|
84
78
|
# Ignore the latest symlink
|
|
85
79
|
if (latest_symlink_filename := self._latest_symlink_filename()) is not None:
|
|
86
|
-
|
|
80
|
+
ckpts = [p for p in ckpts if p.name != latest_symlink_filename]
|
|
87
81
|
|
|
88
82
|
# Sort by epoch, then step, then last modified
|
|
89
|
-
|
|
90
|
-
|
|
83
|
+
ckpts = _sort_ckpts_by_metadata(
|
|
84
|
+
ckpts,
|
|
91
85
|
key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
|
|
92
86
|
reverse=True,
|
|
93
87
|
)
|
|
94
88
|
|
|
95
89
|
# Remove all but the latest k checkpoints
|
|
96
|
-
|
|
97
|
-
|
|
90
|
+
# NOTE: We add 1 to the latest_k here because
|
|
91
|
+
# we're about to save a new checkpoint.
|
|
92
|
+
for _, ckpt_path in ckpts[latest_k:]:
|
|
93
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True)
|
|
98
94
|
|
|
99
95
|
def _save_new_checkpoint(self, trainer: Trainer):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self._remove_old_checkpoints(trainer)
|
|
103
|
-
trainer.strategy.barrier()
|
|
96
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
97
|
+
return
|
|
104
98
|
|
|
105
99
|
# Save the new checkpoint
|
|
106
100
|
filepath = self._ckpt_path(trainer)
|
|
107
101
|
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
108
102
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
symlink_path
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
103
|
+
if trainer.is_global_zero:
|
|
104
|
+
# Remove old checkpoints
|
|
105
|
+
self._remove_old_checkpoints(trainer)
|
|
106
|
+
|
|
107
|
+
# Create the latest symlink
|
|
108
|
+
if (symlink_filename := self._latest_symlink_filename()) is not None:
|
|
109
|
+
symlink_path = self.dirpath / symlink_filename
|
|
110
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
111
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
112
|
+
|
|
113
|
+
# Set the last global step saved
|
|
114
|
+
self._last_global_step_saved = trainer.global_step
|
|
115
|
+
|
|
116
|
+
# Barrier to ensure all processes have saved the checkpoint before continuing
|
|
117
|
+
trainer.strategy.barrier()
|
|
118
|
+
|
|
119
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
120
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
bool(
|
|
124
|
+
getattr(trainer, "fast_dev_run", False)
|
|
125
|
+
) # disable checkpointing with fast_dev_run
|
|
126
|
+
or trainer.state.fn
|
|
127
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
128
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
129
|
+
or self._last_global_step_saved
|
|
130
|
+
== trainer.global_step # already saved at the last step
|
|
131
|
+
)
|
{nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py
RENAMED
|
@@ -198,19 +198,10 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
198
198
|
|
|
199
199
|
@override
|
|
200
200
|
def _link_checkpoint(self, trainer: Trainer, filepath: str, linkpath: str): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
linkpath,
|
|
205
|
-
barrier=True,
|
|
206
|
-
metadata=True,
|
|
207
|
-
)
|
|
201
|
+
if trainer.is_global_zero:
|
|
202
|
+
_link_checkpoint(filepath, linkpath, metadata=True)
|
|
203
|
+
trainer.strategy.barrier()
|
|
208
204
|
|
|
209
205
|
@override
|
|
210
206
|
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
211
|
-
|
|
212
|
-
trainer,
|
|
213
|
-
filepath,
|
|
214
|
-
metadata=True,
|
|
215
|
-
barrier=False,
|
|
216
|
-
)
|
|
207
|
+
_ckpt_saver_remove_checkpoint(trainer, filepath, metadata=True)
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import shutil
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
|
|
5
|
-
from lightning.pytorch import Trainer
|
|
6
|
-
|
|
7
|
-
from .metadata import _link_checkpoint_metadata, _remove_checkpoint_metadata
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def _link_checkpoint(
|
|
11
|
-
trainer: Trainer,
|
|
12
|
-
filepath: str | Path | os.PathLike,
|
|
13
|
-
linkpath: str | Path | os.PathLike,
|
|
14
|
-
*,
|
|
15
|
-
barrier: bool,
|
|
16
|
-
metadata: bool,
|
|
17
|
-
):
|
|
18
|
-
if not isinstance(filepath, Path):
|
|
19
|
-
filepath = Path(filepath)
|
|
20
|
-
if not isinstance(linkpath, Path):
|
|
21
|
-
linkpath = Path(linkpath)
|
|
22
|
-
|
|
23
|
-
if trainer.is_global_zero:
|
|
24
|
-
if linkpath.exists():
|
|
25
|
-
if linkpath.is_symlink() or linkpath.is_file():
|
|
26
|
-
linkpath.unlink()
|
|
27
|
-
elif linkpath.is_dir():
|
|
28
|
-
shutil.rmtree(linkpath)
|
|
29
|
-
_remove_checkpoint_metadata(linkpath)
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
target_path = filepath.relative_to(linkpath.parent)
|
|
33
|
-
linkpath.symlink_to(target_path)
|
|
34
|
-
except OSError:
|
|
35
|
-
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
36
|
-
# fall back to copying the file
|
|
37
|
-
shutil.copy(filepath, linkpath)
|
|
38
|
-
|
|
39
|
-
if metadata:
|
|
40
|
-
_link_checkpoint_metadata(filepath, linkpath)
|
|
41
|
-
if barrier:
|
|
42
|
-
trainer.strategy.barrier()
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def _remove_checkpoint(
|
|
46
|
-
trainer: Trainer,
|
|
47
|
-
filepath: str | Path | os.PathLike,
|
|
48
|
-
*,
|
|
49
|
-
metadata: bool,
|
|
50
|
-
barrier: bool,
|
|
51
|
-
):
|
|
52
|
-
if not isinstance(filepath, Path):
|
|
53
|
-
filepath = Path(filepath)
|
|
54
|
-
|
|
55
|
-
if trainer.is_global_zero:
|
|
56
|
-
trainer.strategy.remove_checkpoint(filepath)
|
|
57
|
-
if metadata:
|
|
58
|
-
_remove_checkpoint_metadata(filepath)
|
|
59
|
-
|
|
60
|
-
if barrier:
|
|
61
|
-
trainer.strategy.barrier()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/_experimental/flops/module_tracker.py
RENAMED
|
File without changes
|
|
File without changes
|
{nshtrainer-0.11.6 → nshtrainer-0.11.7}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|