nshtrainer 1.0.0b11__tar.gz → 1.0.0b13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/pyproject.toml +1 -1
- nshtrainer-1.0.0b13/src/nshtrainer/callbacks/lr_monitor.py +31 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/__init__.py +5 -13
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/__init__.py +8 -0
- nshtrainer-1.0.0b13/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/__init__.py +19 -15
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py +19 -15
- nshtrainer-1.0.0b13/src/nshtrainer/data/datamodule.py +124 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/base.py +100 -2
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_config.py +95 -147
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/trainer.py +48 -76
- nshtrainer-1.0.0b11/src/nshtrainer/data/datamodule.py +0 -57
- nshtrainer-1.0.0b11/src/nshtrainer/scripts/find_packages.py +0 -52
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/README.md +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
+
|
7
|
+
from .base import CallbackConfigBase
|
8
|
+
|
9
|
+
|
10
|
+
class LearningRateMonitorConfig(CallbackConfigBase):
|
11
|
+
logging_interval: Literal["step", "epoch"] | None = None
|
12
|
+
"""
|
13
|
+
Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
|
14
|
+
"""
|
15
|
+
|
16
|
+
log_momentum: bool = False
|
17
|
+
"""
|
18
|
+
Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
|
19
|
+
"""
|
20
|
+
|
21
|
+
log_weight_decay: bool = False
|
22
|
+
"""
|
23
|
+
Option to also log the weight decay values of the optimizer. Defaults to False.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def create_callbacks(self, trainer_config):
|
27
|
+
yield LearningRateMonitor(
|
28
|
+
logging_interval=self.logging_interval,
|
29
|
+
log_momentum=self.log_momentum,
|
30
|
+
log_weight_decay=self.log_weight_decay,
|
31
|
+
)
|
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
|
|
132
132
|
from nshtrainer.trainer._config import (
|
133
133
|
GradientClippingConfig as GradientClippingConfig,
|
134
134
|
)
|
135
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
136
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
137
135
|
from nshtrainer.trainer._config import (
|
138
|
-
|
136
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
139
137
|
)
|
140
138
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
141
139
|
from nshtrainer.util._environment_info import (
|
@@ -325,6 +323,10 @@ else:
|
|
325
323
|
).LastCheckpointStrategyConfig
|
326
324
|
if name == "LeakyReLUNonlinearityConfig":
|
327
325
|
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
326
|
+
if name == "LearningRateMonitorConfig":
|
327
|
+
return importlib.import_module(
|
328
|
+
"nshtrainer.trainer._config"
|
329
|
+
).LearningRateMonitorConfig
|
328
330
|
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
329
331
|
return importlib.import_module(
|
330
332
|
"nshtrainer.lr_scheduler"
|
@@ -333,8 +335,6 @@ else:
|
|
333
335
|
return importlib.import_module(
|
334
336
|
"nshtrainer.callbacks"
|
335
337
|
).LogEpochCallbackConfig
|
336
|
-
if name == "LoggingConfig":
|
337
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
338
338
|
if name == "MLPConfig":
|
339
339
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
340
340
|
if name == "MetricConfig":
|
@@ -349,10 +349,6 @@ else:
|
|
349
349
|
return importlib.import_module(
|
350
350
|
"nshtrainer.callbacks"
|
351
351
|
).OnExceptionCheckpointCallbackConfig
|
352
|
-
if name == "OptimizationConfig":
|
353
|
-
return importlib.import_module(
|
354
|
-
"nshtrainer.trainer._config"
|
355
|
-
).OptimizationConfig
|
356
352
|
if name == "OptimizerConfigBase":
|
357
353
|
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
358
354
|
if name == "PReLUConfig":
|
@@ -373,10 +369,6 @@ else:
|
|
373
369
|
return importlib.import_module(
|
374
370
|
"nshtrainer.lr_scheduler"
|
375
371
|
).ReduceLROnPlateauConfig
|
376
|
-
if name == "ReproducibilityConfig":
|
377
|
-
return importlib.import_module(
|
378
|
-
"nshtrainer.trainer._config"
|
379
|
-
).ReproducibilityConfig
|
380
372
|
if name == "SanityCheckingConfig":
|
381
373
|
return importlib.import_module(
|
382
374
|
"nshtrainer.trainer._config"
|
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
|
|
62
62
|
CheckpointMetadata as CheckpointMetadata,
|
63
63
|
)
|
64
64
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
65
|
+
from nshtrainer.callbacks.lr_monitor import (
|
66
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
67
|
+
)
|
65
68
|
else:
|
66
69
|
|
67
70
|
def __getattr__(name):
|
@@ -115,6 +118,10 @@ else:
|
|
115
118
|
return importlib.import_module(
|
116
119
|
"nshtrainer.callbacks"
|
117
120
|
).LastCheckpointCallbackConfig
|
121
|
+
if name == "LearningRateMonitorConfig":
|
122
|
+
return importlib.import_module(
|
123
|
+
"nshtrainer.callbacks.lr_monitor"
|
124
|
+
).LearningRateMonitorConfig
|
118
125
|
if name == "LogEpochCallbackConfig":
|
119
126
|
return importlib.import_module(
|
120
127
|
"nshtrainer.callbacks"
|
@@ -167,6 +174,7 @@ from . import ema as ema
|
|
167
174
|
from . import finite_checks as finite_checks
|
168
175
|
from . import gradient_skipping as gradient_skipping
|
169
176
|
from . import log_epoch as log_epoch
|
177
|
+
from . import lr_monitor as lr_monitor
|
170
178
|
from . import norm_logging as norm_logging
|
171
179
|
from . import print_table as print_table
|
172
180
|
from . import rlp_sanity_checks as rlp_sanity_checks
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
|
11
|
+
from nshtrainer.callbacks.lr_monitor import (
|
12
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
13
|
+
)
|
14
|
+
else:
|
15
|
+
|
16
|
+
def __getattr__(name):
|
17
|
+
import importlib
|
18
|
+
|
19
|
+
if name in globals():
|
20
|
+
return globals()[name]
|
21
|
+
if name == "CallbackConfigBase":
|
22
|
+
return importlib.import_module(
|
23
|
+
"nshtrainer.callbacks.lr_monitor"
|
24
|
+
).CallbackConfigBase
|
25
|
+
if name == "LearningRateMonitorConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.lr_monitor"
|
28
|
+
).LearningRateMonitorConfig
|
29
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
|
+
|
31
|
+
# Submodule exports
|
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
11
11
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
12
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
12
13
|
from nshtrainer.trainer._config import (
|
13
14
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
14
15
|
)
|
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
|
|
39
40
|
from nshtrainer.trainer._config import (
|
40
41
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
41
42
|
)
|
43
|
+
from nshtrainer.trainer._config import (
|
44
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
45
|
+
)
|
42
46
|
from nshtrainer.trainer._config import (
|
43
47
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
44
48
|
)
|
45
49
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
46
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
47
50
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
48
51
|
from nshtrainer.trainer._config import (
|
49
|
-
|
52
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
52
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
53
54
|
from nshtrainer.trainer._config import (
|
54
|
-
|
55
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
55
56
|
)
|
57
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
56
58
|
from nshtrainer.trainer._config import (
|
57
59
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
58
60
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
{nshtrainer-1.0.0b11 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py
RENAMED
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
11
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
11
12
|
from nshtrainer.trainer._config import (
|
12
13
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
13
14
|
)
|
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
|
|
38
39
|
from nshtrainer.trainer._config import (
|
39
40
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
40
41
|
)
|
42
|
+
from nshtrainer.trainer._config import (
|
43
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
44
|
+
)
|
41
45
|
from nshtrainer.trainer._config import (
|
42
46
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
43
47
|
)
|
44
48
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
45
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
46
49
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
47
50
|
from nshtrainer.trainer._config import (
|
48
|
-
|
51
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
49
52
|
)
|
50
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
51
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
52
53
|
from nshtrainer.trainer._config import (
|
53
|
-
|
54
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
54
55
|
)
|
56
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
57
59
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
@@ -0,0 +1,124 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Callable, Mapping
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Generic, cast
|
7
|
+
|
8
|
+
import nshconfig as C
|
9
|
+
import torch
|
10
|
+
from lightning.pytorch import LightningDataModule
|
11
|
+
from typing_extensions import Never, TypeVar, deprecated, override
|
12
|
+
|
13
|
+
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
14
|
+
from ..model.mixins.debug import _DebugModuleMixin
|
15
|
+
|
16
|
+
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
17
|
+
|
18
|
+
|
19
|
+
class LightningDataModuleBase(
|
20
|
+
_DebugModuleMixin,
|
21
|
+
CallbackRegistrarModuleMixin,
|
22
|
+
LightningDataModule,
|
23
|
+
ABC,
|
24
|
+
Generic[THparams],
|
25
|
+
):
|
26
|
+
@property
|
27
|
+
@override
|
28
|
+
def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
29
|
+
return cast(THparams, super().hparams)
|
30
|
+
|
31
|
+
@property
|
32
|
+
@override
|
33
|
+
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
34
|
+
hparams = cast(THparams, super().hparams_initial)
|
35
|
+
return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
|
36
|
+
|
37
|
+
@property
|
38
|
+
@deprecated("Use `hparams` instead")
|
39
|
+
def config(self):
|
40
|
+
return cast(Never, self.hparams)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
@abstractmethod
|
44
|
+
def hparams_cls(cls) -> type[THparams]: ...
|
45
|
+
|
46
|
+
@override
|
47
|
+
def __init__(self, hparams: THparams | Mapping[str, Any]):
|
48
|
+
super().__init__()
|
49
|
+
|
50
|
+
# Validate and save hyperparameters
|
51
|
+
hparams_cls = self.hparams_cls()
|
52
|
+
if isinstance(hparams, Mapping):
|
53
|
+
hparams = hparams_cls.model_validate(hparams)
|
54
|
+
elif not isinstance(hparams, hparams_cls):
|
55
|
+
raise TypeError(
|
56
|
+
f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
|
57
|
+
)
|
58
|
+
hparams = hparams.model_deep_validate()
|
59
|
+
self.save_hyperparameters(hparams)
|
60
|
+
|
61
|
+
@override
|
62
|
+
@classmethod
|
63
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
64
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def hparams_from_checkpoint(
|
68
|
+
cls,
|
69
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
70
|
+
/,
|
71
|
+
strict: bool | None = None,
|
72
|
+
*,
|
73
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
74
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
75
|
+
):
|
76
|
+
if isinstance(ckpt_or_path, dict):
|
77
|
+
ckpt = ckpt_or_path
|
78
|
+
else:
|
79
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
80
|
+
|
81
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
82
|
+
raise ValueError(
|
83
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
84
|
+
)
|
85
|
+
if update_hparams_dict is not None:
|
86
|
+
hparams = update_hparams_dict(hparams)
|
87
|
+
|
88
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
89
|
+
if update_hparams is not None:
|
90
|
+
hparams = update_hparams(hparams)
|
91
|
+
|
92
|
+
return hparams
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def from_checkpoint(
|
96
|
+
cls,
|
97
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
98
|
+
/,
|
99
|
+
strict: bool | None = None,
|
100
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
101
|
+
*,
|
102
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
103
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
104
|
+
):
|
105
|
+
# Load checkpoint
|
106
|
+
if isinstance(ckpt_or_path, Mapping):
|
107
|
+
ckpt = ckpt_or_path
|
108
|
+
else:
|
109
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
110
|
+
|
111
|
+
# Load hyperparameters from checkpoint
|
112
|
+
hparams = cls.hparams_from_checkpoint(
|
113
|
+
ckpt,
|
114
|
+
strict=strict,
|
115
|
+
update_hparams=update_hparams,
|
116
|
+
update_hparams_dict=update_hparams_dict,
|
117
|
+
)
|
118
|
+
|
119
|
+
# Load datamodule from checkpoint
|
120
|
+
datamodule = cls(hparams)
|
121
|
+
if datamodule.__class__.__qualname__ in ckpt:
|
122
|
+
datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
|
123
|
+
|
124
|
+
return datamodule
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Mapping
|
5
|
+
from collections.abc import Callable, Mapping
|
6
|
+
from pathlib import Path
|
6
7
|
from typing import Any, Generic, Literal, cast
|
7
8
|
|
8
9
|
import nshconfig as C
|
@@ -10,11 +11,13 @@ import torch
|
|
10
11
|
import torch.distributed
|
11
12
|
from lightning.pytorch import LightningModule
|
12
13
|
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
14
|
+
from lightning.pytorch.utilities.model_helpers import is_overridden
|
15
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
13
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
14
17
|
|
15
18
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
16
19
|
from .mixins.callback import CallbackModuleMixin
|
17
|
-
from .mixins.debug import _DebugModuleMixin
|
20
|
+
from .mixins.debug import _DebugModuleMixin
|
18
21
|
from .mixins.logger import LoggerLightningModuleMixin
|
19
22
|
|
20
23
|
log = logging.getLogger(__name__)
|
@@ -241,3 +244,98 @@ class LightningModuleBase(
|
|
241
244
|
loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
|
242
245
|
loss = cast(torch.Tensor, loss)
|
243
246
|
return loss
|
247
|
+
|
248
|
+
@override
|
249
|
+
@classmethod
|
250
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
251
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def hparams_from_checkpoint(
|
255
|
+
cls,
|
256
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
257
|
+
/,
|
258
|
+
strict: bool | None = None,
|
259
|
+
*,
|
260
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
261
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
262
|
+
):
|
263
|
+
if isinstance(ckpt_or_path, dict):
|
264
|
+
ckpt = ckpt_or_path
|
265
|
+
else:
|
266
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
267
|
+
|
268
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
269
|
+
raise ValueError(
|
270
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
271
|
+
)
|
272
|
+
if update_hparams_dict is not None:
|
273
|
+
hparams = update_hparams_dict(hparams)
|
274
|
+
|
275
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
276
|
+
if update_hparams is not None:
|
277
|
+
hparams = update_hparams(hparams)
|
278
|
+
|
279
|
+
return hparams
|
280
|
+
|
281
|
+
@classmethod
|
282
|
+
def from_checkpoint(
|
283
|
+
cls,
|
284
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
285
|
+
/,
|
286
|
+
strict: bool | None = None,
|
287
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
288
|
+
*,
|
289
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
290
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
291
|
+
):
|
292
|
+
# Load checkpoint
|
293
|
+
if isinstance(ckpt_or_path, Mapping):
|
294
|
+
ckpt = ckpt_or_path
|
295
|
+
else:
|
296
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
297
|
+
|
298
|
+
# Load hyperparameters from checkpoint
|
299
|
+
hparams = cls.hparams_from_checkpoint(
|
300
|
+
ckpt,
|
301
|
+
strict=strict,
|
302
|
+
update_hparams=update_hparams,
|
303
|
+
update_hparams_dict=update_hparams_dict,
|
304
|
+
)
|
305
|
+
|
306
|
+
# Load model from checkpoint
|
307
|
+
model = cls(hparams)
|
308
|
+
|
309
|
+
# Load model state from checkpoint
|
310
|
+
if (
|
311
|
+
model._strict_loading is not None
|
312
|
+
and strict is not None
|
313
|
+
and strict != model.strict_loading
|
314
|
+
):
|
315
|
+
raise ValueError(
|
316
|
+
f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
|
317
|
+
f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
|
318
|
+
)
|
319
|
+
strict = model.strict_loading if strict is None else strict
|
320
|
+
|
321
|
+
if is_overridden("configure_model", model):
|
322
|
+
model.configure_model()
|
323
|
+
|
324
|
+
# give model a chance to load something
|
325
|
+
model.on_load_checkpoint(ckpt)
|
326
|
+
|
327
|
+
# load the state_dict on the model automatically
|
328
|
+
|
329
|
+
keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
|
330
|
+
|
331
|
+
if not strict:
|
332
|
+
if keys.missing_keys:
|
333
|
+
rank_zero_warn(
|
334
|
+
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
335
|
+
)
|
336
|
+
if keys.unexpected_keys:
|
337
|
+
rank_zero_warn(
|
338
|
+
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
339
|
+
)
|
340
|
+
|
341
|
+
return model
|