nshtrainer 1.0.0b54__tar.gz → 1.0.0b56__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/pyproject.toml +4 -4
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/__init__.py +2 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/lr_scheduler/__init__.py +2 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +2 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/nn/__init__.py +4 -0
- nshtrainer-1.0.0b56/src/nshtrainer/configs/nn/rng/__init__.py +9 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -2
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -2
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -2
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/nn/__init__.py +2 -1
- nshtrainer-1.0.0b56/src/nshtrainer/nn/rng.py +23 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/__init__.py +9 -1
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/base.py +1 -8
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/environment.py +83 -25
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/io.py +32 -33
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/layer_sync.py +3 -4
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/plugin/precision.py +8 -9
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/README.md +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/metric_validation.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/lr_scheduler/base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/_config.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,13 +1,13 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "nshtrainer"
|
3
|
-
version = "1.0.0-
|
3
|
+
version = "1.0.0-beta56"
|
4
4
|
description = ""
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
6
6
|
readme = "README.md"
|
7
7
|
|
8
8
|
[tool.poetry.dependencies]
|
9
9
|
python = "^3.10"
|
10
|
-
nshrunner = { version = "*"
|
10
|
+
nshrunner = { version = "*" }
|
11
11
|
nshconfig = "*"
|
12
12
|
nshutils = { version = "*", optional = true }
|
13
13
|
psutil = "*"
|
@@ -25,7 +25,7 @@ tensorboard = { version = "*", optional = true }
|
|
25
25
|
huggingface-hub = { version = "*", optional = true }
|
26
26
|
|
27
27
|
[tool.poetry.group.dev.dependencies]
|
28
|
-
|
28
|
+
basedpyright = "*"
|
29
29
|
ruff = "*"
|
30
30
|
ipykernel = "*"
|
31
31
|
ipywidgets = "*"
|
@@ -36,7 +36,7 @@ pytest-cov = "^6.0.0"
|
|
36
36
|
requires = ["poetry-core"]
|
37
37
|
build-backend = "poetry.core.masonry.api"
|
38
38
|
|
39
|
-
[tool.
|
39
|
+
[tool.basedpyright]
|
40
40
|
typeCheckingMode = "standard"
|
41
41
|
deprecateTypingAliases = true
|
42
42
|
strictListInference = true
|
@@ -85,6 +85,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
85
85
|
from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
|
86
86
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
87
87
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
88
|
+
from nshtrainer.nn import RNGConfig as RNGConfig
|
88
89
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
89
90
|
from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
90
91
|
from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
@@ -306,6 +307,7 @@ __all__ = [
|
|
306
307
|
"ProfilerConfig",
|
307
308
|
"PyTorchProfilerConfig",
|
308
309
|
"RLPSanityChecksCallbackConfig",
|
310
|
+
"RNGConfig",
|
309
311
|
"ReLUNonlinearityConfig",
|
310
312
|
"ReduceLROnPlateauConfig",
|
311
313
|
"SLURMEnvironmentPlugin",
|
@@ -12,6 +12,7 @@ from nshtrainer.lr_scheduler.base import lr_scheduler_registry as lr_scheduler_r
|
|
12
12
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
13
13
|
DurationConfig as DurationConfig,
|
14
14
|
)
|
15
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
|
15
16
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import MetricConfig as MetricConfig
|
16
17
|
|
17
18
|
from . import base as base
|
@@ -20,6 +21,7 @@ from . import reduce_lr_on_plateau as reduce_lr_on_plateau
|
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
"DurationConfig",
|
24
|
+
"EpochsConfig",
|
23
25
|
"LRSchedulerConfig",
|
24
26
|
"LRSchedulerConfigBase",
|
25
27
|
"LinearWarmupCosineDecayLRSchedulerConfig",
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import EpochsConfig as EpochsConfig
|
5
6
|
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
6
7
|
LRSchedulerConfigBase as LRSchedulerConfigBase,
|
7
8
|
)
|
@@ -14,6 +15,7 @@ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import (
|
|
14
15
|
)
|
15
16
|
|
16
17
|
__all__ = [
|
18
|
+
"EpochsConfig",
|
17
19
|
"LRSchedulerConfigBase",
|
18
20
|
"MetricConfig",
|
19
21
|
"ReduceLROnPlateauConfig",
|
@@ -11,6 +11,7 @@ from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
|
11
11
|
from nshtrainer.nn import NonlinearityConfigBase as NonlinearityConfigBase
|
12
12
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
13
13
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
14
|
+
from nshtrainer.nn import RNGConfig as RNGConfig
|
14
15
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
15
16
|
from nshtrainer.nn import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
16
17
|
from nshtrainer.nn import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
@@ -25,6 +26,7 @@ from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_reg
|
|
25
26
|
|
26
27
|
from . import mlp as mlp
|
27
28
|
from . import nonlinearity as nonlinearity
|
29
|
+
from . import rng as rng
|
28
30
|
|
29
31
|
__all__ = [
|
30
32
|
"ELUNonlinearityConfig",
|
@@ -35,6 +37,7 @@ __all__ = [
|
|
35
37
|
"NonlinearityConfig",
|
36
38
|
"NonlinearityConfigBase",
|
37
39
|
"PReLUConfig",
|
40
|
+
"RNGConfig",
|
38
41
|
"ReLUNonlinearityConfig",
|
39
42
|
"SiLUNonlinearityConfig",
|
40
43
|
"SigmoidNonlinearityConfig",
|
@@ -47,4 +50,5 @@ __all__ = [
|
|
47
50
|
"mlp",
|
48
51
|
"nonlinearity",
|
49
52
|
"nonlinearity_registry",
|
53
|
+
"rng",
|
50
54
|
]
|
{nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/base/__init__.py
RENAMED
@@ -2,12 +2,10 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
-
from nshtrainer.trainer.plugin.base import PluginConfig as PluginConfig
|
6
5
|
from nshtrainer.trainer.plugin.base import PluginConfigBase as PluginConfigBase
|
7
6
|
from nshtrainer.trainer.plugin.base import plugin_registry as plugin_registry
|
8
7
|
|
9
8
|
__all__ = [
|
10
|
-
"PluginConfig",
|
11
9
|
"PluginConfigBase",
|
12
10
|
"plugin_registry",
|
13
11
|
]
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__codegen__ = True
|
4
4
|
|
5
|
-
from nshtrainer.trainer.plugin.environment import DTypeConfig as DTypeConfig
|
6
5
|
from nshtrainer.trainer.plugin.environment import (
|
7
6
|
KubeflowEnvironmentPlugin as KubeflowEnvironmentPlugin,
|
8
7
|
)
|
@@ -28,7 +27,6 @@ from nshtrainer.trainer.plugin.environment import (
|
|
28
27
|
from nshtrainer.trainer.plugin.environment import plugin_registry as plugin_registry
|
29
28
|
|
30
29
|
__all__ = [
|
31
|
-
"DTypeConfig",
|
32
30
|
"KubeflowEnvironmentPlugin",
|
33
31
|
"LSFEnvironmentPlugin",
|
34
32
|
"LightningEnvironmentPlugin",
|
{nshtrainer-1.0.0b54 → nshtrainer-1.0.0b56}/src/nshtrainer/configs/trainer/plugin/io/__init__.py
RENAMED
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.trainer.plugin.io import (
|
6
6
|
AsyncCheckpointIOPlugin as AsyncCheckpointIOPlugin,
|
7
7
|
)
|
8
|
-
from nshtrainer.trainer.plugin.io import PluginConfig as PluginConfig
|
9
8
|
from nshtrainer.trainer.plugin.io import PluginConfigBase as PluginConfigBase
|
10
9
|
from nshtrainer.trainer.plugin.io import (
|
11
10
|
TorchCheckpointIOPlugin as TorchCheckpointIOPlugin,
|
@@ -15,7 +14,6 @@ from nshtrainer.trainer.plugin.io import plugin_registry as plugin_registry
|
|
15
14
|
|
16
15
|
__all__ = [
|
17
16
|
"AsyncCheckpointIOPlugin",
|
18
|
-
"PluginConfig",
|
19
17
|
"PluginConfigBase",
|
20
18
|
"TorchCheckpointIOPlugin",
|
21
19
|
"XLACheckpointIOPlugin",
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
from .mlp import MLP as MLP
|
4
4
|
from .mlp import MLPConfig as MLPConfig
|
5
5
|
from .mlp import ResidualSequential as ResidualSequential
|
6
|
-
from .mlp import custom_seed_context as custom_seed_context
|
7
6
|
from .module_dict import TypedModuleDict as TypedModuleDict
|
8
7
|
from .module_list import TypedModuleList as TypedModuleList
|
9
8
|
from .nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
@@ -21,3 +20,5 @@ from .nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConf
|
|
21
20
|
from .nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
|
22
21
|
from .nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
|
23
22
|
from .nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
|
23
|
+
from .rng import RNGConfig as RNGConfig
|
24
|
+
from .rng import rng_context as rng_context
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
|
5
|
+
import nshconfig as C
|
6
|
+
import torch
|
7
|
+
|
8
|
+
|
9
|
+
@contextlib.contextmanager
|
10
|
+
def rng_context(config: RNGConfig | None):
|
11
|
+
with contextlib.ExitStack() as stack:
|
12
|
+
if config is not None:
|
13
|
+
stack.enter_context(
|
14
|
+
torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
15
|
+
)
|
16
|
+
torch.manual_seed(config.seed)
|
17
|
+
|
18
|
+
yield
|
19
|
+
|
20
|
+
|
21
|
+
class RNGConfig(C.Config):
|
22
|
+
seed: int
|
23
|
+
"""Random seed to use for initialization."""
|
@@ -1,10 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from typing import Annotated
|
4
|
+
|
5
|
+
from typing_extensions import TypeAliasType
|
6
|
+
|
3
7
|
from . import environment as environment
|
4
8
|
from . import io as io
|
5
9
|
from . import layer_sync as layer_sync
|
6
10
|
from . import precision as precision
|
7
11
|
from .base import Plugin as Plugin
|
8
|
-
from .base import PluginConfig as PluginConfig
|
9
12
|
from .base import PluginConfigBase as PluginConfigBase
|
10
13
|
from .base import plugin_registry as plugin_registry
|
14
|
+
|
15
|
+
PluginConfig = TypeAliasType(
|
16
|
+
"PluginConfig",
|
17
|
+
Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
|
18
|
+
)
|
@@ -1,8 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import logging
|
4
3
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import TYPE_CHECKING
|
4
|
+
from typing import TYPE_CHECKING
|
6
5
|
|
7
6
|
import nshconfig as C
|
8
7
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
@@ -12,7 +11,6 @@ from typing_extensions import TypeAliasType
|
|
12
11
|
|
13
12
|
if TYPE_CHECKING:
|
14
13
|
from .._config import TrainerConfig
|
15
|
-
log = logging.getLogger(__name__)
|
16
14
|
|
17
15
|
|
18
16
|
Plugin = TypeAliasType(
|
@@ -26,8 +24,3 @@ class PluginConfigBase(C.Config, ABC):
|
|
26
24
|
|
27
25
|
|
28
26
|
plugin_registry = C.Registry(PluginConfigBase, discriminator="name")
|
29
|
-
|
30
|
-
PluginConfig = TypeAliasType(
|
31
|
-
"PluginConfig",
|
32
|
-
Annotated[PluginConfigBase, plugin_registry.DynamicResolution()],
|
33
|
-
)
|
@@ -1,27 +1,25 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import signal
|
4
|
-
from typing import
|
4
|
+
from typing import Literal
|
5
5
|
|
6
|
-
from
|
7
|
-
from typing_extensions import override
|
6
|
+
from typing_extensions import TypeAliasType, override
|
8
7
|
|
9
|
-
from ...util.config.dtype import DTypeConfig
|
10
8
|
from .base import PluginConfigBase, plugin_registry
|
11
9
|
|
12
10
|
|
13
11
|
@plugin_registry.register
|
14
12
|
class KubeflowEnvironmentPlugin(PluginConfigBase):
|
15
|
-
name: Literal["kubeflow_environment"] = "kubeflow_environment"
|
16
|
-
|
17
13
|
"""Environment for distributed training using the PyTorchJob operator from Kubeflow.
|
18
14
|
|
19
15
|
This environment, unlike others, does not get auto-detected and needs to be passed
|
20
16
|
to the Fabric/Trainer constructor manually.
|
21
17
|
"""
|
22
18
|
|
19
|
+
name: Literal["kubeflow_environment"] = "kubeflow_environment"
|
20
|
+
|
23
21
|
@override
|
24
|
-
def create_plugin(self, trainer_config)
|
22
|
+
def create_plugin(self, trainer_config):
|
25
23
|
from lightning.fabric.plugins.environments.kubeflow import KubeflowEnvironment
|
26
24
|
|
27
25
|
return KubeflowEnvironment()
|
@@ -29,8 +27,6 @@ class KubeflowEnvironmentPlugin(PluginConfigBase):
|
|
29
27
|
|
30
28
|
@plugin_registry.register
|
31
29
|
class LightningEnvironmentPlugin(PluginConfigBase):
|
32
|
-
name: Literal["lightning_environment"] = "lightning_environment"
|
33
|
-
|
34
30
|
"""The default environment used by Lightning for a single node or free cluster (not managed).
|
35
31
|
|
36
32
|
There are two modes the Lightning environment can operate with:
|
@@ -40,8 +36,10 @@ class LightningEnvironmentPlugin(PluginConfigBase):
|
|
40
36
|
The appropriate environment variables need to be set, and at minimum `LOCAL_RANK`.
|
41
37
|
"""
|
42
38
|
|
39
|
+
name: Literal["lightning_environment"] = "lightning_environment"
|
40
|
+
|
43
41
|
@override
|
44
|
-
def create_plugin(self, trainer_config)
|
42
|
+
def create_plugin(self, trainer_config):
|
45
43
|
from lightning.fabric.plugins.environments.lightning import LightningEnvironment
|
46
44
|
|
47
45
|
return LightningEnvironment()
|
@@ -49,16 +47,16 @@ class LightningEnvironmentPlugin(PluginConfigBase):
|
|
49
47
|
|
50
48
|
@plugin_registry.register
|
51
49
|
class LSFEnvironmentPlugin(PluginConfigBase):
|
52
|
-
name: Literal["lsf_environment"] = "lsf_environment"
|
53
|
-
|
54
50
|
"""An environment for running on clusters managed by the LSF resource manager.
|
55
51
|
|
56
52
|
It is expected that any execution using this ClusterEnvironment was executed
|
57
53
|
using the Job Step Manager i.e. `jsrun`.
|
58
54
|
"""
|
59
55
|
|
56
|
+
name: Literal["lsf_environment"] = "lsf_environment"
|
57
|
+
|
60
58
|
@override
|
61
|
-
def create_plugin(self, trainer_config)
|
59
|
+
def create_plugin(self, trainer_config):
|
62
60
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
63
61
|
|
64
62
|
return LSFEnvironment()
|
@@ -66,48 +64,108 @@ class LSFEnvironmentPlugin(PluginConfigBase):
|
|
66
64
|
|
67
65
|
@plugin_registry.register
|
68
66
|
class MPIEnvironmentPlugin(PluginConfigBase):
|
69
|
-
name: Literal["mpi_environment"] = "mpi_environment"
|
70
|
-
|
71
67
|
"""An environment for running on clusters with processes created through MPI.
|
72
68
|
|
73
69
|
Requires the installation of the `mpi4py` package.
|
74
70
|
"""
|
75
71
|
|
72
|
+
name: Literal["mpi_environment"] = "mpi_environment"
|
73
|
+
|
76
74
|
@override
|
77
|
-
def create_plugin(self, trainer_config)
|
75
|
+
def create_plugin(self, trainer_config):
|
78
76
|
from lightning.fabric.plugins.environments.mpi import MPIEnvironment
|
79
77
|
|
80
78
|
return MPIEnvironment()
|
81
79
|
|
82
80
|
|
81
|
+
SignalAlias = TypeAliasType(
|
82
|
+
"SignalAlias",
|
83
|
+
Literal[
|
84
|
+
"SIGABRT",
|
85
|
+
"SIGFPE",
|
86
|
+
"SIGILL",
|
87
|
+
"SIGINT",
|
88
|
+
"SIGSEGV",
|
89
|
+
"SIGTERM",
|
90
|
+
"SIGBREAK",
|
91
|
+
"CTRL_C_EVENT",
|
92
|
+
"CTRL_BREAK_EVENT",
|
93
|
+
"SIGALRM",
|
94
|
+
"SIGBUS",
|
95
|
+
"SIGCHLD",
|
96
|
+
"SIGCONT",
|
97
|
+
"SIGHUP",
|
98
|
+
"SIGIO",
|
99
|
+
"SIGIOT",
|
100
|
+
"SIGKILL",
|
101
|
+
"SIGPIPE",
|
102
|
+
"SIGPROF",
|
103
|
+
"SIGQUIT",
|
104
|
+
"SIGSTOP",
|
105
|
+
"SIGSYS",
|
106
|
+
"SIGTRAP",
|
107
|
+
"SIGTSTP",
|
108
|
+
"SIGTTIN",
|
109
|
+
"SIGTTOU",
|
110
|
+
"SIGURG",
|
111
|
+
"SIGUSR1",
|
112
|
+
"SIGUSR2",
|
113
|
+
"SIGVTALRM",
|
114
|
+
"SIGWINCH",
|
115
|
+
"SIGXCPU",
|
116
|
+
"SIGXFSZ",
|
117
|
+
"SIGEMT",
|
118
|
+
"SIGINFO",
|
119
|
+
"SIGCLD",
|
120
|
+
"SIGPOLL",
|
121
|
+
"SIGPWR",
|
122
|
+
"SIGRTMAX",
|
123
|
+
"SIGRTMIN",
|
124
|
+
"SIGSTKFLT",
|
125
|
+
],
|
126
|
+
)
|
127
|
+
|
128
|
+
|
83
129
|
@plugin_registry.register
|
84
130
|
class SLURMEnvironmentPlugin(PluginConfigBase):
|
131
|
+
"""An environment for running on clusters managed by the SLURM resource manager."""
|
132
|
+
|
85
133
|
name: Literal["slurm_environment"] = "slurm_environment"
|
86
134
|
|
87
135
|
auto_requeue: bool = True
|
88
136
|
"""Whether automatic job resubmission is enabled or not."""
|
89
137
|
|
90
|
-
requeue_signal:
|
138
|
+
requeue_signal: SignalAlias | None = None
|
91
139
|
"""The signal that SLURM will send to indicate that the job should be requeued."""
|
92
140
|
|
93
141
|
@override
|
94
|
-
def create_plugin(self, trainer_config)
|
142
|
+
def create_plugin(self, trainer_config):
|
95
143
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
96
144
|
|
145
|
+
requeue_signal = None
|
146
|
+
if self.requeue_signal is not None:
|
147
|
+
try:
|
148
|
+
requeue_signal = signal.Signals[self.requeue_signal]
|
149
|
+
except KeyError:
|
150
|
+
raise ValueError(
|
151
|
+
f"Invalid signal name: {self.requeue_signal}. "
|
152
|
+
"Please provide a valid signal name from the signal module."
|
153
|
+
)
|
154
|
+
|
97
155
|
return SLURMEnvironment(
|
98
156
|
auto_requeue=self.auto_requeue,
|
99
|
-
requeue_signal=
|
157
|
+
requeue_signal=requeue_signal,
|
100
158
|
)
|
101
159
|
|
102
160
|
|
103
161
|
@plugin_registry.register
|
104
162
|
class TorchElasticEnvironmentPlugin(PluginConfigBase):
|
105
|
-
name: Literal["torchelastic_environment"] = "torchelastic_environment"
|
106
|
-
|
107
163
|
"""Environment for fault-tolerant and elastic training with torchelastic."""
|
108
164
|
|
165
|
+
name: Literal["torchelastic_environment"] = "torchelastic_environment"
|
166
|
+
|
109
167
|
@override
|
110
|
-
def create_plugin(self, trainer_config)
|
168
|
+
def create_plugin(self, trainer_config):
|
111
169
|
from lightning.fabric.plugins.environments.torchelastic import (
|
112
170
|
TorchElasticEnvironment,
|
113
171
|
)
|
@@ -117,12 +175,12 @@ class TorchElasticEnvironmentPlugin(PluginConfigBase):
|
|
117
175
|
|
118
176
|
@plugin_registry.register
|
119
177
|
class XLAEnvironmentPlugin(PluginConfigBase):
|
120
|
-
name: Literal["xla_environment"] = "xla_environment"
|
121
|
-
|
122
178
|
"""Cluster environment for training on a TPU Pod with the PyTorch/XLA library."""
|
123
179
|
|
180
|
+
name: Literal["xla_environment"] = "xla_environment"
|
181
|
+
|
124
182
|
@override
|
125
|
-
def create_plugin(self, trainer_config)
|
183
|
+
def create_plugin(self, trainer_config):
|
126
184
|
from lightning.fabric.plugins.environments.xla import XLAEnvironment
|
127
185
|
|
128
186
|
return XLAEnvironment()
|
@@ -2,27 +2,52 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
|
-
from lightning.pytorch.plugins.io import CheckpointIO
|
6
5
|
from typing_extensions import override
|
7
6
|
|
8
|
-
from .base import
|
7
|
+
from .base import PluginConfigBase, plugin_registry
|
9
8
|
|
10
9
|
|
11
10
|
@plugin_registry.register
|
12
|
-
class
|
13
|
-
|
11
|
+
class TorchCheckpointIOPlugin(PluginConfigBase):
|
12
|
+
"""CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
|
13
|
+
|
14
|
+
name: Literal["torch_checkpoint"] = "torch_checkpoint"
|
15
|
+
|
16
|
+
@override
|
17
|
+
def create_plugin(self, trainer_config):
|
18
|
+
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
|
19
|
+
|
20
|
+
return TorchCheckpointIO()
|
21
|
+
|
22
|
+
|
23
|
+
@plugin_registry.register
|
24
|
+
class XLACheckpointIOPlugin(PluginConfigBase):
|
25
|
+
"""CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
|
14
26
|
|
27
|
+
name: Literal["xla_checkpoint"] = "xla_checkpoint"
|
28
|
+
|
29
|
+
@override
|
30
|
+
def create_plugin(self, trainer_config):
|
31
|
+
from lightning.fabric.plugins.io.xla import XLACheckpointIO
|
32
|
+
|
33
|
+
return XLACheckpointIO()
|
34
|
+
|
35
|
+
|
36
|
+
@plugin_registry.register
|
37
|
+
class AsyncCheckpointIOPlugin(PluginConfigBase):
|
15
38
|
"""Enables saving the checkpoints asynchronously in a thread.
|
16
39
|
|
17
40
|
.. warning:: This is an experimental feature.
|
18
41
|
"""
|
19
42
|
|
20
|
-
|
43
|
+
name: Literal["async_checkpoint"] = "async_checkpoint"
|
44
|
+
|
45
|
+
checkpoint_io: TorchCheckpointIOPlugin | None = None
|
21
46
|
"""A checkpoint IO plugin that is used as the basis for async checkpointing."""
|
22
47
|
|
23
48
|
@override
|
24
|
-
def create_plugin(self, trainer_config)
|
25
|
-
from lightning.pytorch.plugins.io
|
49
|
+
def create_plugin(self, trainer_config):
|
50
|
+
from lightning.pytorch.plugins.io import AsyncCheckpointIO, CheckpointIO
|
26
51
|
|
27
52
|
base_io = (
|
28
53
|
self.checkpoint_io.create_plugin(trainer_config)
|
@@ -34,29 +59,3 @@ class AsyncCheckpointIOPlugin(PluginConfigBase):
|
|
34
59
|
f"Expected `checkpoint_io` to be a `CheckpointIO` instance, but got {type(base_io)}."
|
35
60
|
)
|
36
61
|
return AsyncCheckpointIO(checkpoint_io=base_io)
|
37
|
-
|
38
|
-
|
39
|
-
@plugin_registry.register
|
40
|
-
class TorchCheckpointIOPlugin(PluginConfigBase):
|
41
|
-
name: Literal["torch_checkpoint"] = "torch_checkpoint"
|
42
|
-
|
43
|
-
"""CheckpointIO that utilizes torch.save and torch.load to save and load checkpoints respectively."""
|
44
|
-
|
45
|
-
@override
|
46
|
-
def create_plugin(self, trainer_config) -> CheckpointIO:
|
47
|
-
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
|
48
|
-
|
49
|
-
return TorchCheckpointIO()
|
50
|
-
|
51
|
-
|
52
|
-
@plugin_registry.register
|
53
|
-
class XLACheckpointIOPlugin(PluginConfigBase):
|
54
|
-
name: Literal["xla_checkpoint"] = "xla_checkpoint"
|
55
|
-
|
56
|
-
"""CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies."""
|
57
|
-
|
58
|
-
@override
|
59
|
-
def create_plugin(self, trainer_config) -> CheckpointIO:
|
60
|
-
from lightning.fabric.plugins.io.xla import XLACheckpointIO
|
61
|
-
|
62
|
-
return XLACheckpointIO()
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
|
-
from lightning.pytorch.plugins.layer_sync import LayerSync
|
6
5
|
from typing_extensions import override
|
7
6
|
|
8
7
|
from .base import PluginConfigBase, plugin_registry
|
@@ -10,16 +9,16 @@ from .base import PluginConfigBase, plugin_registry
|
|
10
9
|
|
11
10
|
@plugin_registry.register
|
12
11
|
class TorchSyncBatchNormPlugin(PluginConfigBase):
|
13
|
-
name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
|
14
|
-
|
15
12
|
"""A plugin that wraps all batch normalization layers of a model with synchronization
|
16
13
|
logic for multiprocessing.
|
17
14
|
|
18
15
|
This plugin has no effect in single-device operation.
|
19
16
|
"""
|
20
17
|
|
18
|
+
name: Literal["torch_sync_batchnorm"] = "torch_sync_batchnorm"
|
19
|
+
|
21
20
|
@override
|
22
|
-
def create_plugin(self, trainer_config)
|
21
|
+
def create_plugin(self, trainer_config):
|
23
22
|
from lightning.pytorch.plugins.layer_sync import TorchSyncBatchNorm
|
24
23
|
|
25
24
|
return TorchSyncBatchNorm()
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Any, Literal
|
4
4
|
|
5
|
-
from lightning.pytorch.plugins.precision import Precision
|
6
5
|
from typing_extensions import override
|
7
6
|
|
8
7
|
from ...util.config.dtype import DTypeConfig
|
@@ -20,7 +19,7 @@ class MixedPrecisionPluginConfig(PluginConfigBase):
|
|
20
19
|
"""The device for ``torch.autocast``."""
|
21
20
|
|
22
21
|
@override
|
23
|
-
def create_plugin(self, trainer_config)
|
22
|
+
def create_plugin(self, trainer_config):
|
24
23
|
from lightning.pytorch.plugins.precision.amp import MixedPrecision
|
25
24
|
|
26
25
|
return MixedPrecision(self.precision, self.device)
|
@@ -45,7 +44,7 @@ class BitsandbytesPluginConfig(PluginConfigBase):
|
|
45
44
|
"""
|
46
45
|
|
47
46
|
@override
|
48
|
-
def create_plugin(self, trainer_config)
|
47
|
+
def create_plugin(self, trainer_config):
|
49
48
|
from lightning.pytorch.plugins.precision.bitsandbytes import (
|
50
49
|
BitsandbytesPrecision,
|
51
50
|
)
|
@@ -66,7 +65,7 @@ class DeepSpeedPluginConfig(PluginConfigBase):
|
|
66
65
|
mixed precision (16-mixed, bf16-mixed)."""
|
67
66
|
|
68
67
|
@override
|
69
|
-
def create_plugin(self, trainer_config)
|
68
|
+
def create_plugin(self, trainer_config):
|
70
69
|
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision
|
71
70
|
|
72
71
|
return DeepSpeedPrecision(precision=self.precision)
|
@@ -80,7 +79,7 @@ class DoublePrecisionPluginConfig(PluginConfigBase):
|
|
80
79
|
"""Plugin for training with double (``torch.float64``) precision."""
|
81
80
|
|
82
81
|
@override
|
83
|
-
def create_plugin(self, trainer_config)
|
82
|
+
def create_plugin(self, trainer_config):
|
84
83
|
from lightning.pytorch.plugins.precision.double import DoublePrecision
|
85
84
|
|
86
85
|
return DoublePrecision()
|
@@ -95,7 +94,7 @@ class FSDPPrecisionPluginConfig(PluginConfigBase):
|
|
95
94
|
mixed precision (16-mixed, bf16-mixed)."""
|
96
95
|
|
97
96
|
@override
|
98
|
-
def create_plugin(self, trainer_config)
|
97
|
+
def create_plugin(self, trainer_config):
|
99
98
|
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
|
100
99
|
|
101
100
|
return FSDPPrecision(precision=self.precision)
|
@@ -109,7 +108,7 @@ class HalfPrecisionPluginConfig(PluginConfigBase):
|
|
109
108
|
"""Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``)."""
|
110
109
|
|
111
110
|
@override
|
112
|
-
def create_plugin(self, trainer_config)
|
111
|
+
def create_plugin(self, trainer_config):
|
113
112
|
from lightning.pytorch.plugins.precision.half import HalfPrecision
|
114
113
|
|
115
114
|
return HalfPrecision(precision=self.precision)
|
@@ -134,7 +133,7 @@ class TransformerEnginePluginConfig(PluginConfigBase):
|
|
134
133
|
Defaults to the same as weights_dtype."""
|
135
134
|
|
136
135
|
@override
|
137
|
-
def create_plugin(self, trainer_config)
|
136
|
+
def create_plugin(self, trainer_config):
|
138
137
|
from lightning.pytorch.plugins.precision.transformer_engine import (
|
139
138
|
TransformerEnginePrecision,
|
140
139
|
)
|
@@ -157,7 +156,7 @@ class XLAPluginConfig(PluginConfigBase):
|
|
157
156
|
"""Full precision (32-true) or half precision (16-true, bf16-true)."""
|
158
157
|
|
159
158
|
@override
|
160
|
-
def create_plugin(self, trainer_config)
|
159
|
+
def create_plugin(self, trainer_config):
|
161
160
|
from lightning.pytorch.plugins.precision.xla import XLAPrecision
|
162
161
|
|
163
162
|
return XLAPrecision(precision=self.precision)
|
File without changes
|
File without changes
|
File without changes
|