nshtrainer 1.0.0b33__tar.gz → 1.0.0b37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/pyproject.toml +1 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/__init__.py +1 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_directory.py +3 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_hf_hub.py +8 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/__init__.py +10 -23
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/actsave.py +6 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/base.py +3 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -4
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer-1.0.0b33/src/nshtrainer/callbacks/checkpoint/time_checkpoint.py → nshtrainer-1.0.0b37/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +31 -31
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/debug_flag.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/directory_setup.py +23 -21
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/early_stopping.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/ema.py +29 -27
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/finite_checks.py +21 -19
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/gradient_skipping.py +29 -27
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/log_epoch.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/lr_monitor.py +6 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/norm_logging.py +36 -34
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/print_table.py +20 -18
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/shared_parameters.py +9 -7
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/timer.py +12 -10
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/wandb_upload_code.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/wandb_watch.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/__init__.py +16 -12
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/_hf_hub/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/__init__.py +4 -8
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/base/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/ema/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/timer/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/loggers/__init__.py +6 -4
- nshtrainer-1.0.0b37/src/nshtrainer/configs/loggers/actsave/__init__.py +13 -0
- nshtrainer-1.0.0b37/src/nshtrainer/configs/loggers/base/__init__.py +11 -0
- nshtrainer-1.0.0b37/src/nshtrainer/configs/loggers/csv/__init__.py +13 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/loggers/wandb/__init__.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/lr_scheduler/__init__.py +4 -2
- nshtrainer-1.0.0b37/src/nshtrainer/configs/lr_scheduler/base/__init__.py +11 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +4 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/nn/__init__.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/nn/mlp/__init__.py +2 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/optimizer/__init__.py +2 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/__init__.py +4 -6
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/_config/__init__.py +2 -10
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/loggers/__init__.py +3 -8
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/loggers/actsave.py +5 -2
- nshtrainer-1.0.0b33/src/nshtrainer/loggers/_base.py → nshtrainer-1.0.0b37/src/nshtrainer/loggers/base.py +4 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/loggers/csv.py +5 -3
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/loggers/tensorboard.py +5 -3
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/loggers/wandb.py +5 -3
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/lr_scheduler/__init__.py +2 -2
- nshtrainer-1.0.0b33/src/nshtrainer/lr_scheduler/_base.py → nshtrainer-1.0.0b37/src/nshtrainer/lr_scheduler/base.py +3 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +56 -54
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +4 -2
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/nn/__init__.py +1 -1
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/nn/mlp.py +4 -4
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/nn/nonlinearity.py +37 -33
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/optimizer.py +8 -2
- nshtrainer-1.0.0b37/src/nshtrainer/trainer/__init__.py +7 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/_config.py +6 -44
- nshtrainer-1.0.0b33/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -44
- nshtrainer-1.0.0b33/src/nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
- nshtrainer-1.0.0b33/src/nshtrainer/configs/loggers/_base/__init__.py +0 -9
- nshtrainer-1.0.0b33/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -11
- nshtrainer-1.0.0b33/src/nshtrainer/configs/loggers/csv/__init__.py +0 -11
- nshtrainer-1.0.0b33/src/nshtrainer/configs/lr_scheduler/_base/__init__.py +0 -9
- nshtrainer-1.0.0b33/src/nshtrainer/trainer/__init__.py +0 -6
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/README.md +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -15,6 +15,7 @@ from .model import LightningModuleBase as LightningModuleBase
|
|
15
15
|
from .trainer import Trainer as Trainer
|
16
16
|
from .trainer import TrainerConfig as TrainerConfig
|
17
17
|
from .trainer import accelerator_registry as accelerator_registry
|
18
|
+
from .trainer import callback_registry as callback_registry
|
18
19
|
from .trainer import plugin_registry as plugin_registry
|
19
20
|
|
20
21
|
try:
|
@@ -81,7 +81,9 @@ class DirectoryConfig(C.Config):
|
|
81
81
|
|
82
82
|
# Save to nshtrainer/{id}/log/{logger name}
|
83
83
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
84
|
-
log_dir = log_dir / logger
|
84
|
+
log_dir = log_dir / getattr(logger, "name")
|
85
|
+
# ^ NOTE: Logger must have a `name` attribute, as this is
|
86
|
+
# the discriminator for the logger registry
|
85
87
|
log_dir.mkdir(exist_ok=True)
|
86
88
|
|
87
89
|
return log_dir
|
@@ -14,7 +14,11 @@ from nshrunner._env import SNAPSHOT_DIR
|
|
14
14
|
from typing_extensions import assert_never, override
|
15
15
|
|
16
16
|
from ._callback import NTCallbackBase
|
17
|
-
from .callbacks.base import
|
17
|
+
from .callbacks.base import (
|
18
|
+
CallbackConfigBase,
|
19
|
+
CallbackMetadataConfig,
|
20
|
+
callback_registry,
|
21
|
+
)
|
18
22
|
|
19
23
|
if TYPE_CHECKING:
|
20
24
|
from huggingface_hub import HfApi # noqa: F401
|
@@ -39,9 +43,12 @@ class HuggingFaceHubAutoCreateConfig(C.Config):
|
|
39
43
|
return self.enabled
|
40
44
|
|
41
45
|
|
46
|
+
@callback_registry.register
|
42
47
|
class HuggingFaceHubConfig(CallbackConfigBase):
|
43
48
|
"""Configuration options for Hugging Face Hub integration."""
|
44
49
|
|
50
|
+
name: Literal["hf_hub"] = "hf_hub"
|
51
|
+
|
45
52
|
metadata: ClassVar[CallbackMetadataConfig] = {"ignore_if_exists": True}
|
46
53
|
|
47
54
|
enabled: bool = False
|
@@ -2,10 +2,13 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Annotated
|
4
4
|
|
5
|
-
|
5
|
+
from typing_extensions import TypeAliasType
|
6
6
|
|
7
7
|
from . import checkpoint as checkpoint
|
8
|
+
from .actsave import ActSaveCallback as ActSaveCallback
|
9
|
+
from .actsave import ActSaveConfig as ActSaveConfig
|
8
10
|
from .base import CallbackConfigBase as CallbackConfigBase
|
11
|
+
from .base import callback_registry as callback_registry
|
9
12
|
from .checkpoint import BestCheckpointCallback as BestCheckpointCallback
|
10
13
|
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
11
14
|
from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
|
@@ -14,8 +17,6 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
|
|
14
17
|
from .checkpoint import (
|
15
18
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
19
|
)
|
17
|
-
from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
-
from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
|
19
20
|
from .debug_flag import DebugFlagCallback as DebugFlagCallback
|
20
21
|
from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
21
22
|
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
@@ -37,6 +38,8 @@ from .interval import IntervalCallback as IntervalCallback
|
|
37
38
|
from .interval import StepIntervalCallback as StepIntervalCallback
|
38
39
|
from .log_epoch import LogEpochCallback as LogEpochCallback
|
39
40
|
from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
|
41
|
+
from .lr_monitor import LearningRateMonitor as LearningRateMonitor
|
42
|
+
from .lr_monitor import LearningRateMonitorConfig as LearningRateMonitorConfig
|
40
43
|
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
41
44
|
from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
42
45
|
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
@@ -60,23 +63,7 @@ from .wandb_upload_code import (
|
|
60
63
|
from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
61
64
|
from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
62
65
|
|
63
|
-
CallbackConfig =
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
| PrintTableMetricsCallbackConfig
|
68
|
-
| FiniteChecksCallbackConfig
|
69
|
-
| NormLoggingCallbackConfig
|
70
|
-
| GradientSkippingCallbackConfig
|
71
|
-
| LogEpochCallbackConfig
|
72
|
-
| EMACallbackConfig
|
73
|
-
| BestCheckpointCallbackConfig
|
74
|
-
| LastCheckpointCallbackConfig
|
75
|
-
| OnExceptionCheckpointCallbackConfig
|
76
|
-
| TimeCheckpointCallbackConfig
|
77
|
-
| SharedParametersCallbackConfig
|
78
|
-
| RLPSanityChecksCallbackConfig
|
79
|
-
| WandbWatchCallbackConfig
|
80
|
-
| WandbUploadCodeCallbackConfig,
|
81
|
-
C.Field(discriminator="name"),
|
82
|
-
]
|
66
|
+
CallbackConfig = TypeAliasType(
|
67
|
+
"CallbackConfig",
|
68
|
+
Annotated[CallbackConfigBase, callback_registry.DynamicResolution()],
|
69
|
+
)
|
@@ -4,15 +4,19 @@ import contextlib
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Literal
|
6
6
|
|
7
|
-
from typing_extensions import TypeAliasType, override
|
7
|
+
from typing_extensions import TypeAliasType, final, override
|
8
8
|
|
9
9
|
from .._callback import NTCallbackBase
|
10
|
-
from .base import CallbackConfigBase
|
10
|
+
from .base import CallbackConfigBase, callback_registry
|
11
11
|
|
12
12
|
Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
|
13
13
|
|
14
14
|
|
15
|
+
@final
|
16
|
+
@callback_registry.register
|
15
17
|
class ActSaveConfig(CallbackConfigBase):
|
18
|
+
name: Literal["act_save"] = "act_save"
|
19
|
+
|
16
20
|
enabled: bool = True
|
17
21
|
"""Enable activation saving."""
|
18
22
|
|
@@ -55,6 +55,9 @@ class CallbackConfigBase(C.Config, ABC):
|
|
55
55
|
) -> Iterable[Callback | CallbackWithMetadata]: ...
|
56
56
|
|
57
57
|
|
58
|
+
callback_registry = C.Registry(CallbackConfigBase, discriminator="name")
|
59
|
+
|
60
|
+
|
58
61
|
# region Config resolution helpers
|
59
62
|
def _create_callbacks_with_metadata(
|
60
63
|
config: CallbackConfigBase, trainer_config: TrainerConfig
|
@@ -14,7 +14,3 @@ from .on_exception_checkpoint import (
|
|
14
14
|
from .on_exception_checkpoint import (
|
15
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
-
from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
-
from .time_checkpoint import (
|
19
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
20
|
-
)
|
{nshtrainer-1.0.0b33 → nshtrainer-1.0.0b37}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
@@ -9,12 +9,14 @@ from typing_extensions import final, override
|
|
9
9
|
|
10
10
|
from ..._checkpoint.metadata import CheckpointMetadata
|
11
11
|
from ...metrics._config import MetricConfig
|
12
|
+
from ..base import callback_registry
|
12
13
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
13
14
|
|
14
15
|
log = logging.getLogger(__name__)
|
15
16
|
|
16
17
|
|
17
18
|
@final
|
19
|
+
@callback_registry.register
|
18
20
|
class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
19
21
|
name: Literal["best_checkpoint"] = "best_checkpoint"
|
20
22
|
|
@@ -9,36 +9,41 @@ from typing import Any, Literal
|
|
9
9
|
from lightning.pytorch import LightningModule, Trainer
|
10
10
|
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from
|
13
|
-
|
12
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
13
|
+
from ..base import callback_registry
|
14
14
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
19
|
@final
|
20
|
-
|
21
|
-
|
20
|
+
@callback_registry.register
|
21
|
+
class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
22
|
+
name: Literal["last_checkpoint"] = "last_checkpoint"
|
23
|
+
|
24
|
+
save_on_time_interval: bool = True
|
25
|
+
"""Whether to save checkpoints based on time interval."""
|
22
26
|
|
23
27
|
interval: timedelta = timedelta(hours=12)
|
24
|
-
"""Time interval between checkpoints."""
|
28
|
+
"""Time interval between checkpoints when save_on_time_interval is True."""
|
25
29
|
|
26
30
|
@override
|
27
31
|
def create_checkpoint(self, trainer_config, dirpath):
|
28
|
-
return
|
32
|
+
return LastCheckpointCallback(self, dirpath)
|
29
33
|
|
30
34
|
|
31
35
|
@final
|
32
|
-
class
|
33
|
-
def __init__(self, config:
|
36
|
+
class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
37
|
+
def __init__(self, config: LastCheckpointCallbackConfig, dirpath: Path):
|
34
38
|
super().__init__(config, dirpath)
|
35
39
|
self.start_time = time.time()
|
36
40
|
self.last_checkpoint_time = self.start_time
|
37
41
|
self.interval_seconds = config.interval.total_seconds()
|
42
|
+
self.save_on_time_interval = config.save_on_time_interval
|
38
43
|
|
39
44
|
@override
|
40
45
|
def name(self):
|
41
|
-
return "
|
46
|
+
return "last"
|
42
47
|
|
43
48
|
@override
|
44
49
|
def default_filename(self):
|
@@ -53,6 +58,8 @@ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
|
53
58
|
return True
|
54
59
|
|
55
60
|
def _should_checkpoint(self) -> bool:
|
61
|
+
if not self.save_on_time_interval:
|
62
|
+
return False
|
56
63
|
current_time = time.time()
|
57
64
|
elapsed_time = current_time - self.last_checkpoint_time
|
58
65
|
return elapsed_time >= self.interval_seconds
|
@@ -85,30 +92,23 @@ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
|
85
92
|
|
86
93
|
@override
|
87
94
|
def on_train_batch_end(
|
88
|
-
self,
|
95
|
+
self,
|
96
|
+
trainer: Trainer,
|
97
|
+
pl_module: LightningModule,
|
98
|
+
*args,
|
99
|
+
**kwargs,
|
89
100
|
):
|
90
|
-
if self._should_checkpoint():
|
91
|
-
|
92
|
-
|
101
|
+
if not self._should_checkpoint():
|
102
|
+
return
|
103
|
+
self.save_checkpoints(trainer)
|
93
104
|
|
94
105
|
@override
|
95
|
-
def
|
96
|
-
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
Dictionary containing the start time and last checkpoint time.
|
100
|
-
"""
|
101
|
-
return {
|
102
|
-
"start_time": self.start_time,
|
103
|
-
"last_checkpoint_time": self.last_checkpoint_time,
|
104
|
-
}
|
106
|
+
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
107
|
+
self.save_checkpoints(trainer)
|
105
108
|
|
106
109
|
@override
|
107
|
-
def
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
"""
|
113
|
-
self.start_time = state_dict["start_time"]
|
114
|
-
self.last_checkpoint_time = state_dict["last_checkpoint_time"]
|
110
|
+
def save_checkpoints(self, trainer):
|
111
|
+
super().save_checkpoints(trainer)
|
112
|
+
|
113
|
+
if self.save_on_time_interval:
|
114
|
+
self.last_checkpoint_time = time.time()
|
@@ -9,9 +9,9 @@ from typing import Any, Literal
|
|
9
9
|
|
10
10
|
from lightning.pytorch import Trainer as LightningTrainer
|
11
11
|
from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
|
12
|
-
from typing_extensions import override
|
12
|
+
from typing_extensions import final, override
|
13
13
|
|
14
|
-
from ..base import CallbackConfigBase
|
14
|
+
from ..base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
@@ -44,6 +44,8 @@ def _monkey_patch_disable_barrier(trainer: LightningTrainer):
|
|
44
44
|
log.warning("Reverted monkey-patched barrier.")
|
45
45
|
|
46
46
|
|
47
|
+
@final
|
48
|
+
@callback_registry.register
|
47
49
|
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
48
50
|
name: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
49
51
|
|
@@ -3,14 +3,16 @@ from __future__ import annotations
|
|
3
3
|
import logging
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from typing_extensions import override
|
6
|
+
from typing_extensions import final, override
|
7
7
|
|
8
8
|
from .._callback import NTCallbackBase
|
9
|
-
from .base import CallbackConfigBase
|
9
|
+
from .base import CallbackConfigBase, callback_registry
|
10
10
|
|
11
11
|
log = logging.getLogger(__name__)
|
12
12
|
|
13
13
|
|
14
|
+
@final
|
15
|
+
@callback_registry.register
|
14
16
|
class DebugFlagCallbackConfig(CallbackConfigBase):
|
15
17
|
name: Literal["debug_flag"] = "debug_flag"
|
16
18
|
|
@@ -5,14 +5,35 @@ import os
|
|
5
5
|
from pathlib import Path
|
6
6
|
from typing import Literal
|
7
7
|
|
8
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import final, override
|
9
9
|
|
10
10
|
from .._callback import NTCallbackBase
|
11
|
-
from .base import CallbackConfigBase
|
11
|
+
from .base import CallbackConfigBase, callback_registry
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
15
|
|
16
|
+
@final
|
17
|
+
@callback_registry.register
|
18
|
+
class DirectorySetupCallbackConfig(CallbackConfigBase):
|
19
|
+
name: Literal["directory_setup"] = "directory_setup"
|
20
|
+
|
21
|
+
enabled: bool = True
|
22
|
+
"""Whether to enable the directory setup callback."""
|
23
|
+
|
24
|
+
create_symlink_to_nshrunner_root: bool = True
|
25
|
+
"""Should we create a symlink to the root folder for the Runner (if we're in one)?"""
|
26
|
+
|
27
|
+
def __bool__(self):
|
28
|
+
return self.enabled
|
29
|
+
|
30
|
+
def create_callbacks(self, trainer_config):
|
31
|
+
if not self:
|
32
|
+
return
|
33
|
+
|
34
|
+
yield DirectorySetupCallback(self)
|
35
|
+
|
36
|
+
|
16
37
|
def _create_symlink_to_nshrunner(base_dir: Path):
|
17
38
|
# Resolve the current nshrunner session directory
|
18
39
|
if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
|
@@ -43,25 +64,6 @@ def _create_symlink_to_nshrunner(base_dir: Path):
|
|
43
64
|
symlink_path.symlink_to(session_dir)
|
44
65
|
|
45
66
|
|
46
|
-
class DirectorySetupCallbackConfig(CallbackConfigBase):
|
47
|
-
name: Literal["directory_setup"] = "directory_setup"
|
48
|
-
|
49
|
-
enabled: bool = True
|
50
|
-
"""Whether to enable the directory setup callback."""
|
51
|
-
|
52
|
-
create_symlink_to_nshrunner_root: bool = True
|
53
|
-
"""Should we create a symlink to the root folder for the Runner (if we're in one)?"""
|
54
|
-
|
55
|
-
def __bool__(self):
|
56
|
-
return self.enabled
|
57
|
-
|
58
|
-
def create_callbacks(self, trainer_config):
|
59
|
-
if not self:
|
60
|
-
return
|
61
|
-
|
62
|
-
yield DirectorySetupCallback(self)
|
63
|
-
|
64
|
-
|
65
67
|
class DirectorySetupCallback(NTCallbackBase):
|
66
68
|
@override
|
67
69
|
def __init__(self, config: DirectorySetupCallbackConfig):
|
@@ -8,14 +8,16 @@ from lightning.fabric.utilities.rank_zero import _get_rank
|
|
8
8
|
from lightning.pytorch import Trainer
|
9
9
|
from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
|
10
10
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
|
11
|
-
from typing_extensions import override
|
11
|
+
from typing_extensions import final, override
|
12
12
|
|
13
13
|
from ..metrics._config import MetricConfig
|
14
|
-
from .base import CallbackConfigBase
|
14
|
+
from .base import CallbackConfigBase, callback_registry
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
18
18
|
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
19
21
|
class EarlyStoppingCallbackConfig(CallbackConfigBase):
|
20
22
|
name: Literal["early_stopping"] = "early_stopping"
|
21
23
|
|
@@ -10,9 +10,36 @@ import lightning.pytorch as pl
|
|
10
10
|
import torch
|
11
11
|
from lightning.pytorch import Callback
|
12
12
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
13
|
-
from typing_extensions import override
|
13
|
+
from typing_extensions import final, override
|
14
14
|
|
15
|
-
from .base import CallbackConfigBase
|
15
|
+
from .base import CallbackConfigBase, callback_registry
|
16
|
+
|
17
|
+
|
18
|
+
@final
|
19
|
+
@callback_registry.register
|
20
|
+
class EMACallbackConfig(CallbackConfigBase):
|
21
|
+
name: Literal["ema"] = "ema"
|
22
|
+
|
23
|
+
decay: float
|
24
|
+
"""The exponential decay used when calculating the moving average. Has to be between 0-1."""
|
25
|
+
|
26
|
+
validate_original_weights: bool = False
|
27
|
+
"""Validate the original weights, as apposed to the EMA weights."""
|
28
|
+
|
29
|
+
every_n_steps: int = 1
|
30
|
+
"""Apply EMA every N steps."""
|
31
|
+
|
32
|
+
cpu_offload: bool = False
|
33
|
+
"""Offload weights to CPU."""
|
34
|
+
|
35
|
+
@override
|
36
|
+
def create_callbacks(self, trainer_config):
|
37
|
+
yield EMACallback(
|
38
|
+
decay=self.decay,
|
39
|
+
validate_original_weights=self.validate_original_weights,
|
40
|
+
every_n_steps=self.every_n_steps,
|
41
|
+
cpu_offload=self.cpu_offload,
|
42
|
+
)
|
16
43
|
|
17
44
|
|
18
45
|
class EMACallback(Callback):
|
@@ -358,28 +385,3 @@ class EMAOptimizer(torch.optim.Optimizer):
|
|
358
385
|
def add_param_group(self, param_group):
|
359
386
|
self.optimizer.add_param_group(param_group)
|
360
387
|
self.rebuild_ema_params = True
|
361
|
-
|
362
|
-
|
363
|
-
class EMACallbackConfig(CallbackConfigBase):
|
364
|
-
name: Literal["ema"] = "ema"
|
365
|
-
|
366
|
-
decay: float
|
367
|
-
"""The exponential decay used when calculating the moving average. Has to be between 0-1."""
|
368
|
-
|
369
|
-
validate_original_weights: bool = False
|
370
|
-
"""Validate the original weights, as apposed to the EMA weights."""
|
371
|
-
|
372
|
-
every_n_steps: int = 1
|
373
|
-
"""Apply EMA every N steps."""
|
374
|
-
|
375
|
-
cpu_offload: bool = False
|
376
|
-
"""Offload weights to CPU."""
|
377
|
-
|
378
|
-
@override
|
379
|
-
def create_callbacks(self, trainer_config):
|
380
|
-
yield EMACallback(
|
381
|
-
decay=self.decay,
|
382
|
-
validate_original_weights=self.validate_original_weights,
|
383
|
-
every_n_steps=self.every_n_steps,
|
384
|
-
cpu_offload=self.cpu_offload,
|
385
|
-
)
|
@@ -5,13 +5,32 @@ from typing import Literal
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
8
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import final, override
|
9
9
|
|
10
|
-
from .base import CallbackConfigBase
|
10
|
+
from .base import CallbackConfigBase, callback_registry
|
11
11
|
|
12
12
|
log = logging.getLogger(__name__)
|
13
13
|
|
14
14
|
|
15
|
+
@final
|
16
|
+
@callback_registry.register
|
17
|
+
class FiniteChecksCallbackConfig(CallbackConfigBase):
|
18
|
+
name: Literal["finite_checks"] = "finite_checks"
|
19
|
+
|
20
|
+
nonfinite_grads: bool = True
|
21
|
+
"""Whether to check for non-finite (i.e. NaN or Inf) gradients"""
|
22
|
+
|
23
|
+
none_grads: bool = True
|
24
|
+
"""Whether to check for None gradients"""
|
25
|
+
|
26
|
+
@override
|
27
|
+
def create_callbacks(self, trainer_config):
|
28
|
+
yield FiniteChecksCallback(
|
29
|
+
nonfinite_grads=self.nonfinite_grads,
|
30
|
+
none_grads=self.none_grads,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
15
34
|
def finite_checks(
|
16
35
|
module: LightningModule,
|
17
36
|
nonfinite_grads: bool = True,
|
@@ -58,20 +77,3 @@ class FiniteChecksCallback(Callback):
|
|
58
77
|
nonfinite_grads=self._nonfinite_grads,
|
59
78
|
none_grads=self._none_grads,
|
60
79
|
)
|
61
|
-
|
62
|
-
|
63
|
-
class FiniteChecksCallbackConfig(CallbackConfigBase):
|
64
|
-
name: Literal["finite_checks"] = "finite_checks"
|
65
|
-
|
66
|
-
nonfinite_grads: bool = True
|
67
|
-
"""Whether to check for non-finite (i.e. NaN or Inf) gradients"""
|
68
|
-
|
69
|
-
none_grads: bool = True
|
70
|
-
"""Whether to check for None gradients"""
|
71
|
-
|
72
|
-
@override
|
73
|
-
def create_callbacks(self, trainer_config):
|
74
|
-
yield FiniteChecksCallback(
|
75
|
-
nonfinite_grads=self.nonfinite_grads,
|
76
|
-
none_grads=self.none_grads,
|
77
|
-
)
|
@@ -7,21 +7,47 @@ import torch
|
|
7
7
|
import torchmetrics
|
8
8
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
9
9
|
from torch.optim import Optimizer
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import final, override
|
11
11
|
|
12
|
-
from .base import CallbackConfigBase
|
12
|
+
from .base import CallbackConfigBase, callback_registry
|
13
13
|
from .norm_logging import compute_norm
|
14
14
|
|
15
15
|
log = logging.getLogger(__name__)
|
16
16
|
|
17
17
|
|
18
|
+
@final
|
19
|
+
@callback_registry.register
|
20
|
+
class GradientSkippingCallbackConfig(CallbackConfigBase):
|
21
|
+
name: Literal["gradient_skipping"] = "gradient_skipping"
|
22
|
+
|
23
|
+
threshold: float
|
24
|
+
"""Threshold to use for gradient skipping."""
|
25
|
+
|
26
|
+
norm_type: str | float = 2.0
|
27
|
+
"""Norm type to use for gradient skipping."""
|
28
|
+
|
29
|
+
start_after_n_steps: int | None = 100
|
30
|
+
"""Number of steps to wait before starting gradient skipping."""
|
31
|
+
|
32
|
+
skip_non_finite: bool = False
|
33
|
+
"""
|
34
|
+
If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
|
35
|
+
|
36
|
+
Should almost always be False, especially when using AMP (unless you know what you're doing!).
|
37
|
+
"""
|
38
|
+
|
39
|
+
@override
|
40
|
+
def create_callbacks(self, trainer_config):
|
41
|
+
yield GradientSkippingCallback(self)
|
42
|
+
|
43
|
+
|
18
44
|
@runtime_checkable
|
19
45
|
class HasGradSkippedSteps(Protocol):
|
20
46
|
grad_skipped_steps: Any
|
21
47
|
|
22
48
|
|
23
49
|
class GradientSkippingCallback(Callback):
|
24
|
-
def __init__(self, config:
|
50
|
+
def __init__(self, config: GradientSkippingCallbackConfig):
|
25
51
|
super().__init__()
|
26
52
|
self.config = config
|
27
53
|
|
@@ -73,27 +99,3 @@ class GradientSkippingCallback(Callback):
|
|
73
99
|
on_step=True,
|
74
100
|
on_epoch=False,
|
75
101
|
)
|
76
|
-
|
77
|
-
|
78
|
-
class GradientSkippingCallbackConfig(CallbackConfigBase):
|
79
|
-
name: Literal["gradient_skipping"] = "gradient_skipping"
|
80
|
-
|
81
|
-
threshold: float
|
82
|
-
"""Threshold to use for gradient skipping."""
|
83
|
-
|
84
|
-
norm_type: str | float = 2.0
|
85
|
-
"""Norm type to use for gradient skipping."""
|
86
|
-
|
87
|
-
start_after_n_steps: int | None = 100
|
88
|
-
"""Number of steps to wait before starting gradient skipping."""
|
89
|
-
|
90
|
-
skip_non_finite: bool = False
|
91
|
-
"""
|
92
|
-
If False, it doesn't skip steps with non-finite norms. This is useful when using AMP, as AMP checks for NaN/Inf grads to adjust the loss scale. Otherwise, skips steps with non-finite norms.
|
93
|
-
|
94
|
-
Should almost always be False, especially when using AMP (unless you know what you're doing!).
|
95
|
-
"""
|
96
|
-
|
97
|
-
@override
|
98
|
-
def create_callbacks(self, trainer_config):
|
99
|
-
yield GradientSkippingCallback(self)
|
@@ -6,13 +6,15 @@ from typing import Any, Literal
|
|
6
6
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from lightning.pytorch.callbacks import Callback
|
9
|
-
from typing_extensions import override
|
9
|
+
from typing_extensions import final, override
|
10
10
|
|
11
|
-
from .base import CallbackConfigBase
|
11
|
+
from .base import CallbackConfigBase, callback_registry
|
12
12
|
|
13
13
|
log = logging.getLogger(__name__)
|
14
14
|
|
15
15
|
|
16
|
+
@final
|
17
|
+
@callback_registry.register
|
16
18
|
class LogEpochCallbackConfig(CallbackConfigBase):
|
17
19
|
name: Literal["log_epoch"] = "log_epoch"
|
18
20
|
|
@@ -3,11 +3,16 @@ from __future__ import annotations
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
5
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
+
from typing_extensions import final
|
6
7
|
|
7
|
-
from .base import CallbackConfigBase
|
8
|
+
from .base import CallbackConfigBase, callback_registry
|
8
9
|
|
9
10
|
|
11
|
+
@final
|
12
|
+
@callback_registry.register
|
10
13
|
class LearningRateMonitorConfig(CallbackConfigBase):
|
14
|
+
name: Literal["learning_rate_monitor"] = "learning_rate_monitor"
|
15
|
+
|
11
16
|
logging_interval: Literal["step", "epoch"] | None = None
|
12
17
|
"""
|
13
18
|
Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
|