nshtrainer 0.11.7__tar.gz → 0.11.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/PKG-INFO +1 -1
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/pyproject.toml +1 -1
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/loader.py +4 -4
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/metadata.py +37 -35
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/__init__.py +3 -8
- nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/__init__.py +12 -0
- nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/_base.py +175 -0
- nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +70 -0
- nshtrainer-0.11.9/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +39 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/__init__.py +2 -4
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/config.py +4 -37
- nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -16
- nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -192
- nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/latest_epoch_checkpoint.py +0 -131
- nshtrainer-0.11.7/src/nshtrainer/callbacks/checkpoint/model_checkpoint.py +0 -207
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/README.md +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.11.7 → nshtrainer-0.11.9}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -236,7 +236,7 @@ def _load_ckpt_meta(
|
|
|
236
236
|
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
|
237
237
|
match on_error:
|
|
238
238
|
case "warn":
|
|
239
|
-
log.
|
|
239
|
+
log.warning(error_msg)
|
|
240
240
|
case "raise":
|
|
241
241
|
raise ValueError(error_msg)
|
|
242
242
|
case _:
|
|
@@ -325,13 +325,13 @@ def _resolve_checkpoint(
|
|
|
325
325
|
),
|
|
326
326
|
]
|
|
327
327
|
if not candidates:
|
|
328
|
-
log.
|
|
328
|
+
log.warning(
|
|
329
329
|
"No checkpoint candidates found for `best` checkpoint strategy."
|
|
330
330
|
)
|
|
331
331
|
continue
|
|
332
332
|
|
|
333
333
|
if (metric := strategy.metric or root_config.primary_metric) is None:
|
|
334
|
-
log.
|
|
334
|
+
log.warning(
|
|
335
335
|
"No metric specified for `best` checkpoint strategy, "
|
|
336
336
|
"and no primary metric is set in the configuration. "
|
|
337
337
|
"Skipping strategy."
|
|
@@ -360,7 +360,7 @@ def _resolve_checkpoint(
|
|
|
360
360
|
),
|
|
361
361
|
]
|
|
362
362
|
if not candidates:
|
|
363
|
-
log.
|
|
363
|
+
log.warning(
|
|
364
364
|
"No checkpoint candidates found for `last` checkpoint strategy."
|
|
365
365
|
)
|
|
366
366
|
continue
|
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import shutil
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
|
8
8
|
|
|
9
9
|
import nshconfig as C
|
|
10
10
|
import numpy as np
|
|
@@ -20,10 +20,11 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
METADATA_PATH_SUFFIX = ".metadata.json"
|
|
23
|
-
HPARAMS_PATH_SUFFIX = ".hparams.json"
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class CheckpointMetadata(C.Config):
|
|
26
|
+
PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
|
|
27
|
+
|
|
27
28
|
checkpoint_path: Path
|
|
28
29
|
checkpoint_filename: str
|
|
29
30
|
|
|
@@ -39,6 +40,8 @@ class CheckpointMetadata(C.Config):
|
|
|
39
40
|
metrics: dict[str, Any]
|
|
40
41
|
environment: EnvironmentConfig
|
|
41
42
|
|
|
43
|
+
hparams: dict[str, Any] | None
|
|
44
|
+
|
|
42
45
|
@classmethod
|
|
43
46
|
def from_file(cls, path: Path):
|
|
44
47
|
return cls.model_validate_json(path.read_text())
|
|
@@ -55,7 +58,10 @@ class CheckpointMetadata(C.Config):
|
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
def _generate_checkpoint_metadata(
|
|
58
|
-
config: "BaseConfig",
|
|
61
|
+
config: "BaseConfig",
|
|
62
|
+
trainer: "Trainer",
|
|
63
|
+
checkpoint_path: Path,
|
|
64
|
+
metadata_path: Path,
|
|
59
65
|
):
|
|
60
66
|
checkpoint_timestamp = datetime.datetime.now()
|
|
61
67
|
start_timestamp = trainer.start_time()
|
|
@@ -70,7 +76,11 @@ def _generate_checkpoint_metadata(
|
|
|
70
76
|
metrics[name] = metric
|
|
71
77
|
|
|
72
78
|
return CheckpointMetadata(
|
|
73
|
-
checkpoint_path=checkpoint_path,
|
|
79
|
+
# checkpoint_path=checkpoint_path,
|
|
80
|
+
# We should store the path as a relative path
|
|
81
|
+
# to the metadata file to avoid issues with
|
|
82
|
+
# moving the checkpoint directory
|
|
83
|
+
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
|
74
84
|
checkpoint_filename=checkpoint_path.name,
|
|
75
85
|
run_id=config.id,
|
|
76
86
|
name=config.run_name,
|
|
@@ -84,6 +94,7 @@ def _generate_checkpoint_metadata(
|
|
|
84
94
|
training_time=training_time,
|
|
85
95
|
metrics=metrics,
|
|
86
96
|
environment=config.environment,
|
|
97
|
+
hparams=config.model_dump(mode="json"),
|
|
87
98
|
)
|
|
88
99
|
|
|
89
100
|
|
|
@@ -93,36 +104,28 @@ def _write_checkpoint_metadata(
|
|
|
93
104
|
checkpoint_path: Path,
|
|
94
105
|
):
|
|
95
106
|
config = cast("BaseConfig", model.config)
|
|
96
|
-
|
|
107
|
+
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
|
+
metadata = _generate_checkpoint_metadata(
|
|
109
|
+
config, trainer, checkpoint_path, metadata_path
|
|
110
|
+
)
|
|
97
111
|
|
|
98
112
|
# Write the metadata to the checkpoint directory
|
|
99
113
|
try:
|
|
100
|
-
metadata_path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
101
114
|
metadata_path.write_text(metadata.model_dump_json(indent=4))
|
|
102
115
|
except Exception as e:
|
|
103
116
|
log.warning(f"Failed to write metadata to {checkpoint_path}: {e}")
|
|
104
117
|
else:
|
|
105
118
|
log.debug(f"Checkpoint metadata written to {checkpoint_path}")
|
|
106
119
|
|
|
107
|
-
|
|
120
|
+
|
|
121
|
+
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
122
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
108
123
|
try:
|
|
109
|
-
|
|
110
|
-
hparams_path.write_text(config.model_dump_json(indent=4))
|
|
124
|
+
path.unlink(missing_ok=True)
|
|
111
125
|
except Exception as e:
|
|
112
|
-
log.warning(f"Failed to
|
|
126
|
+
log.warning(f"Failed to remove {path}: {e}")
|
|
113
127
|
else:
|
|
114
|
-
log.debug(f"
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def _remove_checkpoint_metadata(checkpoint_path: Path):
|
|
118
|
-
for suffix in (METADATA_PATH_SUFFIX, HPARAMS_PATH_SUFFIX):
|
|
119
|
-
path = checkpoint_path.with_suffix(suffix)
|
|
120
|
-
try:
|
|
121
|
-
path.unlink(missing_ok=True)
|
|
122
|
-
except Exception as e:
|
|
123
|
-
log.warning(f"Failed to remove {path}: {e}")
|
|
124
|
-
else:
|
|
125
|
-
log.debug(f"Removed {path}")
|
|
128
|
+
log.debug(f"Removed {path}")
|
|
126
129
|
|
|
127
130
|
|
|
128
131
|
def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Path):
|
|
@@ -130,20 +133,19 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
|
|
|
130
133
|
_remove_checkpoint_metadata(linked_checkpoint_path)
|
|
131
134
|
|
|
132
135
|
# Link the metadata files to the new checkpoint
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
+
path = checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
137
|
+
linked_path = linked_checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
|
|
138
|
+
try:
|
|
136
139
|
try:
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
log.debug(f"Linked {path} to {linked_path}")
|
|
140
|
+
linked_path.symlink_to(path)
|
|
141
|
+
except OSError:
|
|
142
|
+
# on Windows, special permissions are required to create symbolic links as a regular user
|
|
143
|
+
# fall back to copying the file
|
|
144
|
+
shutil.copy(path, linked_path)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
log.warning(f"Failed to link {path} to {linked_path}: {e}")
|
|
147
|
+
else:
|
|
148
|
+
log.debug(f"Linked {path} to {linked_path}")
|
|
147
149
|
|
|
148
150
|
|
|
149
151
|
def _sort_ckpts_by_metadata(
|
|
@@ -6,12 +6,8 @@ from . import checkpoint as checkpoint
|
|
|
6
6
|
from .base import CallbackConfigBase as CallbackConfigBase
|
|
7
7
|
from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
8
8
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
9
|
-
from .checkpoint import
|
|
10
|
-
from .checkpoint import
|
|
11
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
12
|
-
)
|
|
13
|
-
from .checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
14
|
-
from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
9
|
+
from .checkpoint import LastCheckpoint as LastCheckpoint
|
|
10
|
+
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
15
11
|
from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
16
12
|
from .checkpoint import (
|
|
17
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
@@ -46,8 +42,7 @@ CallbackConfig = Annotated[
|
|
|
46
42
|
| GradientSkippingConfig
|
|
47
43
|
| EMAConfig
|
|
48
44
|
| BestCheckpointCallbackConfig
|
|
49
|
-
|
|
|
50
|
-
| LatestEpochCheckpointCallbackConfig
|
|
45
|
+
| LastCheckpointCallbackConfig
|
|
51
46
|
| OnExceptionCheckpointCallbackConfig
|
|
52
47
|
| WandbWatchConfig,
|
|
53
48
|
C.Field(discriminator="name"),
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
2
|
+
from .best_checkpoint import (
|
|
3
|
+
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
|
+
)
|
|
5
|
+
from .last_checkpoint import LastCheckpoint as LastCheckpoint
|
|
6
|
+
from .last_checkpoint import (
|
|
7
|
+
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
8
|
+
)
|
|
9
|
+
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
10
|
+
from .on_exception_checkpoint import (
|
|
11
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
12
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch import Trainer
|
|
9
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
10
|
+
from typing_extensions import TypeVar, override
|
|
11
|
+
|
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
|
13
|
+
from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
|
|
14
|
+
from ..base import CallbackConfigBase
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ...model.config import BaseConfig
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
|
23
|
+
dirpath: str | Path | None = None
|
|
24
|
+
"""Directory path to save the checkpoint file."""
|
|
25
|
+
|
|
26
|
+
filename: str | None = None
|
|
27
|
+
"""Checkpoint filename. This must not include the extension.
|
|
28
|
+
If None, the default filename will be used."""
|
|
29
|
+
|
|
30
|
+
save_weights_only: bool = False
|
|
31
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
32
|
+
|
|
33
|
+
save_symlink: bool = True
|
|
34
|
+
"""Whether to create a symlink to the saved checkpoint."""
|
|
35
|
+
|
|
36
|
+
topk: int | Literal["all"] = 1
|
|
37
|
+
"""The number of checkpoints to keep."""
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def create_checkpoint(
|
|
41
|
+
self,
|
|
42
|
+
root_config: "BaseConfig",
|
|
43
|
+
dirpath: Path,
|
|
44
|
+
) -> "CheckpointBase": ...
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def create_callbacks(self, root_config):
|
|
48
|
+
dirpath = Path(
|
|
49
|
+
self.dirpath
|
|
50
|
+
or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
yield self.create_checkpoint(root_config, dirpath)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
|
60
|
+
def __init__(self, config: TConfig, dirpath: Path):
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
self.config = config
|
|
64
|
+
self.dirpath = dirpath / self.name()
|
|
65
|
+
self.symlink_dirpath = dirpath
|
|
66
|
+
|
|
67
|
+
self._last_global_step_saved = 0
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def default_filename(self) -> str: ...
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def name(self) -> str: ...
|
|
74
|
+
|
|
75
|
+
def extension(self) -> str:
|
|
76
|
+
return ".ckpt"
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
|
80
|
+
|
|
81
|
+
def symlink_path(self):
|
|
82
|
+
if not self.config.save_symlink:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
return self.symlink_dirpath / f"{self.name()}{self.extension()}"
|
|
86
|
+
|
|
87
|
+
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
|
88
|
+
if (filename := self.config.filename) is None:
|
|
89
|
+
filename = self.default_filename()
|
|
90
|
+
filename = filename.format(**current_metrics)
|
|
91
|
+
return self.dirpath / f"{filename}{self.extension()}"
|
|
92
|
+
|
|
93
|
+
def remove_old_checkpoints(self, trainer: Trainer):
|
|
94
|
+
if (topk := self.config.topk) == "all":
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Get all the checkpoint metadata
|
|
98
|
+
metas = [
|
|
99
|
+
CheckpointMetadata.from_file(p)
|
|
100
|
+
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# Sort by the topk sort key
|
|
104
|
+
metas = sorted(metas, key=self.topk_sort_key)
|
|
105
|
+
|
|
106
|
+
# Now, the metas are sorted from the best to the worst,
|
|
107
|
+
# so we can remove the worst checkpoints
|
|
108
|
+
for meta in metas[topk:]:
|
|
109
|
+
if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
|
|
110
|
+
log.warning(
|
|
111
|
+
f"Checkpoint file not found: {old_ckpt_path}\n"
|
|
112
|
+
"Skipping removal of the checkpoint metadata."
|
|
113
|
+
)
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
_remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
|
117
|
+
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
|
118
|
+
|
|
119
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
|
120
|
+
current_metrics: dict[str, Any] = {
|
|
121
|
+
"epoch": trainer.current_epoch,
|
|
122
|
+
"step": trainer.global_step,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
for name, value in trainer.callback_metrics.items():
|
|
126
|
+
match value:
|
|
127
|
+
case torch.Tensor() if value.numel() == 1:
|
|
128
|
+
value = value.detach().cpu().item()
|
|
129
|
+
case np.ndarray() if value.size == 1:
|
|
130
|
+
value = value.item()
|
|
131
|
+
case _:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
current_metrics[name] = value
|
|
135
|
+
|
|
136
|
+
return current_metrics
|
|
137
|
+
|
|
138
|
+
def save_checkpoints(self, trainer: Trainer):
|
|
139
|
+
if self._should_skip_saving_checkpoint(trainer):
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
# Save the new checkpoint
|
|
143
|
+
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
|
144
|
+
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
|
145
|
+
|
|
146
|
+
if trainer.is_global_zero:
|
|
147
|
+
# Remove old checkpoints
|
|
148
|
+
self.remove_old_checkpoints(trainer)
|
|
149
|
+
|
|
150
|
+
# Create the latest symlink
|
|
151
|
+
if (symlink_filename := self.symlink_path()) is not None:
|
|
152
|
+
symlink_path = self.dirpath / symlink_filename
|
|
153
|
+
_link_checkpoint(filepath, symlink_path, metadata=True)
|
|
154
|
+
log.debug(f"Created latest symlink: {symlink_path}")
|
|
155
|
+
|
|
156
|
+
# Barrier to ensure all processes have saved the checkpoint,
|
|
157
|
+
# deleted the old checkpoints, and created the symlink before continuing
|
|
158
|
+
trainer.strategy.barrier()
|
|
159
|
+
|
|
160
|
+
# Set the last global step saved
|
|
161
|
+
self._last_global_step_saved = trainer.global_step
|
|
162
|
+
|
|
163
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
|
164
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
|
165
|
+
|
|
166
|
+
return (
|
|
167
|
+
bool(
|
|
168
|
+
getattr(trainer, "fast_dev_run", False)
|
|
169
|
+
) # disable checkpointing with fast_dev_run
|
|
170
|
+
or trainer.state.fn
|
|
171
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
|
172
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
|
173
|
+
or self._last_global_step_saved
|
|
174
|
+
== trainer.global_step # already saved at the last step
|
|
175
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from typing_extensions import final, override
|
|
7
|
+
|
|
8
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
9
|
+
|
|
10
|
+
from ...metrics._config import MetricConfig
|
|
11
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@final
|
|
17
|
+
class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
18
|
+
name: Literal["best_checkpoint"] = "best_checkpoint"
|
|
19
|
+
|
|
20
|
+
metric: MetricConfig | None = None
|
|
21
|
+
"""Metric to monitor, or `None` to use the default metric."""
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
25
|
+
# Resolve metric
|
|
26
|
+
if (metric := self.metric) is None and (
|
|
27
|
+
metric := root_config.primary_metric
|
|
28
|
+
) is None:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"No metric provided and no primary metric found in the root config"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
return BestCheckpoint(self, dirpath, metric)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@final
|
|
37
|
+
class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
38
|
+
@property
|
|
39
|
+
def _metric_name_normalized(self):
|
|
40
|
+
return self.metric.name.replace("/", "_").replace(" ", "_").replace(".", "_")
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
config: BestCheckpointCallbackConfig,
|
|
46
|
+
dirpath: Path,
|
|
47
|
+
metric: MetricConfig,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(config, dirpath)
|
|
50
|
+
self.metric = metric
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def name(self):
|
|
54
|
+
return f"best_{self._metric_name_normalized}"
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def default_filename(self):
|
|
58
|
+
return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
62
|
+
return metadata.metrics.get(
|
|
63
|
+
self.metric.validation_monitor,
|
|
64
|
+
float("-inf" if self.metric.mode == "max" else "inf"),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Events
|
|
68
|
+
@override
|
|
69
|
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
70
|
+
self.save_checkpoints(trainer)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from typing_extensions import final, override
|
|
6
|
+
|
|
7
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
|
8
|
+
|
|
9
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@final
|
|
15
|
+
class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
16
|
+
name: Literal["last_checkpoint"] = "last_checkpoint"
|
|
17
|
+
|
|
18
|
+
@override
|
|
19
|
+
def create_checkpoint(self, root_config, dirpath):
|
|
20
|
+
return LastCheckpoint(self, dirpath)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@final
|
|
24
|
+
class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
25
|
+
@override
|
|
26
|
+
def name(self):
|
|
27
|
+
return "last"
|
|
28
|
+
|
|
29
|
+
@override
|
|
30
|
+
def default_filename(self):
|
|
31
|
+
return "epoch{epoch:03d}-step{step:07d}"
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
|
35
|
+
return metadata.checkpoint_timestamp
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
39
|
+
self.save_checkpoints(trainer)
|
|
@@ -5,17 +5,15 @@ from .base import LightningModuleBase as LightningModuleBase
|
|
|
5
5
|
from .config import BaseConfig as BaseConfig
|
|
6
6
|
from .config import BaseLoggerConfig as BaseLoggerConfig
|
|
7
7
|
from .config import BaseProfilerConfig as BaseProfilerConfig
|
|
8
|
+
from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
8
9
|
from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
9
10
|
from .config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
10
11
|
from .config import DirectoryConfig as DirectoryConfig
|
|
11
12
|
from .config import EarlyStoppingConfig as EarlyStoppingConfig
|
|
12
13
|
from .config import GradientClippingConfig as GradientClippingConfig
|
|
13
|
-
from .config import
|
|
14
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
15
|
-
)
|
|
14
|
+
from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
16
15
|
from .config import LoggingConfig as LoggingConfig
|
|
17
16
|
from .config import MetricConfig as MetricConfig
|
|
18
|
-
from .config import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
|
|
19
17
|
from .config import (
|
|
20
18
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
21
19
|
)
|
|
@@ -39,8 +39,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
|
|
|
39
39
|
from ..callbacks import (
|
|
40
40
|
BestCheckpointCallbackConfig,
|
|
41
41
|
CallbackConfig,
|
|
42
|
-
|
|
43
|
-
ModelCheckpointCallbackConfig,
|
|
42
|
+
LastCheckpointCallbackConfig,
|
|
44
43
|
OnExceptionCheckpointCallbackConfig,
|
|
45
44
|
WandbWatchConfig,
|
|
46
45
|
)
|
|
@@ -771,9 +770,8 @@ class ReproducibilityConfig(C.Config):
|
|
|
771
770
|
|
|
772
771
|
|
|
773
772
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
774
|
-
|
|
775
|
-
|
|
|
776
|
-
| LatestEpochCheckpointCallbackConfig
|
|
773
|
+
BestCheckpointCallbackConfig
|
|
774
|
+
| LastCheckpointCallbackConfig
|
|
777
775
|
| OnExceptionCheckpointCallbackConfig,
|
|
778
776
|
C.Field(discriminator="name"),
|
|
779
777
|
]
|
|
@@ -784,9 +782,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
784
782
|
"""Enable checkpoint saving."""
|
|
785
783
|
|
|
786
784
|
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
787
|
-
# ModelCheckpointCallbackConfig(),
|
|
788
785
|
BestCheckpointCallbackConfig(),
|
|
789
|
-
|
|
786
|
+
LastCheckpointCallbackConfig(),
|
|
790
787
|
OnExceptionCheckpointCallbackConfig(),
|
|
791
788
|
]
|
|
792
789
|
"""Checkpoint callback configurations."""
|
|
@@ -804,36 +801,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
|
804
801
|
|
|
805
802
|
return True
|
|
806
803
|
|
|
807
|
-
@property
|
|
808
|
-
def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
|
|
809
|
-
return next(
|
|
810
|
-
(
|
|
811
|
-
callback
|
|
812
|
-
for callback in self.checkpoint_callbacks
|
|
813
|
-
if isinstance(callback, ModelCheckpointCallbackConfig)
|
|
814
|
-
),
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
@property
|
|
818
|
-
def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
|
|
819
|
-
return next(
|
|
820
|
-
(
|
|
821
|
-
callback
|
|
822
|
-
for callback in self.checkpoint_callbacks
|
|
823
|
-
if isinstance(callback, LatestEpochCheckpointCallbackConfig)
|
|
824
|
-
),
|
|
825
|
-
)
|
|
826
|
-
|
|
827
|
-
@property
|
|
828
|
-
def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
|
|
829
|
-
return next(
|
|
830
|
-
(
|
|
831
|
-
callback
|
|
832
|
-
for callback in self.checkpoint_callbacks
|
|
833
|
-
if isinstance(callback, OnExceptionCheckpointCallbackConfig)
|
|
834
|
-
),
|
|
835
|
-
)
|
|
836
|
-
|
|
837
804
|
@override
|
|
838
805
|
def create_callbacks(self, root_config: "BaseConfig"):
|
|
839
806
|
if not self.should_save_checkpoints(root_config):
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
2
|
-
from .best_checkpoint import (
|
|
3
|
-
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
|
-
)
|
|
5
|
-
from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
|
|
6
|
-
from .latest_epoch_checkpoint import (
|
|
7
|
-
LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
|
|
8
|
-
)
|
|
9
|
-
from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
|
|
10
|
-
from .model_checkpoint import (
|
|
11
|
-
ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
|
|
12
|
-
)
|
|
13
|
-
from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
14
|
-
from .on_exception_checkpoint import (
|
|
15
|
-
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
16
|
-
)
|