nshtrainer 1.1.1b1__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/PKG-INFO +1 -1
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/pyproject.toml +4 -3
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_directory.py +3 -3
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/__init__.py +6 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/base.py +22 -3
- nshtrainer-1.2.0/src/nshtrainer/callbacks/distributed_prediction_writer.py +166 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/__init__.py +28 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/__init__.py +6 -0
- nshtrainer-1.2.0/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +19 -0
- nshtrainer-1.2.0/src/nshtrainer/configs/optimizer/__init__.py +39 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/__init__.py +4 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/_config/__init__.py +4 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/base.py +60 -2
- nshtrainer-1.2.0/src/nshtrainer/optimizer.py +626 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_config.py +10 -4
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/trainer.py +21 -2
- nshtrainer-1.1.1b1/src/nshtrainer/configs/optimizer/__init__.py +0 -15
- nshtrainer-1.1.1b1/src/nshtrainer/optimizer.py +0 -68
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/README.md +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/base.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/nn/rng.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.1.1b1 → nshtrainer-1.2.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "nshtrainer"
|
3
|
-
version = "1.
|
3
|
+
version = "1.2.0"
|
4
4
|
description = ""
|
5
5
|
authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
|
6
6
|
requires-python = ">=3.10,<4.0"
|
@@ -33,8 +33,9 @@ basedpyright = "*"
|
|
33
33
|
ruff = "*"
|
34
34
|
ipykernel = "*"
|
35
35
|
ipywidgets = "*"
|
36
|
-
pytest = "
|
37
|
-
pytest-cov = "
|
36
|
+
pytest = "*"
|
37
|
+
pytest-cov = "*"
|
38
|
+
pytest-forked = "*"
|
38
39
|
|
39
40
|
[build-system]
|
40
41
|
requires = ["poetry-core"]
|
@@ -65,9 +65,9 @@ class DirectoryConfig(C.Config):
|
|
65
65
|
) -> Path:
|
66
66
|
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
67
67
|
if (subdir := getattr(self, subdirectory, None)) is not None:
|
68
|
-
assert isinstance(
|
69
|
-
|
70
|
-
)
|
68
|
+
assert isinstance(subdir, Path), (
|
69
|
+
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
70
|
+
)
|
71
71
|
return subdir
|
72
72
|
|
73
73
|
dir = self.resolve_run_root_directory(run_id)
|
@@ -23,6 +23,12 @@ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
|
23
23
|
from .directory_setup import (
|
24
24
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
25
25
|
)
|
26
|
+
from .distributed_prediction_writer import (
|
27
|
+
DistributedPredictionWriter as DistributedPredictionWriter,
|
28
|
+
)
|
29
|
+
from .distributed_prediction_writer import (
|
30
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
31
|
+
)
|
26
32
|
from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
|
27
33
|
from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
|
28
34
|
from .ema import EMACallback as EMACallback
|
@@ -23,6 +23,10 @@ class CallbackMetadataConfig(TypedDict, total=False):
|
|
23
23
|
"""Priority of the callback. Callbacks with higher priority will be loaded first.
|
24
24
|
Default is `0`."""
|
25
25
|
|
26
|
+
enabled_for_barebones: bool
|
27
|
+
"""Whether this callback is enabled for barebones mode.
|
28
|
+
Default is `False`."""
|
29
|
+
|
26
30
|
|
27
31
|
@dataclass(frozen=True)
|
28
32
|
class CallbackWithMetadata:
|
@@ -91,10 +95,20 @@ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
|
|
91
95
|
|
92
96
|
|
93
97
|
def _process_and_filter_callbacks(
|
98
|
+
trainer_config: TrainerConfig,
|
94
99
|
callbacks: Iterable[CallbackWithMetadata],
|
95
100
|
) -> list[Callback]:
|
96
101
|
callbacks = list(callbacks)
|
97
102
|
|
103
|
+
# If we're in barebones mode, used the callback metadata
|
104
|
+
# to decide to keep/remove the callback.
|
105
|
+
if trainer_config.barebones:
|
106
|
+
callbacks = [
|
107
|
+
callback
|
108
|
+
for callback in callbacks
|
109
|
+
if callback.metadata.get("enabled_for_barebones", False)
|
110
|
+
]
|
111
|
+
|
98
112
|
# Sort by priority (higher priority first)
|
99
113
|
callbacks.sort(
|
100
114
|
key=lambda callback: callback.metadata.get("priority", 0),
|
@@ -114,9 +128,14 @@ def resolve_all_callbacks(trainer_config: TrainerConfig):
|
|
114
128
|
if config is not None
|
115
129
|
]
|
116
130
|
callbacks = _process_and_filter_callbacks(
|
117
|
-
|
118
|
-
|
119
|
-
|
131
|
+
trainer_config,
|
132
|
+
(
|
133
|
+
callback
|
134
|
+
for callback_config in callback_configs
|
135
|
+
for callback in _create_callbacks_with_metadata(
|
136
|
+
callback_config, trainer_config
|
137
|
+
)
|
138
|
+
),
|
120
139
|
)
|
121
140
|
return callbacks
|
122
141
|
|
@@ -0,0 +1,166 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import logging
|
5
|
+
from collections.abc import Iterator, Sequence
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, ClassVar, Literal, overload
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from lightning.fabric.utilities.apply_func import move_data_to_device
|
11
|
+
from lightning.pytorch.callbacks import BasePredictionWriter
|
12
|
+
from typing_extensions import final, override
|
13
|
+
|
14
|
+
from .base import CallbackConfigBase, CallbackMetadataConfig, callback_registry
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
@final
|
20
|
+
@callback_registry.register
|
21
|
+
class DistributedPredictionWriterConfig(CallbackConfigBase):
|
22
|
+
metadata: ClassVar[CallbackMetadataConfig] = CallbackMetadataConfig(
|
23
|
+
enabled_for_barebones=True
|
24
|
+
)
|
25
|
+
"""Metadata for the callback."""
|
26
|
+
|
27
|
+
name: Literal["distributed_prediction_writer"] = "distributed_prediction_writer"
|
28
|
+
|
29
|
+
dirpath: Path | None = None
|
30
|
+
"""Directory to save the predictions to. If None, will use the default directory."""
|
31
|
+
|
32
|
+
move_to_cpu_on_save: bool = True
|
33
|
+
"""Whether to move the predictions to CPU before saving. Default is True."""
|
34
|
+
|
35
|
+
save_raw: bool = True
|
36
|
+
"""Whether to save the raw predictions."""
|
37
|
+
|
38
|
+
save_processed: bool = True
|
39
|
+
"""Whether to process and save the predictions.
|
40
|
+
|
41
|
+
"Processing" means that the model's batched predictions are split into individual predictions
|
42
|
+
and saved as a list of tensors.
|
43
|
+
"""
|
44
|
+
|
45
|
+
@override
|
46
|
+
def create_callbacks(self, trainer_config):
|
47
|
+
if (dirpath := self.dirpath) is None:
|
48
|
+
dirpath = trainer_config.directory.resolve_subdirectory(
|
49
|
+
trainer_config.id, "predictions"
|
50
|
+
)
|
51
|
+
|
52
|
+
yield DistributedPredictionWriter(self, dirpath)
|
53
|
+
|
54
|
+
|
55
|
+
def _move_and_save(data, path: Path, move_to_cpu: bool):
|
56
|
+
if move_to_cpu:
|
57
|
+
data = move_data_to_device(data, "cpu")
|
58
|
+
|
59
|
+
# Save the data to the specified path
|
60
|
+
torch.save(data, path)
|
61
|
+
|
62
|
+
|
63
|
+
class DistributedPredictionWriter(BasePredictionWriter):
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
config: DistributedPredictionWriterConfig,
|
67
|
+
output_dir: Path,
|
68
|
+
):
|
69
|
+
self.config = config
|
70
|
+
|
71
|
+
super().__init__(write_interval="batch")
|
72
|
+
|
73
|
+
self.output_dir = output_dir
|
74
|
+
|
75
|
+
@override
|
76
|
+
def write_on_batch_end(
|
77
|
+
self,
|
78
|
+
trainer,
|
79
|
+
pl_module,
|
80
|
+
prediction,
|
81
|
+
batch_indices,
|
82
|
+
batch,
|
83
|
+
batch_idx,
|
84
|
+
dataloader_idx,
|
85
|
+
):
|
86
|
+
save = functools.partial(
|
87
|
+
_move_and_save,
|
88
|
+
move_to_cpu=self.config.move_to_cpu_on_save,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Regular, unstructured writing.
|
92
|
+
if self.config.save_raw:
|
93
|
+
output_dir = (
|
94
|
+
self.output_dir
|
95
|
+
/ "raw"
|
96
|
+
/ f"dataloader_{dataloader_idx}"
|
97
|
+
/ f"rank_{trainer.global_rank}"
|
98
|
+
/ f"batch_{batch_idx}"
|
99
|
+
)
|
100
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
101
|
+
save(prediction, output_dir / "predictions.pt")
|
102
|
+
save(batch, output_dir / "batch.pt")
|
103
|
+
save(batch_indices, output_dir / "batch_indices.pt")
|
104
|
+
|
105
|
+
if self.config.save_processed:
|
106
|
+
# Processed writing.
|
107
|
+
from ..model.base import LightningModuleBase
|
108
|
+
|
109
|
+
if not isinstance(pl_module, LightningModuleBase):
|
110
|
+
raise ValueError(
|
111
|
+
"The model must be a subclass of LightningModuleBase to use the distributed prediction writer."
|
112
|
+
)
|
113
|
+
|
114
|
+
output_dir = self.output_dir / "processed" / f"dataloader_{dataloader_idx}"
|
115
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
116
|
+
|
117
|
+
# Split into individual predictions
|
118
|
+
assert batch_indices is not None, (
|
119
|
+
"Batch indices must be provided for processed writing."
|
120
|
+
)
|
121
|
+
for sample in pl_module.split_batched_predictions(
|
122
|
+
batch, prediction, batch_indices
|
123
|
+
):
|
124
|
+
sample = {
|
125
|
+
**sample,
|
126
|
+
"global_rank": trainer.global_rank,
|
127
|
+
"world_size": trainer.world_size,
|
128
|
+
"is_global_zero": trainer.is_global_zero,
|
129
|
+
}
|
130
|
+
save(sample, output_dir / f"{sample['index']}.pt")
|
131
|
+
|
132
|
+
|
133
|
+
class DistributedPredictionReader(Sequence[tuple[Any, Any]]):
|
134
|
+
def __init__(self, output_dir: Path):
|
135
|
+
self.output_dir = output_dir
|
136
|
+
|
137
|
+
@override
|
138
|
+
def __len__(self) -> int:
|
139
|
+
return len(list(self.output_dir.glob("*.pt")))
|
140
|
+
|
141
|
+
@overload
|
142
|
+
def __getitem__(self, index: int) -> tuple[Any, Any]: ...
|
143
|
+
|
144
|
+
@overload
|
145
|
+
def __getitem__(self, index: slice) -> list[tuple[Any, Any]]: ...
|
146
|
+
|
147
|
+
@override
|
148
|
+
def __getitem__(
|
149
|
+
self, index: int | slice
|
150
|
+
) -> tuple[Any, Any] | list[tuple[Any, Any]]:
|
151
|
+
if isinstance(index, slice):
|
152
|
+
# Handle slice indexing
|
153
|
+
indices = range(*index.indices(len(self)))
|
154
|
+
return [self.__getitem__(i) for i in indices]
|
155
|
+
|
156
|
+
# Handle integer indexing
|
157
|
+
path = self.output_dir / f"{index}.pt"
|
158
|
+
if not path.exists():
|
159
|
+
raise FileNotFoundError(f"File {path} does not exist.")
|
160
|
+
sample = torch.load(path)
|
161
|
+
return sample["batch"], sample["prediction"]
|
162
|
+
|
163
|
+
@override
|
164
|
+
def __iter__(self) -> Iterator[tuple[Any, Any]]:
|
165
|
+
for i in range(len(self)):
|
166
|
+
yield self[i]
|
@@ -21,6 +21,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
|
|
21
21
|
from nshtrainer.callbacks import (
|
22
22
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
23
23
|
)
|
24
|
+
from nshtrainer.callbacks import (
|
25
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
26
|
+
)
|
24
27
|
from nshtrainer.callbacks import (
|
25
28
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
26
29
|
)
|
@@ -95,9 +98,21 @@ from nshtrainer.nn.nonlinearity import (
|
|
95
98
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
96
99
|
)
|
97
100
|
from nshtrainer.nn.nonlinearity import nonlinearity_registry as nonlinearity_registry
|
101
|
+
from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
|
102
|
+
from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
|
103
|
+
from nshtrainer.optimizer import AdagradConfig as AdagradConfig
|
104
|
+
from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
|
105
|
+
from nshtrainer.optimizer import AdamConfig as AdamConfig
|
98
106
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
107
|
+
from nshtrainer.optimizer import ASGDConfig as ASGDConfig
|
108
|
+
from nshtrainer.optimizer import NAdamConfig as NAdamConfig
|
99
109
|
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
100
110
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
111
|
+
from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
112
|
+
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
113
|
+
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
114
|
+
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
115
|
+
from nshtrainer.optimizer import Union as Union
|
101
116
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
102
117
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
103
118
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
@@ -225,11 +240,17 @@ from . import trainer as trainer
|
|
225
240
|
from . import util as util
|
226
241
|
|
227
242
|
__all__ = [
|
243
|
+
"ASGDConfig",
|
228
244
|
"AcceleratorConfig",
|
229
245
|
"AcceleratorConfigBase",
|
230
246
|
"ActSaveConfig",
|
231
247
|
"ActSaveLoggerConfig",
|
248
|
+
"AdadeltaConfig",
|
249
|
+
"AdafactorConfig",
|
250
|
+
"AdagradConfig",
|
251
|
+
"AdamConfig",
|
232
252
|
"AdamWConfig",
|
253
|
+
"AdamaxConfig",
|
233
254
|
"AdvancedProfilerConfig",
|
234
255
|
"AsyncCheckpointIOPlugin",
|
235
256
|
"BaseCheckpointCallbackConfig",
|
@@ -249,6 +270,7 @@ __all__ = [
|
|
249
270
|
"DeepSpeedPluginConfig",
|
250
271
|
"DirectoryConfig",
|
251
272
|
"DirectorySetupCallbackConfig",
|
273
|
+
"DistributedPredictionWriterConfig",
|
252
274
|
"DoublePrecisionPluginConfig",
|
253
275
|
"DurationConfig",
|
254
276
|
"ELUNonlinearityConfig",
|
@@ -294,6 +316,7 @@ __all__ = [
|
|
294
316
|
"MetricValidationCallbackConfig",
|
295
317
|
"MishNonlinearityConfig",
|
296
318
|
"MixedPrecisionPluginConfig",
|
319
|
+
"NAdamConfig",
|
297
320
|
"NonlinearityConfig",
|
298
321
|
"NonlinearityConfigBase",
|
299
322
|
"NormLoggingCallbackConfig",
|
@@ -306,10 +329,14 @@ __all__ = [
|
|
306
329
|
"PrintTableMetricsCallbackConfig",
|
307
330
|
"ProfilerConfig",
|
308
331
|
"PyTorchProfilerConfig",
|
332
|
+
"RAdamConfig",
|
309
333
|
"RLPSanityChecksCallbackConfig",
|
334
|
+
"RMSpropConfig",
|
310
335
|
"RNGConfig",
|
311
336
|
"ReLUNonlinearityConfig",
|
312
337
|
"ReduceLROnPlateauConfig",
|
338
|
+
"RpropConfig",
|
339
|
+
"SGDConfig",
|
313
340
|
"SLURMEnvironmentPlugin",
|
314
341
|
"SanityCheckingConfig",
|
315
342
|
"SharedParametersCallbackConfig",
|
@@ -331,6 +358,7 @@ __all__ = [
|
|
331
358
|
"TorchSyncBatchNormPlugin",
|
332
359
|
"TrainerConfig",
|
333
360
|
"TransformerEnginePluginConfig",
|
361
|
+
"Union",
|
334
362
|
"WandbLoggerConfig",
|
335
363
|
"WandbUploadCodeCallbackConfig",
|
336
364
|
"WandbWatchCallbackConfig",
|
@@ -12,6 +12,9 @@ from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackCon
|
|
12
12
|
from nshtrainer.callbacks import (
|
13
13
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
14
14
|
)
|
15
|
+
from nshtrainer.callbacks import (
|
16
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
17
|
+
)
|
15
18
|
from nshtrainer.callbacks import (
|
16
19
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
17
20
|
)
|
@@ -62,6 +65,7 @@ from . import base as base
|
|
62
65
|
from . import checkpoint as checkpoint
|
63
66
|
from . import debug_flag as debug_flag
|
64
67
|
from . import directory_setup as directory_setup
|
68
|
+
from . import distributed_prediction_writer as distributed_prediction_writer
|
65
69
|
from . import early_stopping as early_stopping
|
66
70
|
from . import ema as ema
|
67
71
|
from . import finite_checks as finite_checks
|
@@ -86,6 +90,7 @@ __all__ = [
|
|
86
90
|
"CheckpointMetadata",
|
87
91
|
"DebugFlagCallbackConfig",
|
88
92
|
"DirectorySetupCallbackConfig",
|
93
|
+
"DistributedPredictionWriterConfig",
|
89
94
|
"EMACallbackConfig",
|
90
95
|
"EarlyStoppingCallbackConfig",
|
91
96
|
"EpochTimerCallbackConfig",
|
@@ -109,6 +114,7 @@ __all__ = [
|
|
109
114
|
"checkpoint",
|
110
115
|
"debug_flag",
|
111
116
|
"directory_setup",
|
117
|
+
"distributed_prediction_writer",
|
112
118
|
"early_stopping",
|
113
119
|
"ema",
|
114
120
|
"finite_checks",
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
6
|
+
CallbackConfigBase as CallbackConfigBase,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
9
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
10
|
+
)
|
11
|
+
from nshtrainer.callbacks.distributed_prediction_writer import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"CallbackConfigBase",
|
17
|
+
"DistributedPredictionWriterConfig",
|
18
|
+
"callback_registry",
|
19
|
+
]
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.optimizer import AdadeltaConfig as AdadeltaConfig
|
6
|
+
from nshtrainer.optimizer import AdafactorConfig as AdafactorConfig
|
7
|
+
from nshtrainer.optimizer import AdagradConfig as AdagradConfig
|
8
|
+
from nshtrainer.optimizer import AdamaxConfig as AdamaxConfig
|
9
|
+
from nshtrainer.optimizer import AdamConfig as AdamConfig
|
10
|
+
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
11
|
+
from nshtrainer.optimizer import ASGDConfig as ASGDConfig
|
12
|
+
from nshtrainer.optimizer import NAdamConfig as NAdamConfig
|
13
|
+
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
14
|
+
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
15
|
+
from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
16
|
+
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
17
|
+
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
18
|
+
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
19
|
+
from nshtrainer.optimizer import Union as Union
|
20
|
+
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"ASGDConfig",
|
24
|
+
"AdadeltaConfig",
|
25
|
+
"AdafactorConfig",
|
26
|
+
"AdagradConfig",
|
27
|
+
"AdamConfig",
|
28
|
+
"AdamWConfig",
|
29
|
+
"AdamaxConfig",
|
30
|
+
"NAdamConfig",
|
31
|
+
"OptimizerConfig",
|
32
|
+
"OptimizerConfigBase",
|
33
|
+
"RAdamConfig",
|
34
|
+
"RMSpropConfig",
|
35
|
+
"RpropConfig",
|
36
|
+
"SGDConfig",
|
37
|
+
"Union",
|
38
|
+
"optimizer_registry",
|
39
|
+
]
|
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
+
from nshtrainer.trainer._config import (
|
26
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
27
|
+
)
|
25
28
|
from nshtrainer.trainer._config import (
|
26
29
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
27
30
|
)
|
@@ -149,6 +152,7 @@ __all__ = [
|
|
149
152
|
"DebugFlagCallbackConfig",
|
150
153
|
"DeepSpeedPluginConfig",
|
151
154
|
"DirectoryConfig",
|
155
|
+
"DistributedPredictionWriterConfig",
|
152
156
|
"DoublePrecisionPluginConfig",
|
153
157
|
"EarlyStoppingCallbackConfig",
|
154
158
|
"EnvironmentConfig",
|
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
+
from nshtrainer.trainer._config import (
|
22
|
+
DistributedPredictionWriterConfig as DistributedPredictionWriterConfig,
|
23
|
+
)
|
21
24
|
from nshtrainer.trainer._config import (
|
22
25
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
23
26
|
)
|
@@ -70,6 +73,7 @@ __all__ = [
|
|
70
73
|
"CheckpointSavingConfig",
|
71
74
|
"DebugFlagCallbackConfig",
|
72
75
|
"DirectoryConfig",
|
76
|
+
"DistributedPredictionWriterConfig",
|
73
77
|
"EarlyStoppingCallbackConfig",
|
74
78
|
"EnvironmentConfig",
|
75
79
|
"GradientClippingConfig",
|
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Callable, Mapping
|
5
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Generic, Literal, cast
|
7
|
+
from typing import Any, Generic, Literal, TypedDict, cast
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
10
|
import torch
|
@@ -53,6 +53,47 @@ VALID_REDUCE_OPS = (
|
|
53
53
|
)
|
54
54
|
|
55
55
|
|
56
|
+
class IndividualSample(TypedDict):
|
57
|
+
"""
|
58
|
+
A dictionary that contains the individual sample.
|
59
|
+
This is used to split the batched predictions into individual predictions.
|
60
|
+
"""
|
61
|
+
|
62
|
+
index: int
|
63
|
+
"""The index of the sample in the batch."""
|
64
|
+
|
65
|
+
batch: Any
|
66
|
+
"""The batch to split."""
|
67
|
+
|
68
|
+
prediction: Any
|
69
|
+
"""The batched prediction to split."""
|
70
|
+
|
71
|
+
|
72
|
+
def default_split_batched_predictions(
|
73
|
+
batch: Any,
|
74
|
+
prediction: Any,
|
75
|
+
batch_indices: Sequence[Any],
|
76
|
+
) -> Iterable[IndividualSample]:
|
77
|
+
"""
|
78
|
+
Splits the batched predictions into a list of individual predictions.
|
79
|
+
Args:
|
80
|
+
batch: The batch to split.
|
81
|
+
prediction: The batched prediction to split.
|
82
|
+
batch_indices: The indices of the batches.
|
83
|
+
Returns:
|
84
|
+
A tuple of two sequences: the corresponding batches and the individual predictions.
|
85
|
+
"""
|
86
|
+
import torch.utils._pytree as tree
|
87
|
+
|
88
|
+
for sample_idx, batch_idx in enumerate(batch_indices):
|
89
|
+
# Create a dictionary for each sample
|
90
|
+
yield IndividualSample(
|
91
|
+
index=batch_idx,
|
92
|
+
batch=tree.tree_map(lambda x: x[sample_idx], batch),
|
93
|
+
prediction=tree.tree_map(lambda x: x[sample_idx], prediction),
|
94
|
+
)
|
95
|
+
|
96
|
+
|
56
97
|
class LightningModuleBase(
|
57
98
|
DebugModuleMixin,
|
58
99
|
RLPSanityCheckModuleMixin,
|
@@ -171,6 +212,23 @@ class LightningModuleBase(
|
|
171
212
|
loss = cast(torch.Tensor, loss)
|
172
213
|
return loss
|
173
214
|
|
215
|
+
def split_batched_predictions(
|
216
|
+
self,
|
217
|
+
batch: Any,
|
218
|
+
prediction: Any,
|
219
|
+
batch_indices: Sequence[Any],
|
220
|
+
) -> Iterable[IndividualSample]:
|
221
|
+
"""
|
222
|
+
Splits the batched predictions into a list of individual predictions.
|
223
|
+
Args:
|
224
|
+
batch: The batch to split.
|
225
|
+
prediction: The batched prediction to split.
|
226
|
+
batch_indices: The indices of the batches.
|
227
|
+
Returns:
|
228
|
+
A tuple of two sequences: the corresponding batches and the individual predictions.
|
229
|
+
"""
|
230
|
+
return default_split_batched_predictions(batch, prediction, batch_indices)
|
231
|
+
|
174
232
|
@override
|
175
233
|
@classmethod
|
176
234
|
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|