nshtrainer 0.41.1__tar.gz → 0.43.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/PKG-INFO +1 -1
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/pyproject.toml +9 -5
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_callback.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/loader.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/metadata.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_checkpoint/saver.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_directory.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_experimental/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/_hf_hub.py +2 -0
- nshtrainer-0.43.0/src/nshtrainer/callbacks/__init__.py +81 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/actsave.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +6 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/debug_flag.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/directory_setup.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/early_stopping.py +6 -4
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/ema.py +5 -3
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/finite_checks.py +3 -1
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/gradient_skipping.py +6 -4
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/interval.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/log_epoch.py +13 -1
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/norm_logging.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/print_table.py +3 -1
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/shared_parameters.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/throughput_monitor.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/timer.py +5 -3
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/wandb_upload_code.py +4 -2
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer-0.43.0/src/nshtrainer/config/__init__.py +465 -0
- nshtrainer-0.43.0/src/nshtrainer/config/_checkpoint/loader/__init__.py +62 -0
- nshtrainer-0.43.0/src/nshtrainer/config/_checkpoint/metadata/__init__.py +29 -0
- nshtrainer-0.43.0/src/nshtrainer/config/_directory/__init__.py +32 -0
- nshtrainer-0.43.0/src/nshtrainer/config/_hf_hub/__init__.py +32 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/__init__.py +176 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/actsave/__init__.py +27 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/base/__init__.py +24 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/__init__.py +73 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +40 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +47 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +40 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/debug_flag/__init__.py +31 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/directory_setup/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/early_stopping/__init__.py +38 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/ema/__init__.py +27 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/finite_checks/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/norm_logging/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/print_table/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/timer/__init__.py +31 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +33 -0
- nshtrainer-0.43.0/src/nshtrainer/config/loggers/__init__.py +58 -0
- nshtrainer-0.43.0/src/nshtrainer/config/loggers/_base/__init__.py +22 -0
- nshtrainer-0.43.0/src/nshtrainer/config/loggers/csv/__init__.py +25 -0
- nshtrainer-0.43.0/src/nshtrainer/config/loggers/tensorboard/__init__.py +31 -0
- nshtrainer-0.43.0/src/nshtrainer/config/loggers/wandb/__init__.py +44 -0
- nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/__init__.py +59 -0
- nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/_base/__init__.py +26 -0
- nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +40 -0
- nshtrainer-0.43.0/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +40 -0
- nshtrainer-0.43.0/src/nshtrainer/config/metrics/__init__.py +24 -0
- nshtrainer-0.43.0/src/nshtrainer/config/metrics/_config/__init__.py +22 -0
- nshtrainer-0.43.0/src/nshtrainer/config/model/__init__.py +41 -0
- nshtrainer-0.43.0/src/nshtrainer/config/model/base/__init__.py +25 -0
- nshtrainer-0.43.0/src/nshtrainer/config/model/config/__init__.py +37 -0
- nshtrainer-0.43.0/src/nshtrainer/config/model/mixins/logger/__init__.py +22 -0
- nshtrainer-0.43.0/src/nshtrainer/config/nn/__init__.py +77 -0
- nshtrainer-0.43.0/src/nshtrainer/config/nn/mlp/__init__.py +28 -0
- nshtrainer-0.43.0/src/nshtrainer/config/nn/nonlinearity/__init__.py +125 -0
- nshtrainer-0.43.0/src/nshtrainer/config/optimizer/__init__.py +28 -0
- nshtrainer-0.43.0/src/nshtrainer/config/profiler/__init__.py +39 -0
- nshtrainer-0.43.0/src/nshtrainer/config/profiler/_base/__init__.py +24 -0
- nshtrainer-0.43.0/src/nshtrainer/config/profiler/advanced/__init__.py +31 -0
- nshtrainer-0.43.0/src/nshtrainer/config/profiler/pytorch/__init__.py +31 -0
- nshtrainer-0.43.0/src/nshtrainer/config/profiler/simple/__init__.py +29 -0
- nshtrainer-0.43.0/src/nshtrainer/config/runner/__init__.py +22 -0
- nshtrainer-0.43.0/src/nshtrainer/config/trainer/_config/__init__.py +153 -0
- nshtrainer-0.43.0/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +26 -0
- nshtrainer-0.43.0/src/nshtrainer/config/util/_environment_info/__init__.py +94 -0
- nshtrainer-0.43.0/src/nshtrainer/config/util/config/__init__.py +34 -0
- nshtrainer-0.43.0/src/nshtrainer/config/util/config/dtype/__init__.py +22 -0
- nshtrainer-0.43.0/src/nshtrainer/config/util/config/duration/__init__.py +34 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/balanced_batch_sampler.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/datamodule.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/data/transform.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/_experimental.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/actsave.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/callbacks.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/config.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/data.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/log.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/lr_scheduler.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/model.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/nn.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/optimizer.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/runner.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/snapshot.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/snoop.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/trainer.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/typecheck.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/ll/util.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/_base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/csv.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/tensorboard.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/loggers/wandb.py +6 -4
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/_base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/metrics/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/metrics/_config.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/config.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/mixins/callback.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/model/mixins/logger.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/mlp.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/module_dict.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/module_list.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/nn/nonlinearity.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/optimizer.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/_base.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/advanced.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/pytorch.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/profiler/simple.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/runner.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/scripts/find_packages.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/_config.py +16 -13
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/_runtime_callback.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/checkpoint_connector.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/signal_connector.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/trainer/trainer.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/_environment_info.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/bf16.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/__init__.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/dtype.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/config/duration.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/environment.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/path.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/seed.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/slurm.py +3 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/typed.py +2 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/typing_utils.py +2 -0
- nshtrainer-0.41.1/src/nshtrainer/callbacks/__init__.py +0 -65
- nshtrainer-0.41.1/src/nshtrainer/config/__init__.py +0 -114
- nshtrainer-0.41.1/src/nshtrainer/config/_checkpoint/loader/__init__.py +0 -18
- nshtrainer-0.41.1/src/nshtrainer/config/_checkpoint/metadata/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/_directory/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/_hf_hub/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/__init__.py +0 -51
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/actsave/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/base/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/__init__.py +0 -22
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +0 -15
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/debug_flag/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/directory_setup/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/early_stopping/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/ema/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/finite_checks/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/norm_logging/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/print_table/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/timer/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/loggers/__init__.py +0 -23
- nshtrainer-0.41.1/src/nshtrainer/config/loggers/_base/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/loggers/csv/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/loggers/tensorboard/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/loggers/wandb/__init__.py +0 -16
- nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/__init__.py +0 -20
- nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/_base/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/metrics/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/metrics/_config/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/model/__init__.py +0 -20
- nshtrainer-0.41.1/src/nshtrainer/config/model/base/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/model/config/__init__.py +0 -17
- nshtrainer-0.41.1/src/nshtrainer/config/model/mixins/logger/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/nn/__init__.py +0 -30
- nshtrainer-0.41.1/src/nshtrainer/config/nn/mlp/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/nn/nonlinearity/__init__.py +0 -27
- nshtrainer-0.41.1/src/nshtrainer/config/optimizer/__init__.py +0 -14
- nshtrainer-0.41.1/src/nshtrainer/config/profiler/__init__.py +0 -20
- nshtrainer-0.41.1/src/nshtrainer/config/profiler/_base/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/profiler/advanced/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/profiler/pytorch/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/profiler/simple/__init__.py +0 -13
- nshtrainer-0.41.1/src/nshtrainer/config/runner/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/trainer/_config/__init__.py +0 -35
- nshtrainer-0.41.1/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/util/_environment_info/__init__.py +0 -22
- nshtrainer-0.41.1/src/nshtrainer/config/util/config/__init__.py +0 -17
- nshtrainer-0.41.1/src/nshtrainer/config/util/config/dtype/__init__.py +0 -12
- nshtrainer-0.41.1/src/nshtrainer/config/util/config/duration/__init__.py +0 -14
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/README.md +0 -0
- {nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/util/_useful_types.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "nshtrainer"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.43.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -25,10 +25,10 @@ tensorboard = { version = "*", optional = true }
|
|
|
25
25
|
huggingface-hub = { version = "*", optional = true }
|
|
26
26
|
|
|
27
27
|
[tool.poetry.group.dev.dependencies]
|
|
28
|
-
pyright = "
|
|
29
|
-
ruff = "
|
|
30
|
-
ipykernel = "
|
|
31
|
-
ipywidgets = "
|
|
28
|
+
pyright = "*"
|
|
29
|
+
ruff = "*"
|
|
30
|
+
ipykernel = "*"
|
|
31
|
+
ipywidgets = "*"
|
|
32
32
|
|
|
33
33
|
[build-system]
|
|
34
34
|
requires = ["poetry-core"]
|
|
@@ -43,7 +43,11 @@ strictSetInference = true
|
|
|
43
43
|
reportPrivateImportUsage = false
|
|
44
44
|
|
|
45
45
|
[tool.ruff.lint]
|
|
46
|
+
select = ["FA102", "FA100"]
|
|
46
47
|
ignore = ["F722", "F821", "E731", "E741"]
|
|
47
48
|
|
|
49
|
+
[tool.ruff.lint.isort]
|
|
50
|
+
required-imports = ["from __future__ import annotations"]
|
|
51
|
+
|
|
48
52
|
[tool.poetry.extras]
|
|
49
53
|
extra = ["wrapt", "GitPython", "wandb", "tensorboard", "huggingface-hub"]
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
|
|
4
6
|
import nshconfig as C
|
|
5
7
|
|
|
6
|
-
from .callbacks.directory_setup import
|
|
8
|
+
from .callbacks.directory_setup import DirectorySetupCallbackConfig
|
|
7
9
|
from .loggers import LoggerConfig
|
|
8
10
|
|
|
9
11
|
log = logging.getLogger(__name__)
|
|
@@ -32,7 +34,7 @@ class DirectoryConfig(C.Config):
|
|
|
32
34
|
profile: Path | None = None
|
|
33
35
|
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
|
34
36
|
|
|
35
|
-
setup_callback:
|
|
37
|
+
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
|
36
38
|
"""Configuration for the directory setup PyTorch Lightning callback."""
|
|
37
39
|
|
|
38
40
|
def resolve_run_root_directory(self, run_id: str) -> Path:
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
import nshconfig as C
|
|
6
|
+
|
|
7
|
+
from . import checkpoint as checkpoint
|
|
8
|
+
from .base import CallbackConfigBase as CallbackConfigBase
|
|
9
|
+
from .checkpoint import BestCheckpoint as BestCheckpoint
|
|
10
|
+
from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
11
|
+
from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
|
|
12
|
+
from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
13
|
+
from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCallback
|
|
14
|
+
from .checkpoint import (
|
|
15
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
16
|
+
)
|
|
17
|
+
from .debug_flag import DebugFlagCallback as DebugFlagCallback
|
|
18
|
+
from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
|
19
|
+
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
|
20
|
+
from .directory_setup import (
|
|
21
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
|
22
|
+
)
|
|
23
|
+
from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
|
|
24
|
+
from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
|
|
25
|
+
from .ema import EMACallback as EMACallback
|
|
26
|
+
from .ema import EMACallbackConfig as EMACallbackConfig
|
|
27
|
+
from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
|
|
28
|
+
from .finite_checks import FiniteChecksCallbackConfig as FiniteChecksCallbackConfig
|
|
29
|
+
from .gradient_skipping import GradientSkippingCallback as GradientSkippingCallback
|
|
30
|
+
from .gradient_skipping import (
|
|
31
|
+
GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
|
|
32
|
+
)
|
|
33
|
+
from .interval import EpochIntervalCallback as EpochIntervalCallback
|
|
34
|
+
from .interval import IntervalCallback as IntervalCallback
|
|
35
|
+
from .interval import StepIntervalCallback as StepIntervalCallback
|
|
36
|
+
from .log_epoch import LogEpochCallback as LogEpochCallback
|
|
37
|
+
from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
|
|
38
|
+
from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
|
39
|
+
from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
|
|
40
|
+
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
|
41
|
+
from .print_table import (
|
|
42
|
+
PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
|
|
43
|
+
)
|
|
44
|
+
from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
|
|
45
|
+
from .rlp_sanity_checks import (
|
|
46
|
+
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
|
47
|
+
)
|
|
48
|
+
from .shared_parameters import SharedParametersCallback as SharedParametersCallback
|
|
49
|
+
from .shared_parameters import (
|
|
50
|
+
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
|
51
|
+
)
|
|
52
|
+
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
53
|
+
from .timer import EpochTimerCallback as EpochTimerCallback
|
|
54
|
+
from .timer import EpochTimerCallbackConfig as EpochTimerCallbackConfig
|
|
55
|
+
from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
|
|
56
|
+
from .wandb_upload_code import (
|
|
57
|
+
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
|
58
|
+
)
|
|
59
|
+
from .wandb_watch import WandbWatchCallback as WandbWatchCallback
|
|
60
|
+
from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
|
61
|
+
|
|
62
|
+
CallbackConfig = Annotated[
|
|
63
|
+
DebugFlagCallbackConfig
|
|
64
|
+
| EarlyStoppingCallbackConfig
|
|
65
|
+
| ThroughputMonitorConfig
|
|
66
|
+
| EpochTimerCallbackConfig
|
|
67
|
+
| PrintTableMetricsCallbackConfig
|
|
68
|
+
| FiniteChecksCallbackConfig
|
|
69
|
+
| NormLoggingCallbackConfig
|
|
70
|
+
| GradientSkippingCallbackConfig
|
|
71
|
+
| LogEpochCallbackConfig
|
|
72
|
+
| EMACallbackConfig
|
|
73
|
+
| BestCheckpointCallbackConfig
|
|
74
|
+
| LastCheckpointCallbackConfig
|
|
75
|
+
| OnExceptionCheckpointCallbackConfig
|
|
76
|
+
| SharedParametersCallbackConfig
|
|
77
|
+
| RLPSanityChecksCallbackConfig
|
|
78
|
+
| WandbWatchCallbackConfig
|
|
79
|
+
| WandbUploadCodeCallbackConfig,
|
|
80
|
+
C.Field(discriminator="name"),
|
|
81
|
+
]
|
{nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
import time
|
|
16
18
|
from collections import deque
|
|
17
19
|
from typing import (
|
|
@@ -1,12 +1,16 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from .best_checkpoint import BestCheckpoint as BestCheckpoint
|
|
2
4
|
from .best_checkpoint import (
|
|
3
5
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
|
4
6
|
)
|
|
5
|
-
from .last_checkpoint import
|
|
7
|
+
from .last_checkpoint import LastCheckpointCallback as LastCheckpointCallback
|
|
6
8
|
from .last_checkpoint import (
|
|
7
9
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
|
8
10
|
)
|
|
9
|
-
from .on_exception_checkpoint import
|
|
11
|
+
from .on_exception_checkpoint import (
|
|
12
|
+
OnExceptionCheckpointCallback as OnExceptionCheckpointCallback,
|
|
13
|
+
)
|
|
10
14
|
from .on_exception_checkpoint import (
|
|
11
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
12
16
|
)
|
{nshtrainer-0.41.1 → nshtrainer-0.43.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Literal
|
|
3
5
|
|
|
@@ -17,11 +19,11 @@ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
|
17
19
|
|
|
18
20
|
@override
|
|
19
21
|
def create_checkpoint(self, root_config, dirpath):
|
|
20
|
-
return
|
|
22
|
+
return LastCheckpointCallback(self, dirpath)
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
@final
|
|
24
|
-
class
|
|
26
|
+
class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
25
27
|
@override
|
|
26
28
|
def name(self):
|
|
27
29
|
return "last"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import datetime
|
|
3
5
|
import logging
|
|
@@ -59,10 +61,12 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
|
59
61
|
|
|
60
62
|
if not (filename := self.filename):
|
|
61
63
|
filename = f"on_exception_{root_config.id}"
|
|
62
|
-
yield
|
|
64
|
+
yield OnExceptionCheckpointCallback(
|
|
65
|
+
self, dirpath=Path(dirpath), filename=filename
|
|
66
|
+
)
|
|
63
67
|
|
|
64
68
|
|
|
65
|
-
class
|
|
69
|
+
class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
|
|
66
70
|
@override
|
|
67
71
|
def __init__(
|
|
68
72
|
self,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import os
|
|
3
5
|
from pathlib import Path
|
|
@@ -41,7 +43,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
|
|
|
41
43
|
symlink_path.symlink_to(session_dir)
|
|
42
44
|
|
|
43
45
|
|
|
44
|
-
class
|
|
46
|
+
class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
45
47
|
name: Literal["directory_setup"] = "directory_setup"
|
|
46
48
|
|
|
47
49
|
enabled: bool = True
|
|
@@ -62,7 +64,7 @@ class DirectorySetupConfig(CallbackConfigBase):
|
|
|
62
64
|
|
|
63
65
|
class DirectorySetupCallback(Callback):
|
|
64
66
|
@override
|
|
65
|
-
def __init__(self, config:
|
|
67
|
+
def __init__(self, config: DirectorySetupCallbackConfig):
|
|
66
68
|
super().__init__()
|
|
67
69
|
|
|
68
70
|
self.config = config
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import math
|
|
3
5
|
from typing import Literal
|
|
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
|
|
|
14
16
|
log = logging.getLogger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
class
|
|
19
|
+
class EarlyStoppingCallbackConfig(CallbackConfigBase):
|
|
18
20
|
name: Literal["early_stopping"] = "early_stopping"
|
|
19
21
|
|
|
20
22
|
metric: MetricConfig | None = None
|
|
@@ -54,11 +56,11 @@ class EarlyStoppingConfig(CallbackConfigBase):
|
|
|
54
56
|
"Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
|
|
55
57
|
)
|
|
56
58
|
|
|
57
|
-
yield
|
|
59
|
+
yield EarlyStoppingCallback(self, metric)
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
class
|
|
61
|
-
def __init__(self, config:
|
|
62
|
+
class EarlyStoppingCallback(_EarlyStopping):
|
|
63
|
+
def __init__(self, config: EarlyStoppingCallbackConfig, metric: MetricConfig):
|
|
62
64
|
self.config = config
|
|
63
65
|
self.metric = metric
|
|
64
66
|
del config, metric
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import threading
|
|
@@ -13,7 +15,7 @@ from typing_extensions import override
|
|
|
13
15
|
from .base import CallbackConfigBase
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
class
|
|
18
|
+
class EMACallback(Callback):
|
|
17
19
|
"""
|
|
18
20
|
Implements Exponential Moving Averaging (EMA).
|
|
19
21
|
|
|
@@ -358,7 +360,7 @@ class EMAOptimizer(torch.optim.Optimizer):
|
|
|
358
360
|
self.rebuild_ema_params = True
|
|
359
361
|
|
|
360
362
|
|
|
361
|
-
class
|
|
363
|
+
class EMACallbackConfig(CallbackConfigBase):
|
|
362
364
|
name: Literal["ema"] = "ema"
|
|
363
365
|
|
|
364
366
|
decay: float
|
|
@@ -375,7 +377,7 @@ class EMAConfig(CallbackConfigBase):
|
|
|
375
377
|
|
|
376
378
|
@override
|
|
377
379
|
def create_callbacks(self, root_config):
|
|
378
|
-
yield
|
|
380
|
+
yield EMACallback(
|
|
379
381
|
decay=self.decay,
|
|
380
382
|
validate_original_weights=self.validate_original_weights,
|
|
381
383
|
every_n_steps=self.every_n_steps,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Literal
|
|
3
5
|
|
|
@@ -58,7 +60,7 @@ class FiniteChecksCallback(Callback):
|
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
|
|
61
|
-
class
|
|
63
|
+
class FiniteChecksCallbackConfig(CallbackConfigBase):
|
|
62
64
|
name: Literal["finite_checks"] = "finite_checks"
|
|
63
65
|
|
|
64
66
|
nonfinite_grads: bool = True
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Any, Literal, Protocol, runtime_checkable
|
|
3
5
|
|
|
@@ -18,8 +20,8 @@ class HasGradSkippedSteps(Protocol):
|
|
|
18
20
|
grad_skipped_steps: Any
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
class
|
|
22
|
-
def __init__(self, config: "
|
|
23
|
+
class GradientSkippingCallback(Callback):
|
|
24
|
+
def __init__(self, config: "GradientSkippingCallbackConfig"):
|
|
23
25
|
super().__init__()
|
|
24
26
|
self.config = config
|
|
25
27
|
|
|
@@ -73,7 +75,7 @@ class GradientSkipping(Callback):
|
|
|
73
75
|
)
|
|
74
76
|
|
|
75
77
|
|
|
76
|
-
class
|
|
78
|
+
class GradientSkippingCallbackConfig(CallbackConfigBase):
|
|
77
79
|
name: Literal["gradient_skipping"] = "gradient_skipping"
|
|
78
80
|
|
|
79
81
|
threshold: float
|
|
@@ -94,4 +96,4 @@ class GradientSkippingConfig(CallbackConfigBase):
|
|
|
94
96
|
|
|
95
97
|
@override
|
|
96
98
|
def create_callbacks(self, root_config):
|
|
97
|
-
yield
|
|
99
|
+
yield GradientSkippingCallback(self)
|
|
@@ -1,14 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import math
|
|
3
|
-
from typing import Any
|
|
5
|
+
from typing import Any, Literal
|
|
4
6
|
|
|
5
7
|
from lightning.pytorch import LightningModule, Trainer
|
|
6
8
|
from lightning.pytorch.callbacks import Callback
|
|
7
9
|
from typing_extensions import override
|
|
8
10
|
|
|
11
|
+
from .base import CallbackConfigBase
|
|
12
|
+
|
|
9
13
|
log = logging.getLogger(__name__)
|
|
10
14
|
|
|
11
15
|
|
|
16
|
+
class LogEpochCallbackConfig(CallbackConfigBase):
|
|
17
|
+
name: Literal["log_epoch"] = "log_epoch"
|
|
18
|
+
|
|
19
|
+
@override
|
|
20
|
+
def create_callbacks(self, root_config):
|
|
21
|
+
yield LogEpochCallback()
|
|
22
|
+
|
|
23
|
+
|
|
12
24
|
class LogEpochCallback(Callback):
|
|
13
25
|
def __init__(self, metric_name: str = "computed_epoch"):
|
|
14
26
|
super().__init__()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Literal, cast
|
|
3
5
|
|
|
@@ -96,7 +98,7 @@ def compute_norm(
|
|
|
96
98
|
|
|
97
99
|
|
|
98
100
|
class NormLoggingCallback(Callback):
|
|
99
|
-
def __init__(self, config: "
|
|
101
|
+
def __init__(self, config: "NormLoggingCallbackConfig"):
|
|
100
102
|
super().__init__()
|
|
101
103
|
|
|
102
104
|
self.config = config
|
|
@@ -155,7 +157,7 @@ class NormLoggingCallback(Callback):
|
|
|
155
157
|
)
|
|
156
158
|
|
|
157
159
|
|
|
158
|
-
class
|
|
160
|
+
class NormLoggingCallbackConfig(CallbackConfigBase):
|
|
159
161
|
name: Literal["norm_logging"] = "norm_logging"
|
|
160
162
|
|
|
161
163
|
log_grad_norm: bool | str | float = False
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import copy
|
|
2
4
|
import fnmatch
|
|
3
5
|
import importlib.util
|
|
@@ -74,7 +76,7 @@ class PrintTableMetricsCallback(Callback):
|
|
|
74
76
|
return table
|
|
75
77
|
|
|
76
78
|
|
|
77
|
-
class
|
|
79
|
+
class PrintTableMetricsCallbackConfig(CallbackConfigBase):
|
|
78
80
|
"""Configuration class for PrintTableMetricsCallback."""
|
|
79
81
|
|
|
80
82
|
name: Literal["print_table_metrics"] = "print_table_metrics"
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from collections.abc import Mapping
|
|
3
5
|
from typing import Literal, cast
|
|
@@ -16,7 +18,7 @@ from .base import CallbackConfigBase
|
|
|
16
18
|
log = logging.getLogger(__name__)
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
class
|
|
21
|
+
class RLPSanityChecksCallbackConfig(CallbackConfigBase):
|
|
20
22
|
"""
|
|
21
23
|
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
|
22
24
|
- If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
|
|
@@ -43,7 +45,7 @@ class RLPSanityChecksConfig(CallbackConfigBase):
|
|
|
43
45
|
|
|
44
46
|
class RLPSanityChecksCallback(Callback):
|
|
45
47
|
@override
|
|
46
|
-
def __init__(self, config:
|
|
48
|
+
def __init__(self, config: RLPSanityChecksCallbackConfig):
|
|
47
49
|
super().__init__()
|
|
48
50
|
|
|
49
51
|
self.config = config
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from collections.abc import Iterable
|
|
3
5
|
from typing import Literal, Protocol, TypeAlias, runtime_checkable
|
|
@@ -17,7 +19,7 @@ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
|
|
|
17
19
|
return [mapping[id(p)] for p in parameters]
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
class
|
|
22
|
+
class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
21
23
|
"""A callback that allows scaling the gradients of shared parameters that
|
|
22
24
|
are registered in the ``self.shared_parameters`` list of the root module.
|
|
23
25
|
|
|
@@ -43,7 +45,7 @@ class ModuleWithSharedParameters(Protocol):
|
|
|
43
45
|
|
|
44
46
|
class SharedParametersCallback(Callback):
|
|
45
47
|
@override
|
|
46
|
-
def __init__(self, config:
|
|
48
|
+
def __init__(self, config: SharedParametersCallbackConfig):
|
|
47
49
|
super().__init__()
|
|
48
50
|
|
|
49
51
|
self.config = config
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import time
|
|
3
5
|
from typing import Any, Literal
|
|
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
|
|
|
12
14
|
log = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class
|
|
17
|
+
class EpochTimerCallback(Callback):
|
|
16
18
|
def __init__(self):
|
|
17
19
|
super().__init__()
|
|
18
20
|
|
|
@@ -149,9 +151,9 @@ class EpochTimer(Callback):
|
|
|
149
151
|
self._total_batches = state_dict["total_batches"]
|
|
150
152
|
|
|
151
153
|
|
|
152
|
-
class
|
|
154
|
+
class EpochTimerCallbackConfig(CallbackConfigBase):
|
|
153
155
|
name: Literal["epoch_timer"] = "epoch_timer"
|
|
154
156
|
|
|
155
157
|
@override
|
|
156
158
|
def create_callbacks(self, root_config):
|
|
157
|
-
yield
|
|
159
|
+
yield EpochTimerCallback()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import os
|
|
3
5
|
from pathlib import Path
|
|
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
|
|
|
14
16
|
log = logging.getLogger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
17
|
-
class
|
|
19
|
+
class WandbUploadCodeCallbackConfig(CallbackConfigBase):
|
|
18
20
|
name: Literal["wandb_upload_code"] = "wandb_upload_code"
|
|
19
21
|
|
|
20
22
|
enabled: bool = True
|
|
@@ -32,7 +34,7 @@ class WandbUploadCodeConfig(CallbackConfigBase):
|
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
class WandbUploadCodeCallback(Callback):
|
|
35
|
-
def __init__(self, config:
|
|
37
|
+
def __init__(self, config: WandbUploadCodeCallbackConfig):
|
|
36
38
|
super().__init__()
|
|
37
39
|
|
|
38
40
|
self.config = config
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Literal, Protocol, cast, runtime_checkable
|
|
3
5
|
|
|
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
|
|
|
12
14
|
log = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class
|
|
17
|
+
class WandbWatchCallbackConfig(CallbackConfigBase):
|
|
16
18
|
name: Literal["wandb_watch"] = "wandb_watch"
|
|
17
19
|
|
|
18
20
|
enabled: bool = True
|
|
@@ -41,7 +43,7 @@ class _HasWandbLogModuleProtocol(Protocol):
|
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
class WandbWatchCallback(Callback):
|
|
44
|
-
def __init__(self, config:
|
|
46
|
+
def __init__(self, config: WandbWatchCallbackConfig):
|
|
45
47
|
super().__init__()
|
|
46
48
|
|
|
47
49
|
self.config = config
|