nshtrainer 1.0.0b12__tar.gz → 1.0.0b13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/pyproject.toml +1 -1
- nshtrainer-1.0.0b13/src/nshtrainer/callbacks/lr_monitor.py +31 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/__init__.py +5 -13
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/__init__.py +8 -0
- nshtrainer-1.0.0b13/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/__init__.py +19 -15
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py +19 -15
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/datamodule.py +0 -2
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_config.py +95 -146
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/trainer.py +10 -13
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/README.md +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/loader/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/checkpoint_connector/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
+
|
7
|
+
from .base import CallbackConfigBase
|
8
|
+
|
9
|
+
|
10
|
+
class LearningRateMonitorConfig(CallbackConfigBase):
|
11
|
+
logging_interval: Literal["step", "epoch"] | None = None
|
12
|
+
"""
|
13
|
+
Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
|
14
|
+
"""
|
15
|
+
|
16
|
+
log_momentum: bool = False
|
17
|
+
"""
|
18
|
+
Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
|
19
|
+
"""
|
20
|
+
|
21
|
+
log_weight_decay: bool = False
|
22
|
+
"""
|
23
|
+
Option to also log the weight decay values of the optimizer. Defaults to False.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def create_callbacks(self, trainer_config):
|
27
|
+
yield LearningRateMonitor(
|
28
|
+
logging_interval=self.logging_interval,
|
29
|
+
log_momentum=self.log_momentum,
|
30
|
+
log_weight_decay=self.log_weight_decay,
|
31
|
+
)
|
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
|
|
132
132
|
from nshtrainer.trainer._config import (
|
133
133
|
GradientClippingConfig as GradientClippingConfig,
|
134
134
|
)
|
135
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
136
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
137
135
|
from nshtrainer.trainer._config import (
|
138
|
-
|
136
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
139
137
|
)
|
140
138
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
141
139
|
from nshtrainer.util._environment_info import (
|
@@ -325,6 +323,10 @@ else:
|
|
325
323
|
).LastCheckpointStrategyConfig
|
326
324
|
if name == "LeakyReLUNonlinearityConfig":
|
327
325
|
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
326
|
+
if name == "LearningRateMonitorConfig":
|
327
|
+
return importlib.import_module(
|
328
|
+
"nshtrainer.trainer._config"
|
329
|
+
).LearningRateMonitorConfig
|
328
330
|
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
329
331
|
return importlib.import_module(
|
330
332
|
"nshtrainer.lr_scheduler"
|
@@ -333,8 +335,6 @@ else:
|
|
333
335
|
return importlib.import_module(
|
334
336
|
"nshtrainer.callbacks"
|
335
337
|
).LogEpochCallbackConfig
|
336
|
-
if name == "LoggingConfig":
|
337
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
338
338
|
if name == "MLPConfig":
|
339
339
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
340
340
|
if name == "MetricConfig":
|
@@ -349,10 +349,6 @@ else:
|
|
349
349
|
return importlib.import_module(
|
350
350
|
"nshtrainer.callbacks"
|
351
351
|
).OnExceptionCheckpointCallbackConfig
|
352
|
-
if name == "OptimizationConfig":
|
353
|
-
return importlib.import_module(
|
354
|
-
"nshtrainer.trainer._config"
|
355
|
-
).OptimizationConfig
|
356
352
|
if name == "OptimizerConfigBase":
|
357
353
|
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
358
354
|
if name == "PReLUConfig":
|
@@ -373,10 +369,6 @@ else:
|
|
373
369
|
return importlib.import_module(
|
374
370
|
"nshtrainer.lr_scheduler"
|
375
371
|
).ReduceLROnPlateauConfig
|
376
|
-
if name == "ReproducibilityConfig":
|
377
|
-
return importlib.import_module(
|
378
|
-
"nshtrainer.trainer._config"
|
379
|
-
).ReproducibilityConfig
|
380
372
|
if name == "SanityCheckingConfig":
|
381
373
|
return importlib.import_module(
|
382
374
|
"nshtrainer.trainer._config"
|
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
|
|
62
62
|
CheckpointMetadata as CheckpointMetadata,
|
63
63
|
)
|
64
64
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
65
|
+
from nshtrainer.callbacks.lr_monitor import (
|
66
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
67
|
+
)
|
65
68
|
else:
|
66
69
|
|
67
70
|
def __getattr__(name):
|
@@ -115,6 +118,10 @@ else:
|
|
115
118
|
return importlib.import_module(
|
116
119
|
"nshtrainer.callbacks"
|
117
120
|
).LastCheckpointCallbackConfig
|
121
|
+
if name == "LearningRateMonitorConfig":
|
122
|
+
return importlib.import_module(
|
123
|
+
"nshtrainer.callbacks.lr_monitor"
|
124
|
+
).LearningRateMonitorConfig
|
118
125
|
if name == "LogEpochCallbackConfig":
|
119
126
|
return importlib.import_module(
|
120
127
|
"nshtrainer.callbacks"
|
@@ -167,6 +174,7 @@ from . import ema as ema
|
|
167
174
|
from . import finite_checks as finite_checks
|
168
175
|
from . import gradient_skipping as gradient_skipping
|
169
176
|
from . import log_epoch as log_epoch
|
177
|
+
from . import lr_monitor as lr_monitor
|
170
178
|
from . import norm_logging as norm_logging
|
171
179
|
from . import print_table as print_table
|
172
180
|
from . import rlp_sanity_checks as rlp_sanity_checks
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
|
11
|
+
from nshtrainer.callbacks.lr_monitor import (
|
12
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
13
|
+
)
|
14
|
+
else:
|
15
|
+
|
16
|
+
def __getattr__(name):
|
17
|
+
import importlib
|
18
|
+
|
19
|
+
if name in globals():
|
20
|
+
return globals()[name]
|
21
|
+
if name == "CallbackConfigBase":
|
22
|
+
return importlib.import_module(
|
23
|
+
"nshtrainer.callbacks.lr_monitor"
|
24
|
+
).CallbackConfigBase
|
25
|
+
if name == "LearningRateMonitorConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.lr_monitor"
|
28
|
+
).LearningRateMonitorConfig
|
29
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
|
+
|
31
|
+
# Submodule exports
|
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
11
11
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
12
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
12
13
|
from nshtrainer.trainer._config import (
|
13
14
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
14
15
|
)
|
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
|
|
39
40
|
from nshtrainer.trainer._config import (
|
40
41
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
41
42
|
)
|
43
|
+
from nshtrainer.trainer._config import (
|
44
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
45
|
+
)
|
42
46
|
from nshtrainer.trainer._config import (
|
43
47
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
44
48
|
)
|
45
49
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
46
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
47
50
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
48
51
|
from nshtrainer.trainer._config import (
|
49
|
-
|
52
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
52
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
53
54
|
from nshtrainer.trainer._config import (
|
54
|
-
|
55
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
55
56
|
)
|
57
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
56
58
|
from nshtrainer.trainer._config import (
|
57
59
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
58
60
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
{nshtrainer-1.0.0b12 → nshtrainer-1.0.0b13}/src/nshtrainer/configs/trainer/_config/__init__.py
RENAMED
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
11
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
11
12
|
from nshtrainer.trainer._config import (
|
12
13
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
13
14
|
)
|
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
|
|
38
39
|
from nshtrainer.trainer._config import (
|
39
40
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
40
41
|
)
|
42
|
+
from nshtrainer.trainer._config import (
|
43
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
44
|
+
)
|
41
45
|
from nshtrainer.trainer._config import (
|
42
46
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
43
47
|
)
|
44
48
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
45
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
46
49
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
47
50
|
from nshtrainer.trainer._config import (
|
48
|
-
|
51
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
49
52
|
)
|
50
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
51
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
52
53
|
from nshtrainer.trainer._config import (
|
53
|
-
|
54
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
54
55
|
)
|
56
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
57
59
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
@@ -8,8 +8,6 @@ from typing import Any, Generic, cast
|
|
8
8
|
import nshconfig as C
|
9
9
|
import torch
|
10
10
|
from lightning.pytorch import LightningDataModule
|
11
|
-
from lightning.pytorch.utilities.model_helpers import is_overridden
|
12
|
-
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
13
11
|
from typing_extensions import Never, TypeVar, deprecated, override
|
14
12
|
|
15
13
|
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
@@ -40,11 +40,13 @@ from ..callbacks import (
|
|
40
40
|
CallbackConfig,
|
41
41
|
EarlyStoppingCallbackConfig,
|
42
42
|
LastCheckpointCallbackConfig,
|
43
|
+
NormLoggingCallbackConfig,
|
43
44
|
OnExceptionCheckpointCallbackConfig,
|
44
45
|
)
|
45
46
|
from ..callbacks.base import CallbackConfigBase
|
46
47
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
47
48
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
49
|
+
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
48
50
|
from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
|
49
51
|
from ..callbacks.shared_parameters import SharedParametersCallbackConfig
|
50
52
|
from ..loggers import (
|
@@ -53,6 +55,7 @@ from ..loggers import (
|
|
53
55
|
TensorboardLoggerConfig,
|
54
56
|
WandbLoggerConfig,
|
55
57
|
)
|
58
|
+
from ..loggers._base import BaseLoggerConfig
|
56
59
|
from ..loggers.actsave import ActSaveLoggerConfig
|
57
60
|
from ..metrics._config import MetricConfig
|
58
61
|
from ..profiler import ProfilerConfig
|
@@ -61,103 +64,6 @@ from ..util._environment_info import EnvironmentConfig
|
|
61
64
|
log = logging.getLogger(__name__)
|
62
65
|
|
63
66
|
|
64
|
-
class LoggingConfig(CallbackConfigBase):
|
65
|
-
enabled: bool = True
|
66
|
-
"""Enable experiment tracking."""
|
67
|
-
|
68
|
-
loggers: Sequence[LoggerConfig] = [
|
69
|
-
WandbLoggerConfig(),
|
70
|
-
CSVLoggerConfig(),
|
71
|
-
TensorboardLoggerConfig(),
|
72
|
-
]
|
73
|
-
"""Loggers to use for experiment tracking."""
|
74
|
-
|
75
|
-
log_lr: bool | Literal["step", "epoch"] = True
|
76
|
-
"""If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
|
77
|
-
log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
|
78
|
-
"""If enabled, will log the fractional epoch number to the logger."""
|
79
|
-
|
80
|
-
actsave_logger: ActSaveLoggerConfig | None = None
|
81
|
-
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
82
|
-
|
83
|
-
@property
|
84
|
-
def wandb(self):
|
85
|
-
return next(
|
86
|
-
(
|
87
|
-
logger
|
88
|
-
for logger in self.loggers
|
89
|
-
if isinstance(logger, WandbLoggerConfig)
|
90
|
-
),
|
91
|
-
None,
|
92
|
-
)
|
93
|
-
|
94
|
-
@property
|
95
|
-
def csv(self):
|
96
|
-
return next(
|
97
|
-
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
98
|
-
None,
|
99
|
-
)
|
100
|
-
|
101
|
-
@property
|
102
|
-
def tensorboard(self):
|
103
|
-
return next(
|
104
|
-
(
|
105
|
-
logger
|
106
|
-
for logger in self.loggers
|
107
|
-
if isinstance(logger, TensorboardLoggerConfig)
|
108
|
-
),
|
109
|
-
None,
|
110
|
-
)
|
111
|
-
|
112
|
-
def create_loggers(self, trainer_config: TrainerConfig):
|
113
|
-
"""
|
114
|
-
Constructs and returns a list of loggers based on the provided root configuration.
|
115
|
-
|
116
|
-
Args:
|
117
|
-
trainer_config (TrainerConfig): The root configuration object.
|
118
|
-
|
119
|
-
Returns:
|
120
|
-
list[Logger]: A list of constructed loggers.
|
121
|
-
"""
|
122
|
-
if not self.enabled:
|
123
|
-
return
|
124
|
-
|
125
|
-
for logger_config in sorted(
|
126
|
-
self.loggers,
|
127
|
-
key=lambda x: x.priority,
|
128
|
-
reverse=True,
|
129
|
-
):
|
130
|
-
if not logger_config.enabled:
|
131
|
-
continue
|
132
|
-
if (logger := logger_config.create_logger(trainer_config)) is None:
|
133
|
-
continue
|
134
|
-
yield logger
|
135
|
-
|
136
|
-
# If the actsave_metrics is enabled, add the ActSave logger
|
137
|
-
if self.actsave_logger:
|
138
|
-
yield self.actsave_logger.create_logger(trainer_config)
|
139
|
-
|
140
|
-
@override
|
141
|
-
def create_callbacks(self, trainer_config):
|
142
|
-
if self.log_lr:
|
143
|
-
from lightning.pytorch.callbacks import LearningRateMonitor
|
144
|
-
|
145
|
-
logging_interval: str | None = None
|
146
|
-
if isinstance(self.log_lr, str):
|
147
|
-
logging_interval = self.log_lr
|
148
|
-
|
149
|
-
yield LearningRateMonitor(logging_interval=logging_interval)
|
150
|
-
|
151
|
-
if self.log_epoch:
|
152
|
-
yield from self.log_epoch.create_callbacks(trainer_config)
|
153
|
-
|
154
|
-
for logger in self.loggers:
|
155
|
-
if not logger or not isinstance(logger, CallbackConfigBase):
|
156
|
-
continue
|
157
|
-
|
158
|
-
yield from logger.create_callbacks(trainer_config)
|
159
|
-
|
160
|
-
|
161
67
|
class GradientClippingConfig(C.Config):
|
162
68
|
enabled: bool = True
|
163
69
|
"""Enable gradient clipping."""
|
@@ -167,32 +73,6 @@ class GradientClippingConfig(C.Config):
|
|
167
73
|
"""Norm type to use for gradient clipping."""
|
168
74
|
|
169
75
|
|
170
|
-
class OptimizationConfig(CallbackConfigBase):
|
171
|
-
log_grad_norm: bool | str | float = False
|
172
|
-
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
173
|
-
log_grad_norm_per_param: bool | str | float = False
|
174
|
-
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
175
|
-
|
176
|
-
log_param_norm: bool | str | float = False
|
177
|
-
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
178
|
-
log_param_norm_per_param: bool | str | float = False
|
179
|
-
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
180
|
-
|
181
|
-
gradient_clipping: GradientClippingConfig | None = None
|
182
|
-
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
183
|
-
|
184
|
-
@override
|
185
|
-
def create_callbacks(self, trainer_config):
|
186
|
-
from ..callbacks.norm_logging import NormLoggingCallbackConfig
|
187
|
-
|
188
|
-
yield from NormLoggingCallbackConfig(
|
189
|
-
log_grad_norm=self.log_grad_norm,
|
190
|
-
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
191
|
-
log_param_norm=self.log_param_norm,
|
192
|
-
log_param_norm_per_param=self.log_param_norm_per_param,
|
193
|
-
).create_callbacks(trainer_config)
|
194
|
-
|
195
|
-
|
196
76
|
TPlugin = TypeVar(
|
197
77
|
"TPlugin",
|
198
78
|
Precision,
|
@@ -252,15 +132,6 @@ StrategyLiteral: TypeAlias = Literal[
|
|
252
132
|
]
|
253
133
|
|
254
134
|
|
255
|
-
class ReproducibilityConfig(C.Config):
|
256
|
-
deterministic: bool | Literal["warn"] | None = None
|
257
|
-
"""
|
258
|
-
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
259
|
-
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
260
|
-
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
261
|
-
"""
|
262
|
-
|
263
|
-
|
264
135
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
265
136
|
BestCheckpointCallbackConfig
|
266
137
|
| LastCheckpointCallbackConfig
|
@@ -634,14 +505,34 @@ class TrainerConfig(C.Config):
|
|
634
505
|
hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
|
635
506
|
"""Hugging Face Hub configuration options."""
|
636
507
|
|
637
|
-
|
638
|
-
|
508
|
+
loggers: Sequence[LoggerConfig] = [
|
509
|
+
WandbLoggerConfig(),
|
510
|
+
CSVLoggerConfig(),
|
511
|
+
TensorboardLoggerConfig(),
|
512
|
+
]
|
513
|
+
"""Loggers to use for experiment tracking."""
|
639
514
|
|
640
|
-
|
641
|
-
"""
|
515
|
+
actsave_logger: ActSaveLoggerConfig | None = None
|
516
|
+
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
642
517
|
|
643
|
-
|
644
|
-
"""
|
518
|
+
lr_monitor: LearningRateMonitorConfig | None = LearningRateMonitorConfig()
|
519
|
+
"""Learning rate monitoring configuration options."""
|
520
|
+
|
521
|
+
log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
|
522
|
+
"""If enabled, will log the fractional epoch number to the logger."""
|
523
|
+
|
524
|
+
gradient_clipping: GradientClippingConfig | None = None
|
525
|
+
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
526
|
+
|
527
|
+
log_norms: NormLoggingCallbackConfig | None = None
|
528
|
+
"""Norm logging configuration options."""
|
529
|
+
|
530
|
+
deterministic: bool | Literal["warn"] | None = None
|
531
|
+
"""
|
532
|
+
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
533
|
+
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
534
|
+
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
535
|
+
"""
|
645
536
|
|
646
537
|
reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
|
647
538
|
RLPSanityChecksCallbackConfig()
|
@@ -856,27 +747,87 @@ class TrainerConfig(C.Config):
|
|
856
747
|
set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
|
857
748
|
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
858
749
|
|
750
|
+
@property
|
751
|
+
def wandb_logger(self):
|
752
|
+
return next(
|
753
|
+
(
|
754
|
+
logger
|
755
|
+
for logger in self.loggers
|
756
|
+
if isinstance(logger, WandbLoggerConfig)
|
757
|
+
),
|
758
|
+
None,
|
759
|
+
)
|
760
|
+
|
761
|
+
@property
|
762
|
+
def csv_logger(self):
|
763
|
+
return next(
|
764
|
+
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
765
|
+
None,
|
766
|
+
)
|
767
|
+
|
768
|
+
@property
|
769
|
+
def tensorboard_logger(self):
|
770
|
+
return next(
|
771
|
+
(
|
772
|
+
logger
|
773
|
+
for logger in self.loggers
|
774
|
+
if isinstance(logger, TensorboardLoggerConfig)
|
775
|
+
),
|
776
|
+
None,
|
777
|
+
)
|
778
|
+
|
859
779
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
860
780
|
yield self.early_stopping
|
861
781
|
yield self.checkpoint_saving
|
862
|
-
yield self.
|
863
|
-
yield
|
782
|
+
yield self.lr_monitor
|
783
|
+
yield from (
|
784
|
+
logger_config
|
785
|
+
for logger_config in self.loggers
|
786
|
+
if logger_config is not None
|
787
|
+
and isinstance(logger_config, CallbackConfigBase)
|
788
|
+
)
|
789
|
+
yield self.log_epoch
|
790
|
+
yield self.log_norms
|
864
791
|
yield self.hf_hub
|
865
792
|
yield self.shared_parameters
|
866
793
|
yield self.reduce_lr_on_plateau_sanity_checking
|
867
794
|
yield self.auto_set_debug_flag
|
868
795
|
yield from self.callbacks
|
869
796
|
|
797
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
|
798
|
+
yield from self.loggers
|
799
|
+
yield self.actsave_logger
|
800
|
+
|
870
801
|
# region Helper Methods
|
802
|
+
def fast_dev_run_(self, value: int | bool = True, /):
|
803
|
+
"""
|
804
|
+
Enables fast_dev_run mode for the trainer.
|
805
|
+
This will run the training loop for a specified number of batches,
|
806
|
+
if an integer is provided, or for a single batch if True is provided.
|
807
|
+
"""
|
808
|
+
self.fast_dev_run = value
|
809
|
+
return self
|
810
|
+
|
871
811
|
def with_fast_dev_run(self, value: int | bool = True, /):
|
872
812
|
"""
|
873
813
|
Enables fast_dev_run mode for the trainer.
|
874
814
|
This will run the training loop for a specified number of batches,
|
875
815
|
if an integer is provided, or for a single batch if True is provided.
|
876
816
|
"""
|
877
|
-
|
878
|
-
|
879
|
-
|
817
|
+
return copy.deepcopy(self).fast_dev_run_(value)
|
818
|
+
|
819
|
+
def project_root_(self, project_root: str | Path | os.PathLike):
|
820
|
+
"""
|
821
|
+
Set the project root directory for the trainer.
|
822
|
+
|
823
|
+
Args:
|
824
|
+
project_root (Path): The base directory to use.
|
825
|
+
|
826
|
+
Returns:
|
827
|
+
self: The current instance of the class.
|
828
|
+
"""
|
829
|
+
self.directory.project_root = Path(project_root)
|
830
|
+
return self
|
880
831
|
|
881
832
|
def with_project_root(self, project_root: str | Path | os.PathLike):
|
882
833
|
"""
|
@@ -888,9 +839,7 @@ class TrainerConfig(C.Config):
|
|
888
839
|
Returns:
|
889
840
|
self: The current instance of the class.
|
890
841
|
"""
|
891
|
-
|
892
|
-
config.directory.project_root = Path(project_root)
|
893
|
-
return config
|
842
|
+
return copy.deepcopy(self).project_root_(project_root)
|
894
843
|
|
895
844
|
def reset_run(
|
896
845
|
self,
|