nshtrainer 1.0.0b13__tar.gz → 1.0.0b14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/pyproject.toml +1 -1
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +3 -3
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/__init__.py +2 -37
- nshtrainer-1.0.0b14/src/nshtrainer/configs/_checkpoint/__init__.py +31 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/__init__.py +0 -8
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -7
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_config.py +0 -7
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/trainer.py +1 -11
- nshtrainer-1.0.0b13/src/nshtrainer/_checkpoint/loader.py +0 -387
- nshtrainer-1.0.0b13/src/nshtrainer/configs/_checkpoint/__init__.py +0 -70
- nshtrainer-1.0.0b13/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -62
- nshtrainer-1.0.0b13/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -26
- nshtrainer-1.0.0b13/src/nshtrainer/trainer/checkpoint_connector.py +0 -86
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/README.md +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -96,8 +96,8 @@ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
|
|
96
96
|
def on_exception(self, trainer: LightningTrainer, *args: Any, **kwargs: Any):
|
97
97
|
# Monkey-patch the strategy instance to make the barrier operation a no-op.
|
98
98
|
# We do this because `save_checkpoint` calls `barrier`. This is okay in most
|
99
|
-
#
|
100
|
-
#
|
101
|
-
#
|
99
|
+
# cases, but when we want to save a checkpoint in the case of an exception,
|
100
|
+
# `barrier` causes a deadlock. So we monkey-patch the strategy instance to
|
101
|
+
# make the barrier operation a no-op.
|
102
102
|
with _monkey_patch_disable_barrier(trainer):
|
103
103
|
return super().on_exception(trainer, *args, **kwargs)
|
@@ -9,19 +9,7 @@ from typing import TYPE_CHECKING
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer import MetricConfig as MetricConfig
|
11
11
|
from nshtrainer import TrainerConfig as TrainerConfig
|
12
|
-
from nshtrainer._checkpoint.
|
13
|
-
BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
|
14
|
-
)
|
15
|
-
from nshtrainer._checkpoint.loader import (
|
16
|
-
CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
|
17
|
-
)
|
18
|
-
from nshtrainer._checkpoint.loader import CheckpointMetadata as CheckpointMetadata
|
19
|
-
from nshtrainer._checkpoint.loader import (
|
20
|
-
LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
|
21
|
-
)
|
22
|
-
from nshtrainer._checkpoint.loader import (
|
23
|
-
UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
|
24
|
-
)
|
12
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
25
13
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
26
14
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
27
15
|
from nshtrainer._hf_hub import (
|
@@ -122,9 +110,6 @@ if TYPE_CHECKING:
|
|
122
110
|
from nshtrainer.trainer._config import (
|
123
111
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
124
112
|
)
|
125
|
-
from nshtrainer.trainer._config import (
|
126
|
-
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
127
|
-
)
|
128
113
|
from nshtrainer.trainer._config import (
|
129
114
|
CheckpointSavingConfig as CheckpointSavingConfig,
|
130
115
|
)
|
@@ -199,21 +184,13 @@ else:
|
|
199
184
|
return importlib.import_module(
|
200
185
|
"nshtrainer.callbacks"
|
201
186
|
).BestCheckpointCallbackConfig
|
202
|
-
if name == "BestCheckpointStrategyConfig":
|
203
|
-
return importlib.import_module(
|
204
|
-
"nshtrainer._checkpoint.loader"
|
205
|
-
).BestCheckpointStrategyConfig
|
206
187
|
if name == "CSVLoggerConfig":
|
207
188
|
return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
|
208
189
|
if name == "CallbackConfigBase":
|
209
190
|
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
210
|
-
if name == "CheckpointLoadingConfig":
|
211
|
-
return importlib.import_module(
|
212
|
-
"nshtrainer.trainer._config"
|
213
|
-
).CheckpointLoadingConfig
|
214
191
|
if name == "CheckpointMetadata":
|
215
192
|
return importlib.import_module(
|
216
|
-
"nshtrainer._checkpoint.
|
193
|
+
"nshtrainer._checkpoint.metadata"
|
217
194
|
).CheckpointMetadata
|
218
195
|
if name == "CheckpointSavingConfig":
|
219
196
|
return importlib.import_module(
|
@@ -317,10 +294,6 @@ else:
|
|
317
294
|
return importlib.import_module(
|
318
295
|
"nshtrainer.callbacks"
|
319
296
|
).LastCheckpointCallbackConfig
|
320
|
-
if name == "LastCheckpointStrategyConfig":
|
321
|
-
return importlib.import_module(
|
322
|
-
"nshtrainer._checkpoint.loader"
|
323
|
-
).LastCheckpointStrategyConfig
|
324
297
|
if name == "LeakyReLUNonlinearityConfig":
|
325
298
|
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
326
299
|
if name == "LearningRateMonitorConfig":
|
@@ -403,10 +376,6 @@ else:
|
|
403
376
|
return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
|
404
377
|
if name == "TrainerConfig":
|
405
378
|
return importlib.import_module("nshtrainer").TrainerConfig
|
406
|
-
if name == "UserProvidedPathCheckpointStrategyConfig":
|
407
|
-
return importlib.import_module(
|
408
|
-
"nshtrainer._checkpoint.loader"
|
409
|
-
).UserProvidedPathCheckpointStrategyConfig
|
410
379
|
if name == "WandbLoggerConfig":
|
411
380
|
return importlib.import_module("nshtrainer.loggers").WandbLoggerConfig
|
412
381
|
if name == "WandbUploadCodeCallbackConfig":
|
@@ -423,10 +392,6 @@ else:
|
|
423
392
|
return importlib.import_module(
|
424
393
|
"nshtrainer.trainer._config"
|
425
394
|
).CheckpointCallbackConfig
|
426
|
-
if name == "CheckpointLoadingStrategyConfig":
|
427
|
-
return importlib.import_module(
|
428
|
-
"nshtrainer._checkpoint.loader"
|
429
|
-
).CheckpointLoadingStrategyConfig
|
430
395
|
if name == "DurationConfig":
|
431
396
|
return importlib.import_module("nshtrainer.util.config").DurationConfig
|
432
397
|
if name == "LRSchedulerConfig":
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
11
|
+
from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
|
12
|
+
else:
|
13
|
+
|
14
|
+
def __getattr__(name):
|
15
|
+
import importlib
|
16
|
+
|
17
|
+
if name in globals():
|
18
|
+
return globals()[name]
|
19
|
+
if name == "CheckpointMetadata":
|
20
|
+
return importlib.import_module(
|
21
|
+
"nshtrainer._checkpoint.metadata"
|
22
|
+
).CheckpointMetadata
|
23
|
+
if name == "EnvironmentConfig":
|
24
|
+
return importlib.import_module(
|
25
|
+
"nshtrainer._checkpoint.metadata"
|
26
|
+
).EnvironmentConfig
|
27
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
28
|
+
|
29
|
+
|
30
|
+
# Submodule exports
|
31
|
+
from . import metadata as metadata
|
@@ -18,9 +18,6 @@ if TYPE_CHECKING:
|
|
18
18
|
from nshtrainer.trainer._config import (
|
19
19
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
20
20
|
)
|
21
|
-
from nshtrainer.trainer._config import (
|
22
|
-
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
23
|
-
)
|
24
21
|
from nshtrainer.trainer._config import (
|
25
22
|
CheckpointSavingConfig as CheckpointSavingConfig,
|
26
23
|
)
|
@@ -91,10 +88,6 @@ else:
|
|
91
88
|
return importlib.import_module(
|
92
89
|
"nshtrainer.trainer._config"
|
93
90
|
).CallbackConfigBase
|
94
|
-
if name == "CheckpointLoadingConfig":
|
95
|
-
return importlib.import_module(
|
96
|
-
"nshtrainer.trainer._config"
|
97
|
-
).CheckpointLoadingConfig
|
98
91
|
if name == "CheckpointSavingConfig":
|
99
92
|
return importlib.import_module(
|
100
93
|
"nshtrainer.trainer._config"
|
@@ -180,5 +173,4 @@ else:
|
|
180
173
|
|
181
174
|
# Submodule exports
|
182
175
|
from . import _config as _config
|
183
|
-
from . import checkpoint_connector as checkpoint_connector
|
184
176
|
from . import trainer as trainer
|
{nshtrainer-1.0.0b13 → nshtrainer-1.0.0b14}/src/nshtrainer/configs/trainer/_config/__init__.py
RENAMED
@@ -17,9 +17,6 @@ if TYPE_CHECKING:
|
|
17
17
|
from nshtrainer.trainer._config import (
|
18
18
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
19
19
|
)
|
20
|
-
from nshtrainer.trainer._config import (
|
21
|
-
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
22
|
-
)
|
23
20
|
from nshtrainer.trainer._config import (
|
24
21
|
CheckpointSavingConfig as CheckpointSavingConfig,
|
25
22
|
)
|
@@ -91,10 +88,6 @@ else:
|
|
91
88
|
return importlib.import_module(
|
92
89
|
"nshtrainer.trainer._config"
|
93
90
|
).CallbackConfigBase
|
94
|
-
if name == "CheckpointLoadingConfig":
|
95
|
-
return importlib.import_module(
|
96
|
-
"nshtrainer.trainer._config"
|
97
|
-
).CheckpointLoadingConfig
|
98
91
|
if name == "CheckpointSavingConfig":
|
99
92
|
return importlib.import_module(
|
100
93
|
"nshtrainer.trainer._config"
|
@@ -32,7 +32,6 @@ from lightning.pytorch.profilers import Profiler
|
|
32
32
|
from lightning.pytorch.strategies.strategy import Strategy
|
33
33
|
from typing_extensions import TypedDict, TypeVar, override
|
34
34
|
|
35
|
-
from .._checkpoint.loader import CheckpointLoadingConfig
|
36
35
|
from .._directory import DirectoryConfig
|
37
36
|
from .._hf_hub import HuggingFaceHubConfig
|
38
37
|
from ..callbacks import (
|
@@ -493,12 +492,6 @@ class TrainerConfig(C.Config):
|
|
493
492
|
ckpt_path: Literal["none"] | str | Path | None = None
|
494
493
|
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
495
494
|
|
496
|
-
checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
|
497
|
-
"""Checkpoint loading configuration options.
|
498
|
-
`"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
|
499
|
-
`"none"` will disable checkpoint loading.
|
500
|
-
"""
|
501
|
-
|
502
495
|
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
503
496
|
"""Checkpoint saving configuration options."""
|
504
497
|
|
@@ -29,7 +29,6 @@ from ._config import (
|
|
29
29
|
TrainerConfig,
|
30
30
|
)
|
31
31
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
32
|
-
from .checkpoint_connector import _CheckpointConnector
|
33
32
|
from .signal_connector import _SignalConnector
|
34
33
|
|
35
34
|
log = logging.getLogger(__name__)
|
@@ -314,9 +313,6 @@ class Trainer(LightningTrainer):
|
|
314
313
|
# Replace the signal connector with our own.
|
315
314
|
self._signal_connector = _SignalConnector(self)
|
316
315
|
|
317
|
-
# Replace the checkpoint connector with our own.
|
318
|
-
self._checkpoint_connector = _CheckpointConnector(self)
|
319
|
-
|
320
316
|
# Print out the log dir, so that we can easily find it in the logs.
|
321
317
|
if log_dir := self.log_dir:
|
322
318
|
log_dir = str(Path(log_dir).resolve())
|
@@ -441,19 +437,13 @@ class Trainer(LightningTrainer):
|
|
441
437
|
):
|
442
438
|
filepath = Path(filepath)
|
443
439
|
|
444
|
-
# List of files that we should upload to HF
|
445
|
-
written_files: list[Path] = [filepath]
|
446
|
-
|
447
440
|
super().save_checkpoint(filepath, weights_only, storage_options)
|
448
441
|
|
449
442
|
# Save the checkpoint metadata
|
450
443
|
metadata_path = None
|
451
444
|
if self.hparams.save_checkpoint_metadata and self.is_global_zero:
|
452
445
|
# Generate the metadata and write to disk
|
453
|
-
|
454
|
-
metadata_path := _write_checkpoint_metadata(self, filepath)
|
455
|
-
) is not None:
|
456
|
-
written_files.append(metadata_path)
|
446
|
+
metadata_path = _write_checkpoint_metadata(self, filepath)
|
457
447
|
|
458
448
|
# Call the `on_checkpoint_saved` method on all callbacks
|
459
449
|
from .. import _callback
|
@@ -1,387 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from collections.abc import Iterable, Sequence
|
5
|
-
from dataclasses import dataclass
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
|
8
|
-
|
9
|
-
import nshconfig as C
|
10
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
11
|
-
from typing_extensions import assert_never
|
12
|
-
|
13
|
-
from ..metrics._config import MetricConfig
|
14
|
-
from .metadata import CheckpointMetadata
|
15
|
-
|
16
|
-
if TYPE_CHECKING:
|
17
|
-
from ..trainer import Trainer
|
18
|
-
from ..trainer._config import TrainerConfig
|
19
|
-
|
20
|
-
log = logging.getLogger(__name__)
|
21
|
-
|
22
|
-
|
23
|
-
class BestCheckpointStrategyConfig(C.Config):
|
24
|
-
name: Literal["best"] = "best"
|
25
|
-
|
26
|
-
metric: MetricConfig | None = None
|
27
|
-
"""The metric to use for selecting the best checkpoint. If `None`, the primary metric will be used."""
|
28
|
-
|
29
|
-
additional_candidates: Iterable[Path] = []
|
30
|
-
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
31
|
-
|
32
|
-
|
33
|
-
class UserProvidedPathCheckpointStrategyConfig(C.Config):
|
34
|
-
name: Literal["user_provided_path"] = "user_provided_path"
|
35
|
-
|
36
|
-
path: Path
|
37
|
-
"""The path to the checkpoint to load."""
|
38
|
-
|
39
|
-
on_error: Literal["warn", "raise"] = "warn"
|
40
|
-
"""The behavior when the checkpoint does not belong to the current run.
|
41
|
-
|
42
|
-
- `warn`: Log a warning and skip the checkpoint.
|
43
|
-
- `raise`: Raise an error.
|
44
|
-
"""
|
45
|
-
|
46
|
-
|
47
|
-
class LastCheckpointStrategyConfig(C.Config):
|
48
|
-
name: Literal["last"] = "last"
|
49
|
-
|
50
|
-
criterion: Literal["global_step", "runtime"] = "global_step"
|
51
|
-
"""The criterion to use for selecting the last checkpoint.
|
52
|
-
|
53
|
-
- `global_step`: The checkpoint with the highest global step will be selected.
|
54
|
-
- `runtime`: The checkpoint with the highest runtime will be selected.
|
55
|
-
"""
|
56
|
-
|
57
|
-
additional_candidates: Iterable[Path] = []
|
58
|
-
"""Additional checkpoint candidates to consider when selecting the last checkpoint."""
|
59
|
-
|
60
|
-
|
61
|
-
CheckpointLoadingStrategyConfig: TypeAlias = Annotated[
|
62
|
-
BestCheckpointStrategyConfig
|
63
|
-
| LastCheckpointStrategyConfig
|
64
|
-
| UserProvidedPathCheckpointStrategyConfig,
|
65
|
-
C.Field(discriminator="name"),
|
66
|
-
]
|
67
|
-
|
68
|
-
|
69
|
-
class CheckpointLoadingConfig(C.Config):
|
70
|
-
strategies: Sequence[CheckpointLoadingStrategyConfig]
|
71
|
-
"""The strategies to use for loading checkpoints.
|
72
|
-
|
73
|
-
The order of the strategies determines the priority of the strategies.
|
74
|
-
The first strategy that resolves a checkpoint will be used.
|
75
|
-
"""
|
76
|
-
|
77
|
-
include_hpc: bool
|
78
|
-
"""Whether to include checkpoints from HPC pre-emption."""
|
79
|
-
|
80
|
-
@classmethod
|
81
|
-
def none(cls, include_hpc: bool = False):
|
82
|
-
return cls(strategies=[], include_hpc=include_hpc)
|
83
|
-
|
84
|
-
@classmethod
|
85
|
-
def _auto_train(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
86
|
-
if ckpt is None:
|
87
|
-
ckpt = "last"
|
88
|
-
match ckpt:
|
89
|
-
case "best":
|
90
|
-
return cls(
|
91
|
-
strategies=[BestCheckpointStrategyConfig()],
|
92
|
-
include_hpc=True,
|
93
|
-
)
|
94
|
-
case "last":
|
95
|
-
return cls(
|
96
|
-
strategies=[LastCheckpointStrategyConfig()],
|
97
|
-
include_hpc=True,
|
98
|
-
)
|
99
|
-
case "none":
|
100
|
-
return cls.none()
|
101
|
-
case Path() | str():
|
102
|
-
ckpt = Path(ckpt)
|
103
|
-
return cls(
|
104
|
-
strategies=[
|
105
|
-
LastCheckpointStrategyConfig(additional_candidates=[ckpt]),
|
106
|
-
UserProvidedPathCheckpointStrategyConfig(path=ckpt),
|
107
|
-
],
|
108
|
-
include_hpc=True,
|
109
|
-
)
|
110
|
-
case _:
|
111
|
-
assert_never(ckpt)
|
112
|
-
|
113
|
-
@classmethod
|
114
|
-
def _auto_eval(cls, ckpt: Literal["best", "last", "none"] | str | Path | None):
|
115
|
-
if ckpt is None:
|
116
|
-
log.warn("No checkpoint specified for evaluation. Defaulting to `last`.")
|
117
|
-
ckpt = "last"
|
118
|
-
|
119
|
-
match ckpt:
|
120
|
-
case "best":
|
121
|
-
return cls(
|
122
|
-
strategies=[BestCheckpointStrategyConfig()],
|
123
|
-
include_hpc=False,
|
124
|
-
)
|
125
|
-
case "last":
|
126
|
-
return cls(
|
127
|
-
strategies=[LastCheckpointStrategyConfig()],
|
128
|
-
include_hpc=False,
|
129
|
-
)
|
130
|
-
case "none":
|
131
|
-
return cls.none(include_hpc=False)
|
132
|
-
case Path() | str():
|
133
|
-
ckpt = Path(ckpt)
|
134
|
-
return cls(
|
135
|
-
strategies=[UserProvidedPathCheckpointStrategyConfig(path=ckpt)],
|
136
|
-
include_hpc=False,
|
137
|
-
)
|
138
|
-
case _:
|
139
|
-
assert_never(ckpt)
|
140
|
-
|
141
|
-
@classmethod
|
142
|
-
def auto(
|
143
|
-
cls,
|
144
|
-
ckpt: Literal["best", "last", "none"] | str | Path | None,
|
145
|
-
trainer_mode: TrainerFn,
|
146
|
-
):
|
147
|
-
"""
|
148
|
-
Automatically create a CheckpointLoadingConfig based on the provided checkpoint option and trainer mode.
|
149
|
-
|
150
|
-
This method provides a convenient way to generate a checkpoint loading configuration
|
151
|
-
tailored to different training and evaluation scenarios.
|
152
|
-
|
153
|
-
Parameters:
|
154
|
-
-----------
|
155
|
-
ckpt : Literal["best", "last", "none"] | str | Path | None
|
156
|
-
Specifies the checkpoint loading preference:
|
157
|
-
- "best": Use the best checkpoint based on the primary metric.
|
158
|
-
- "last": Use the most recent checkpoint.
|
159
|
-
- str or Path: Path to a specific checkpoint file.
|
160
|
-
- None: Defaults to "last" for training, raises an error for evaluation.
|
161
|
-
|
162
|
-
trainer_mode : TrainerFn
|
163
|
-
The mode in which the trainer is operating. This affects how the configuration is created.
|
164
|
-
- TrainerFn.FITTING: Used for training scenarios.
|
165
|
-
- TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING: Used for evaluation scenarios.
|
166
|
-
|
167
|
-
Returns:
|
168
|
-
--------
|
169
|
-
CheckpointLoadingConfig
|
170
|
-
A configuration object for checkpoint loading based on the given parameters.
|
171
|
-
|
172
|
-
Behavior:
|
173
|
-
---------
|
174
|
-
1. For training (TrainerFn.FITTING):
|
175
|
-
- Includes HPC pre-emption checkpoints.
|
176
|
-
- If ckpt is None, defaults to "last".
|
177
|
-
- For "best" or "last", creates a single-strategy configuration that loads the best or last checkpoint.
|
178
|
-
- For a specific path, creates a two-strategy configuration:
|
179
|
-
a) Tries to load the checkpoint as the last checkpoint.
|
180
|
-
b) Falls back to loading it as a user-provided path.
|
181
|
-
|
182
|
-
2. For evaluation (VALIDATING, TESTING, PREDICTING):
|
183
|
-
- Does not include HPC pre-emption checkpoints.
|
184
|
-
- Requires ckpt to be specified (raises ValueError if None).
|
185
|
-
- Creates a single-strategy configuration based on the ckpt value.
|
186
|
-
|
187
|
-
Raises:
|
188
|
-
-------
|
189
|
-
ValueError
|
190
|
-
If ckpt is None during evaluation modes.
|
191
|
-
|
192
|
-
Examples:
|
193
|
-
---------
|
194
|
-
# Training mode, use last checkpoint
|
195
|
-
config = CheckpointLoadingConfig.auto("last", TrainerFn.FITTING)
|
196
|
-
|
197
|
-
# Evaluation mode, use best checkpoint
|
198
|
-
config = CheckpointLoadingConfig.auto("best", TrainerFn.TESTING)
|
199
|
-
|
200
|
-
# Training mode, use specific checkpoint
|
201
|
-
config = CheckpointLoadingConfig.auto("/path/to/checkpoint.ckpt", TrainerFn.FITTING)
|
202
|
-
|
203
|
-
Notes:
|
204
|
-
------
|
205
|
-
- The method internally calls _auto_train or _auto_eval based on the trainer_mode.
|
206
|
-
- The resulting configuration always includes strategies as a sequence, even if there's only one strategy.
|
207
|
-
"""
|
208
|
-
# Implementation remains the same...
|
209
|
-
match trainer_mode:
|
210
|
-
case TrainerFn.FITTING:
|
211
|
-
return cls._auto_train(ckpt)
|
212
|
-
case TrainerFn.VALIDATING | TrainerFn.TESTING | TrainerFn.PREDICTING:
|
213
|
-
return cls._auto_eval(ckpt)
|
214
|
-
case _:
|
215
|
-
assert_never(trainer_mode)
|
216
|
-
|
217
|
-
|
218
|
-
@dataclass
|
219
|
-
class _CkptCandidate:
|
220
|
-
meta: CheckpointMetadata
|
221
|
-
meta_path: Path
|
222
|
-
|
223
|
-
@property
|
224
|
-
def ckpt_path(self):
|
225
|
-
return self.meta_path.with_name(self.meta.checkpoint_filename)
|
226
|
-
|
227
|
-
|
228
|
-
@overload
|
229
|
-
def _load_ckpt_meta(
|
230
|
-
path: Path,
|
231
|
-
trainer_config: TrainerConfig,
|
232
|
-
on_error: Literal["warn"] = "warn",
|
233
|
-
) -> _CkptCandidate | None: ...
|
234
|
-
@overload
|
235
|
-
def _load_ckpt_meta(
|
236
|
-
path: Path,
|
237
|
-
trainer_config: TrainerConfig,
|
238
|
-
on_error: Literal["raise"],
|
239
|
-
) -> _CkptCandidate: ...
|
240
|
-
def _load_ckpt_meta(
|
241
|
-
path: Path,
|
242
|
-
trainer_config: TrainerConfig,
|
243
|
-
on_error: Literal["warn", "raise"] = "warn",
|
244
|
-
):
|
245
|
-
meta = CheckpointMetadata.from_file(path)
|
246
|
-
if trainer_config.id != meta.run_id:
|
247
|
-
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
248
|
-
match on_error:
|
249
|
-
case "warn":
|
250
|
-
log.warning(error_msg)
|
251
|
-
case "raise":
|
252
|
-
raise ValueError(error_msg)
|
253
|
-
case _:
|
254
|
-
assert_never(on_error)
|
255
|
-
return None
|
256
|
-
return _CkptCandidate(meta, path)
|
257
|
-
|
258
|
-
|
259
|
-
def _checkpoint_candidates(trainer: Trainer, *, include_hpc: bool = True):
|
260
|
-
# Load the checkpoint directory, and throw if it doesn't exist.
|
261
|
-
# This indicates a non-standard setup, and we don't want to guess
|
262
|
-
# where the checkpoints are.
|
263
|
-
ckpt_dir = trainer.hparams.directory.resolve_subdirectory(
|
264
|
-
trainer.hparams.id, "checkpoint"
|
265
|
-
)
|
266
|
-
if not ckpt_dir.is_dir():
|
267
|
-
raise FileNotFoundError(
|
268
|
-
f"Checkpoint directory {ckpt_dir} not found. "
|
269
|
-
"Please ensure that the checkpoint directory exists."
|
270
|
-
)
|
271
|
-
|
272
|
-
# Load all checkpoints in the directory.
|
273
|
-
# We can do this by looking for metadata files.
|
274
|
-
for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
|
275
|
-
if (meta := _load_ckpt_meta(path, trainer.hparams)) is not None:
|
276
|
-
yield meta
|
277
|
-
|
278
|
-
# If we have a pre-empted checkpoint, load it
|
279
|
-
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
280
|
-
hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
281
|
-
if (meta := _load_ckpt_meta(hpc_meta_path, trainer.hparams)) is not None:
|
282
|
-
yield meta
|
283
|
-
|
284
|
-
|
285
|
-
def _additional_candidates(
|
286
|
-
additional_candidates: Iterable[Path], trainer_config: TrainerConfig
|
287
|
-
):
|
288
|
-
for path in additional_candidates:
|
289
|
-
if (
|
290
|
-
meta := _load_ckpt_meta(
|
291
|
-
path.with_suffix(CheckpointMetadata.PATH_SUFFIX), trainer_config
|
292
|
-
)
|
293
|
-
) is None:
|
294
|
-
continue
|
295
|
-
yield meta
|
296
|
-
|
297
|
-
|
298
|
-
def _resolve_checkpoint(config: CheckpointLoadingConfig, trainer: Trainer):
|
299
|
-
# We lazily load the checkpoint candidates to avoid loading them
|
300
|
-
# if they are not needed.
|
301
|
-
_ckpt_candidates: list[_CkptCandidate] | None = None
|
302
|
-
|
303
|
-
def ckpt_candidates():
|
304
|
-
nonlocal _ckpt_candidates, trainer
|
305
|
-
|
306
|
-
if _ckpt_candidates is None:
|
307
|
-
_ckpt_candidates = list(
|
308
|
-
_checkpoint_candidates(trainer, include_hpc=config.include_hpc)
|
309
|
-
)
|
310
|
-
return _ckpt_candidates
|
311
|
-
|
312
|
-
# Iterate over the strategies and try to resolve the checkpoint.
|
313
|
-
for strategy in config.strategies:
|
314
|
-
match strategy:
|
315
|
-
case UserProvidedPathCheckpointStrategyConfig():
|
316
|
-
meta = _load_ckpt_meta(
|
317
|
-
strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
318
|
-
trainer.hparams,
|
319
|
-
on_error=strategy.on_error,
|
320
|
-
)
|
321
|
-
if meta is None:
|
322
|
-
continue
|
323
|
-
return meta.ckpt_path
|
324
|
-
case BestCheckpointStrategyConfig():
|
325
|
-
candidates = [
|
326
|
-
*ckpt_candidates(),
|
327
|
-
*_additional_candidates(
|
328
|
-
strategy.additional_candidates, trainer.hparams
|
329
|
-
),
|
330
|
-
]
|
331
|
-
if not candidates:
|
332
|
-
log.warning(
|
333
|
-
"No checkpoint candidates found for `best` checkpoint strategy."
|
334
|
-
)
|
335
|
-
continue
|
336
|
-
|
337
|
-
if (
|
338
|
-
metric := strategy.metric or trainer.hparams.primary_metric
|
339
|
-
) is None:
|
340
|
-
log.warning(
|
341
|
-
"No metric specified for `best` checkpoint strategy, "
|
342
|
-
"and no primary metric is set in the configuration. "
|
343
|
-
"Skipping strategy."
|
344
|
-
)
|
345
|
-
continue
|
346
|
-
|
347
|
-
# Find the best checkpoint based on the metric.
|
348
|
-
def metric_value(ckpt: _CkptCandidate):
|
349
|
-
assert metric is not None
|
350
|
-
if (
|
351
|
-
value := ckpt.meta.metrics.get(metric.validation_monitor)
|
352
|
-
) is None:
|
353
|
-
raise ValueError(
|
354
|
-
f"Metric {metric.validation_monitor} not found in checkpoint metadata. "
|
355
|
-
f"Available metrics: {ckpt.meta.metrics.keys()}"
|
356
|
-
)
|
357
|
-
return value
|
358
|
-
|
359
|
-
best_candidate = metric.best(candidates, key=metric_value)
|
360
|
-
return best_candidate.ckpt_path
|
361
|
-
case LastCheckpointStrategyConfig():
|
362
|
-
candidates = [
|
363
|
-
*ckpt_candidates(),
|
364
|
-
*_additional_candidates(
|
365
|
-
strategy.additional_candidates, trainer.hparams
|
366
|
-
),
|
367
|
-
]
|
368
|
-
if not candidates:
|
369
|
-
log.warning(
|
370
|
-
"No checkpoint candidates found for `last` checkpoint strategy."
|
371
|
-
)
|
372
|
-
continue
|
373
|
-
|
374
|
-
# Find the last checkpoint based on the criterion.
|
375
|
-
def criterion_value(ckpt: _CkptCandidate):
|
376
|
-
match strategy.criterion:
|
377
|
-
case "global_step":
|
378
|
-
return ckpt.meta.global_step
|
379
|
-
case "runtime":
|
380
|
-
return ckpt.meta.training_time.total_seconds()
|
381
|
-
case _:
|
382
|
-
assert_never(strategy.criterion)
|
383
|
-
|
384
|
-
last_candidate = max(candidates, key=criterion_value)
|
385
|
-
return last_candidate.ckpt_path
|
386
|
-
case _:
|
387
|
-
assert_never(strategy)
|