nshtrainer 0.44.0__tar.gz → 1.0.0b9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/PKG-INFO +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/pyproject.toml +10 -3
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/__init__.py +6 -3
- nshtrainer-1.0.0b9/src/nshtrainer/_callback.py +337 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/loader.py +23 -30
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer-1.0.0b9/src/nshtrainer/_experimental/__init__.py +1 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_hf_hub.py +25 -26
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/__init__.py +1 -3
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/actsave.py +22 -20
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/base.py +7 -7
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/_base.py +8 -5
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/debug_flag.py +14 -19
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/directory_setup.py +6 -11
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/early_stopping.py +3 -3
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/ema.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/finite_checks.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/gradient_skipping.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/log_epoch.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/norm_logging.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/print_table.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/shared_parameters.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/timer.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/wandb_upload_code.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/wandb_watch.py +1 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/__init__.py +189 -189
- nshtrainer-1.0.0b9/src/nshtrainer/config/_checkpoint/__init__.py +70 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer-1.0.0b9/src/nshtrainer/config/callbacks/log_epoch/__init__.py +31 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer-1.0.0b9/src/nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer-1.0.0b9/src/nshtrainer/config/trainer/__init__.py +180 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer-1.0.0b9/src/nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer-1.0.0b9/src/nshtrainer/config/util/__init__.py +109 -0
- nshtrainer-1.0.0b9/src/nshtrainer/data/datamodule.py +56 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/__init__.py +2 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/_base.py +5 -2
- nshtrainer-1.0.0b9/src/nshtrainer/loggers/actsave.py +59 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/csv.py +5 -5
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/tensorboard.py +5 -5
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/loggers/wandb.py +17 -16
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/_base.py +2 -1
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer-1.0.0b9/src/nshtrainer/model/__init__.py +3 -0
- nshtrainer-1.0.0b9/src/nshtrainer/model/base.py +243 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer-1.0.0b9/src/nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer-1.0.0b9/src/nshtrainer/model/mixins/logger.py +163 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/_base.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/advanced.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/pytorch.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/simple.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/__init__.py +1 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/_config.py +164 -17
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/checkpoint_connector.py +23 -8
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/trainer.py +194 -76
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/_environment_info.py +21 -13
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/dtype.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/typing_utils.py +1 -1
- nshtrainer-0.44.0/src/nshtrainer/_callback.py +0 -42
- nshtrainer-0.44.0/src/nshtrainer/_experimental/__init__.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer-0.44.0/src/nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer-0.44.0/src/nshtrainer/config/callbacks/throughput_monitor/__init__.py +0 -33
- nshtrainer-0.44.0/src/nshtrainer/config/model/__init__.py +0 -41
- nshtrainer-0.44.0/src/nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer-0.44.0/src/nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer-0.44.0/src/nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer-0.44.0/src/nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer-0.44.0/src/nshtrainer/data/datamodule.py +0 -7
- nshtrainer-0.44.0/src/nshtrainer/ll/__init__.py +0 -59
- nshtrainer-0.44.0/src/nshtrainer/ll/_experimental.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/actsave.py +0 -6
- nshtrainer-0.44.0/src/nshtrainer/ll/callbacks.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/config.py +0 -6
- nshtrainer-0.44.0/src/nshtrainer/ll/data.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/log.py +0 -5
- nshtrainer-0.44.0/src/nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/model.py +0 -21
- nshtrainer-0.44.0/src/nshtrainer/ll/nn.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/optimizer.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/runner.py +0 -5
- nshtrainer-0.44.0/src/nshtrainer/ll/snapshot.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/snoop.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/trainer.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/typecheck.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/ll/util.py +0 -3
- nshtrainer-0.44.0/src/nshtrainer/model/__init__.py +0 -7
- nshtrainer-0.44.0/src/nshtrainer/model/base.py +0 -526
- nshtrainer-0.44.0/src/nshtrainer/model/config.py +0 -218
- nshtrainer-0.44.0/src/nshtrainer/model/mixins/logger.py +0 -166
- nshtrainer-0.44.0/src/nshtrainer/runner.py +0 -101
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/README.md +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_directory/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/_hf_hub/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/base/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/ema/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/timer/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/_base/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/csv/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/loggers/wandb/__init__.py +6 -6
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/_base/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/metrics/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/metrics/_config/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/__init__.py +18 -18
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/mlp/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/optimizer/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/_base/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/advanced/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/profiler/simple/__init__.py +4 -4
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/trainer/checkpoint_connector/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/_environment_info/__init__.py +20 -20
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/__init__.py +2 -2
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/dtype/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/config/util/config/duration/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.44.0 → nshtrainer-1.0.0b9}/src/nshtrainer/util/typed.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nshtrainer
|
3
|
-
Version: 0.
|
3
|
+
Version: 1.0.0b9
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -15,7 +15,7 @@ Requires-Dist: huggingface-hub ; extra == "extra"
|
|
15
15
|
Requires-Dist: lightning
|
16
16
|
Requires-Dist: nshconfig
|
17
17
|
Requires-Dist: nshrunner
|
18
|
-
Requires-Dist: nshutils
|
18
|
+
Requires-Dist: nshutils ; extra == "extra"
|
19
19
|
Requires-Dist: numpy
|
20
20
|
Requires-Dist: packaging
|
21
21
|
Requires-Dist: psutil
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "nshtrainer"
|
3
|
-
version = "0.
|
3
|
+
version = "1.0.0-beta9"
|
4
4
|
description = ""
|
5
5
|
authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
|
6
6
|
readme = "README.md"
|
@@ -9,7 +9,7 @@ readme = "README.md"
|
|
9
9
|
python = "^3.10"
|
10
10
|
nshrunner = "*"
|
11
11
|
nshconfig = "*"
|
12
|
-
nshutils = "*"
|
12
|
+
nshutils = { version = "*", optional = true }
|
13
13
|
psutil = "*"
|
14
14
|
numpy = "*"
|
15
15
|
torch = "*"
|
@@ -50,4 +50,11 @@ ignore = ["F722", "F821", "E731", "E741"]
|
|
50
50
|
required-imports = ["from __future__ import annotations"]
|
51
51
|
|
52
52
|
[tool.poetry.extras]
|
53
|
-
extra = [
|
53
|
+
extra = [
|
54
|
+
"wrapt",
|
55
|
+
"GitPython",
|
56
|
+
"wandb",
|
57
|
+
"tensorboard",
|
58
|
+
"huggingface-hub",
|
59
|
+
"nshutils",
|
60
|
+
]
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from . import _experimental as _experimental
|
4
4
|
from . import callbacks as callbacks
|
5
|
-
from . import config as config
|
6
5
|
from . import data as data
|
7
6
|
from . import lr_scheduler as lr_scheduler
|
8
7
|
from . import metrics as metrics
|
@@ -12,7 +11,11 @@ from . import optimizer as optimizer
|
|
12
11
|
from . import profiler as profiler
|
13
12
|
from .data import LightningDataModuleBase as LightningDataModuleBase
|
14
13
|
from .metrics import MetricConfig as MetricConfig
|
15
|
-
from .model import BaseConfig as BaseConfig
|
16
14
|
from .model import LightningModuleBase as LightningModuleBase
|
17
|
-
from .runner import Runner as Runner
|
18
15
|
from .trainer import Trainer as Trainer
|
16
|
+
from .trainer import TrainerConfig as TrainerConfig
|
17
|
+
|
18
|
+
try:
|
19
|
+
from . import config as config
|
20
|
+
except BaseException:
|
21
|
+
pass
|
@@ -0,0 +1,337 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import TYPE_CHECKING, Any
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
8
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
9
|
+
from torch.optim import Optimizer
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from .model import LightningModuleBase
|
13
|
+
from .trainer import Trainer
|
14
|
+
|
15
|
+
|
16
|
+
class NTCallbackBase(_LightningCallback):
|
17
|
+
def setup( # pyright: ignore[reportIncompatibleMethodOverride]
|
18
|
+
self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
|
19
|
+
) -> None:
|
20
|
+
"""Called when fit, validate, test, predict, or tune begins."""
|
21
|
+
|
22
|
+
def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
|
23
|
+
self, trainer: Trainer, pl_module: LightningModuleBase, stage: str
|
24
|
+
) -> None:
|
25
|
+
"""Called when fit, validate, test, predict, or tune ends."""
|
26
|
+
|
27
|
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
28
|
+
"""Called when fit begins."""
|
29
|
+
|
30
|
+
def on_fit_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
31
|
+
"""Called when fit ends."""
|
32
|
+
|
33
|
+
def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
34
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
35
|
+
) -> None:
|
36
|
+
"""Called when the validation sanity check starts."""
|
37
|
+
|
38
|
+
def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
39
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
40
|
+
) -> None:
|
41
|
+
"""Called when the validation sanity check ends."""
|
42
|
+
|
43
|
+
def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
44
|
+
self,
|
45
|
+
trainer: Trainer,
|
46
|
+
pl_module: LightningModuleBase,
|
47
|
+
batch: Any,
|
48
|
+
batch_idx: int,
|
49
|
+
) -> None:
|
50
|
+
"""Called when the train batch begins."""
|
51
|
+
|
52
|
+
def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
53
|
+
self,
|
54
|
+
trainer: Trainer,
|
55
|
+
pl_module: LightningModuleBase,
|
56
|
+
outputs: STEP_OUTPUT,
|
57
|
+
batch: Any,
|
58
|
+
batch_idx: int,
|
59
|
+
) -> None:
|
60
|
+
"""Called when the train batch ends.
|
61
|
+
|
62
|
+
Note:
|
63
|
+
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
|
64
|
+
loss returned from ``training_step``.
|
65
|
+
|
66
|
+
"""
|
67
|
+
|
68
|
+
def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
69
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
70
|
+
) -> None:
|
71
|
+
"""Called when the train epoch begins."""
|
72
|
+
|
73
|
+
def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
74
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
75
|
+
) -> None:
|
76
|
+
"""Called when the train epoch ends.
|
77
|
+
|
78
|
+
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
|
79
|
+
:class:`lightning.pytorch.core.LightningModule` and access them in this hook:
|
80
|
+
|
81
|
+
.. code-block:: python
|
82
|
+
|
83
|
+
class MyLightningModule(L.LightningModule):
|
84
|
+
def __init__(self):
|
85
|
+
super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
|
86
|
+
self.training_step_outputs = []
|
87
|
+
|
88
|
+
def training_step(self):
|
89
|
+
loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
|
90
|
+
self.training_step_outputs.append(loss)
|
91
|
+
return loss
|
92
|
+
|
93
|
+
|
94
|
+
class MyCallback(L.Callback):
|
95
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
96
|
+
# do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
|
97
|
+
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
|
98
|
+
pl_module.log("training_epoch_mean", epoch_mean)
|
99
|
+
# free up the memory
|
100
|
+
pl_module.training_step_outputs.clear()
|
101
|
+
|
102
|
+
"""
|
103
|
+
|
104
|
+
def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
105
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
106
|
+
) -> None:
|
107
|
+
"""Called when the val epoch begins."""
|
108
|
+
|
109
|
+
def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
110
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
111
|
+
) -> None:
|
112
|
+
"""Called when the val epoch ends."""
|
113
|
+
|
114
|
+
def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
115
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
116
|
+
) -> None:
|
117
|
+
"""Called when the test epoch begins."""
|
118
|
+
|
119
|
+
def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
120
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
121
|
+
) -> None:
|
122
|
+
"""Called when the test epoch ends."""
|
123
|
+
|
124
|
+
def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
125
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
126
|
+
) -> None:
|
127
|
+
"""Called when the predict epoch begins."""
|
128
|
+
|
129
|
+
def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
130
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
131
|
+
) -> None:
|
132
|
+
"""Called when the predict epoch ends."""
|
133
|
+
|
134
|
+
def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
135
|
+
self,
|
136
|
+
trainer: Trainer,
|
137
|
+
pl_module: LightningModuleBase,
|
138
|
+
batch: Any,
|
139
|
+
batch_idx: int,
|
140
|
+
dataloader_idx: int = 0,
|
141
|
+
) -> None:
|
142
|
+
"""Called when the validation batch begins."""
|
143
|
+
|
144
|
+
def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
145
|
+
self,
|
146
|
+
trainer: Trainer,
|
147
|
+
pl_module: LightningModuleBase,
|
148
|
+
outputs: STEP_OUTPUT,
|
149
|
+
batch: Any,
|
150
|
+
batch_idx: int,
|
151
|
+
dataloader_idx: int = 0,
|
152
|
+
) -> None:
|
153
|
+
"""Called when the validation batch ends."""
|
154
|
+
|
155
|
+
def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
156
|
+
self,
|
157
|
+
trainer: Trainer,
|
158
|
+
pl_module: LightningModuleBase,
|
159
|
+
batch: Any,
|
160
|
+
batch_idx: int,
|
161
|
+
dataloader_idx: int = 0,
|
162
|
+
) -> None:
|
163
|
+
"""Called when the test batch begins."""
|
164
|
+
|
165
|
+
def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
166
|
+
self,
|
167
|
+
trainer: Trainer,
|
168
|
+
pl_module: LightningModuleBase,
|
169
|
+
outputs: STEP_OUTPUT,
|
170
|
+
batch: Any,
|
171
|
+
batch_idx: int,
|
172
|
+
dataloader_idx: int = 0,
|
173
|
+
) -> None:
|
174
|
+
"""Called when the test batch ends."""
|
175
|
+
|
176
|
+
def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
177
|
+
self,
|
178
|
+
trainer: Trainer,
|
179
|
+
pl_module: LightningModuleBase,
|
180
|
+
batch: Any,
|
181
|
+
batch_idx: int,
|
182
|
+
dataloader_idx: int = 0,
|
183
|
+
) -> None:
|
184
|
+
"""Called when the predict batch begins."""
|
185
|
+
|
186
|
+
def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
187
|
+
self,
|
188
|
+
trainer: Trainer,
|
189
|
+
pl_module: LightningModuleBase,
|
190
|
+
outputs: Any,
|
191
|
+
batch: Any,
|
192
|
+
batch_idx: int,
|
193
|
+
dataloader_idx: int = 0,
|
194
|
+
) -> None:
|
195
|
+
"""Called when the predict batch ends."""
|
196
|
+
|
197
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
198
|
+
"""Called when the train begins."""
|
199
|
+
|
200
|
+
def on_train_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
201
|
+
"""Called when the train ends."""
|
202
|
+
|
203
|
+
def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
204
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
205
|
+
) -> None:
|
206
|
+
"""Called when the validation loop begins."""
|
207
|
+
|
208
|
+
def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
209
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
210
|
+
) -> None:
|
211
|
+
"""Called when the validation loop ends."""
|
212
|
+
|
213
|
+
def on_test_start(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
214
|
+
"""Called when the test begins."""
|
215
|
+
|
216
|
+
def on_test_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
217
|
+
"""Called when the test ends."""
|
218
|
+
|
219
|
+
def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
220
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
221
|
+
) -> None:
|
222
|
+
"""Called when the predict begins."""
|
223
|
+
|
224
|
+
def on_predict_end(self, trainer: Trainer, pl_module: LightningModuleBase) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
225
|
+
"""Called when predict ends."""
|
226
|
+
|
227
|
+
def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
|
228
|
+
self,
|
229
|
+
trainer: Trainer,
|
230
|
+
pl_module: LightningModuleBase,
|
231
|
+
exception: BaseException,
|
232
|
+
) -> None:
|
233
|
+
"""Called when any trainer execution is interrupted by an exception."""
|
234
|
+
|
235
|
+
def state_dict(self) -> dict[str, Any]: # pyright: ignore[reportIncompatibleMethodOverride]
|
236
|
+
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``.
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
A dictionary containing callback state.
|
240
|
+
|
241
|
+
"""
|
242
|
+
return {}
|
243
|
+
|
244
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
245
|
+
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
state_dict: the callback state returned by ``state_dict``.
|
249
|
+
|
250
|
+
"""
|
251
|
+
pass
|
252
|
+
|
253
|
+
def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
254
|
+
self,
|
255
|
+
trainer: Trainer,
|
256
|
+
pl_module: LightningModuleBase,
|
257
|
+
checkpoint: dict[str, Any],
|
258
|
+
) -> None:
|
259
|
+
r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
trainer: the current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
|
263
|
+
pl_module: the current :class:`~lightning.pytorch.core.LightningModule` instance.
|
264
|
+
checkpoint: the checkpoint dictionary that will be saved.
|
265
|
+
|
266
|
+
"""
|
267
|
+
|
268
|
+
def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
269
|
+
self,
|
270
|
+
trainer: Trainer,
|
271
|
+
pl_module: LightningModuleBase,
|
272
|
+
checkpoint: dict[str, Any],
|
273
|
+
) -> None:
|
274
|
+
r"""Called when loading a model checkpoint, use to reload state.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
trainer: the current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
|
278
|
+
pl_module: the current :class:`~lightning.pytorch.core.LightningModule` instance.
|
279
|
+
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
|
280
|
+
|
281
|
+
"""
|
282
|
+
|
283
|
+
def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
284
|
+
self, trainer: Trainer, pl_module: LightningModuleBase, loss: torch.Tensor
|
285
|
+
) -> None:
|
286
|
+
"""Called before ``loss.backward()``."""
|
287
|
+
|
288
|
+
def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
289
|
+
self, trainer: Trainer, pl_module: LightningModuleBase
|
290
|
+
) -> None:
|
291
|
+
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
292
|
+
|
293
|
+
def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
294
|
+
self,
|
295
|
+
trainer: Trainer,
|
296
|
+
pl_module: LightningModuleBase,
|
297
|
+
optimizer: Optimizer,
|
298
|
+
) -> None:
|
299
|
+
"""Called before ``optimizer.step()``."""
|
300
|
+
|
301
|
+
def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
|
302
|
+
self,
|
303
|
+
trainer: Trainer,
|
304
|
+
pl_module: LightningModuleBase,
|
305
|
+
optimizer: Optimizer,
|
306
|
+
) -> None:
|
307
|
+
"""Called before ``optimizer.zero_grad()``."""
|
308
|
+
|
309
|
+
def on_checkpoint_saved( # pyright: ignore[reportIncompatibleMethodOverride]
|
310
|
+
self,
|
311
|
+
ckpt_path: Path,
|
312
|
+
metadata_path: Path | None,
|
313
|
+
trainer: "Trainer",
|
314
|
+
pl_module: "LightningModuleBase",
|
315
|
+
) -> None:
|
316
|
+
"""Called after a checkpoint is saved."""
|
317
|
+
pass
|
318
|
+
|
319
|
+
|
320
|
+
def _call_on_checkpoint_saved(
|
321
|
+
trainer: "Trainer",
|
322
|
+
ckpt_path: str | Path,
|
323
|
+
metadata_path: str | Path | None,
|
324
|
+
):
|
325
|
+
ckpt_path = Path(ckpt_path)
|
326
|
+
metadata_path = Path(metadata_path) if metadata_path else None
|
327
|
+
|
328
|
+
for callback in trainer.callbacks:
|
329
|
+
if not isinstance(callback, NTCallbackBase):
|
330
|
+
continue
|
331
|
+
|
332
|
+
callback.on_checkpoint_saved(
|
333
|
+
ckpt_path,
|
334
|
+
metadata_path,
|
335
|
+
trainer,
|
336
|
+
trainer._base_module,
|
337
|
+
)
|
@@ -7,7 +7,6 @@ from pathlib import Path
|
|
7
7
|
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias, overload
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
|
-
from lightning.pytorch import Trainer as LightningTrainer
|
11
10
|
from lightning.pytorch.trainer.states import TrainerFn
|
12
11
|
from typing_extensions import assert_never
|
13
12
|
|
@@ -15,7 +14,8 @@ from ..metrics._config import MetricConfig
|
|
15
14
|
from .metadata import CheckpointMetadata
|
16
15
|
|
17
16
|
if TYPE_CHECKING:
|
18
|
-
from ..
|
17
|
+
from ..trainer import Trainer
|
18
|
+
from ..trainer._config import TrainerConfig
|
19
19
|
|
20
20
|
log = logging.getLogger(__name__)
|
21
21
|
|
@@ -228,22 +228,22 @@ class _CkptCandidate:
|
|
228
228
|
@overload
|
229
229
|
def _load_ckpt_meta(
|
230
230
|
path: Path,
|
231
|
-
|
231
|
+
trainer_config: TrainerConfig,
|
232
232
|
on_error: Literal["warn"] = "warn",
|
233
233
|
) -> _CkptCandidate | None: ...
|
234
234
|
@overload
|
235
235
|
def _load_ckpt_meta(
|
236
236
|
path: Path,
|
237
|
-
|
237
|
+
trainer_config: TrainerConfig,
|
238
238
|
on_error: Literal["raise"],
|
239
239
|
) -> _CkptCandidate: ...
|
240
240
|
def _load_ckpt_meta(
|
241
241
|
path: Path,
|
242
|
-
|
242
|
+
trainer_config: TrainerConfig,
|
243
243
|
on_error: Literal["warn", "raise"] = "warn",
|
244
244
|
):
|
245
245
|
meta = CheckpointMetadata.from_file(path)
|
246
|
-
if
|
246
|
+
if trainer_config.id != meta.run_id:
|
247
247
|
error_msg = f"Skipping checkpoint {path} because it belongs to a different run"
|
248
248
|
match on_error:
|
249
249
|
case "warn":
|
@@ -256,16 +256,13 @@ def _load_ckpt_meta(
|
|
256
256
|
return _CkptCandidate(meta, path)
|
257
257
|
|
258
258
|
|
259
|
-
def _checkpoint_candidates(
|
260
|
-
root_config: "BaseConfig",
|
261
|
-
trainer: LightningTrainer,
|
262
|
-
*,
|
263
|
-
include_hpc: bool = True,
|
264
|
-
):
|
259
|
+
def _checkpoint_candidates(trainer: Trainer, *, include_hpc: bool = True):
|
265
260
|
# Load the checkpoint directory, and throw if it doesn't exist.
|
266
261
|
# This indicates a non-standard setup, and we don't want to guess
|
267
262
|
# where the checkpoints are.
|
268
|
-
ckpt_dir =
|
263
|
+
ckpt_dir = trainer.hparams.directory.resolve_subdirectory(
|
264
|
+
trainer.hparams.id, "checkpoint"
|
265
|
+
)
|
269
266
|
if not ckpt_dir.is_dir():
|
270
267
|
raise FileNotFoundError(
|
271
268
|
f"Checkpoint directory {ckpt_dir} not found. "
|
@@ -275,46 +272,40 @@ def _checkpoint_candidates(
|
|
275
272
|
# Load all checkpoints in the directory.
|
276
273
|
# We can do this by looking for metadata files.
|
277
274
|
for path in ckpt_dir.glob(f"*{CheckpointMetadata.PATH_SUFFIX}"):
|
278
|
-
if (meta := _load_ckpt_meta(path,
|
275
|
+
if (meta := _load_ckpt_meta(path, trainer.hparams)) is not None:
|
279
276
|
yield meta
|
280
277
|
|
281
278
|
# If we have a pre-empted checkpoint, load it
|
282
279
|
if include_hpc and (hpc_path := trainer._checkpoint_connector._hpc_resume_path):
|
283
280
|
hpc_meta_path = Path(hpc_path).with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
284
|
-
if (meta := _load_ckpt_meta(hpc_meta_path,
|
281
|
+
if (meta := _load_ckpt_meta(hpc_meta_path, trainer.hparams)) is not None:
|
285
282
|
yield meta
|
286
283
|
|
287
284
|
|
288
285
|
def _additional_candidates(
|
289
|
-
additional_candidates: Iterable[Path],
|
286
|
+
additional_candidates: Iterable[Path], trainer_config: TrainerConfig
|
290
287
|
):
|
291
288
|
for path in additional_candidates:
|
292
289
|
if (
|
293
290
|
meta := _load_ckpt_meta(
|
294
|
-
path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
291
|
+
path.with_suffix(CheckpointMetadata.PATH_SUFFIX), trainer_config
|
295
292
|
)
|
296
293
|
) is None:
|
297
294
|
continue
|
298
295
|
yield meta
|
299
296
|
|
300
297
|
|
301
|
-
def _resolve_checkpoint(
|
302
|
-
config: CheckpointLoadingConfig,
|
303
|
-
root_config: "BaseConfig",
|
304
|
-
trainer: LightningTrainer,
|
305
|
-
):
|
298
|
+
def _resolve_checkpoint(config: CheckpointLoadingConfig, trainer: Trainer):
|
306
299
|
# We lazily load the checkpoint candidates to avoid loading them
|
307
300
|
# if they are not needed.
|
308
301
|
_ckpt_candidates: list[_CkptCandidate] | None = None
|
309
302
|
|
310
303
|
def ckpt_candidates():
|
311
|
-
nonlocal _ckpt_candidates,
|
304
|
+
nonlocal _ckpt_candidates, trainer
|
312
305
|
|
313
306
|
if _ckpt_candidates is None:
|
314
307
|
_ckpt_candidates = list(
|
315
|
-
_checkpoint_candidates(
|
316
|
-
root_config, trainer, include_hpc=config.include_hpc
|
317
|
-
)
|
308
|
+
_checkpoint_candidates(trainer, include_hpc=config.include_hpc)
|
318
309
|
)
|
319
310
|
return _ckpt_candidates
|
320
311
|
|
@@ -324,7 +315,7 @@ def _resolve_checkpoint(
|
|
324
315
|
case UserProvidedPathCheckpointStrategyConfig():
|
325
316
|
meta = _load_ckpt_meta(
|
326
317
|
strategy.path.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
327
|
-
|
318
|
+
trainer.hparams,
|
328
319
|
on_error=strategy.on_error,
|
329
320
|
)
|
330
321
|
if meta is None:
|
@@ -334,7 +325,7 @@ def _resolve_checkpoint(
|
|
334
325
|
candidates = [
|
335
326
|
*ckpt_candidates(),
|
336
327
|
*_additional_candidates(
|
337
|
-
strategy.additional_candidates,
|
328
|
+
strategy.additional_candidates, trainer.hparams
|
338
329
|
),
|
339
330
|
]
|
340
331
|
if not candidates:
|
@@ -343,7 +334,9 @@ def _resolve_checkpoint(
|
|
343
334
|
)
|
344
335
|
continue
|
345
336
|
|
346
|
-
if (
|
337
|
+
if (
|
338
|
+
metric := strategy.metric or trainer.hparams.primary_metric
|
339
|
+
) is None:
|
347
340
|
log.warning(
|
348
341
|
"No metric specified for `best` checkpoint strategy, "
|
349
342
|
"and no primary metric is set in the configuration. "
|
@@ -369,7 +362,7 @@ def _resolve_checkpoint(
|
|
369
362
|
candidates = [
|
370
363
|
*ckpt_candidates(),
|
371
364
|
*_additional_candidates(
|
372
|
-
strategy.additional_candidates,
|
365
|
+
strategy.additional_candidates, trainer.hparams
|
373
366
|
),
|
374
367
|
]
|
375
368
|
if not candidates:
|
@@ -5,7 +5,7 @@ import datetime
|
|
5
5
|
import logging
|
6
6
|
from collections.abc import Callable
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import TYPE_CHECKING, Any, ClassVar
|
8
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
9
9
|
|
10
10
|
import nshconfig as C
|
11
11
|
import numpy as np
|
@@ -15,7 +15,6 @@ from ..util._environment_info import EnvironmentConfig
|
|
15
15
|
from ..util.path import compute_file_checksum, try_symlink_or_copy
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
|
-
from ..model import BaseConfig, LightningModuleBase
|
19
18
|
from ..trainer.trainer import Trainer
|
20
19
|
|
21
20
|
log = logging.getLogger(__name__)
|
@@ -24,6 +23,19 @@ log = logging.getLogger(__name__)
|
|
24
23
|
METADATA_PATH_SUFFIX = ".metadata.json"
|
25
24
|
|
26
25
|
|
26
|
+
def _full_hparams_dict(trainer: Trainer):
|
27
|
+
hparams = {}
|
28
|
+
hparams["trainer"] = trainer.hparams.model_dump(mode="json")
|
29
|
+
|
30
|
+
if trainer.lightning_module is not None:
|
31
|
+
from ..model import LightningModuleBase
|
32
|
+
|
33
|
+
if isinstance(trainer.lightning_module, LightningModuleBase):
|
34
|
+
hparams["model"] = trainer.lightning_module.hparams.model_dump(mode="json")
|
35
|
+
|
36
|
+
return hparams
|
37
|
+
|
38
|
+
|
27
39
|
class CheckpointMetadata(C.Config):
|
28
40
|
PATH_SUFFIX: ClassVar[str] = METADATA_PATH_SUFFIX
|
29
41
|
|
@@ -59,8 +71,7 @@ class CheckpointMetadata(C.Config):
|
|
59
71
|
|
60
72
|
|
61
73
|
def _generate_checkpoint_metadata(
|
62
|
-
|
63
|
-
trainer: "Trainer",
|
74
|
+
trainer: Trainer,
|
64
75
|
checkpoint_path: Path,
|
65
76
|
metadata_path: Path,
|
66
77
|
):
|
@@ -84,9 +95,9 @@ def _generate_checkpoint_metadata(
|
|
84
95
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
85
96
|
checkpoint_filename=checkpoint_path.name,
|
86
97
|
checkpoint_checksum=compute_file_checksum(checkpoint_path),
|
87
|
-
run_id=
|
88
|
-
name=
|
89
|
-
project=
|
98
|
+
run_id=trainer.hparams.id,
|
99
|
+
name=trainer.hparams.full_name,
|
100
|
+
project=trainer.hparams.project,
|
90
101
|
checkpoint_timestamp=checkpoint_timestamp,
|
91
102
|
start_timestamp=start_timestamp.datetime
|
92
103
|
if start_timestamp is not None
|
@@ -95,8 +106,8 @@ def _generate_checkpoint_metadata(
|
|
95
106
|
global_step=trainer.global_step,
|
96
107
|
training_time=training_time,
|
97
108
|
metrics=metrics,
|
98
|
-
environment=
|
99
|
-
hparams=
|
109
|
+
environment=trainer.hparams.environment,
|
110
|
+
hparams=_full_hparams_dict(trainer),
|
100
111
|
)
|
101
112
|
|
102
113
|
|
@@ -104,16 +115,9 @@ def _metadata_path(checkpoint_path: Path):
|
|
104
115
|
return checkpoint_path.with_suffix(CheckpointMetadata.PATH_SUFFIX)
|
105
116
|
|
106
117
|
|
107
|
-
def _write_checkpoint_metadata(
|
108
|
-
trainer: "Trainer",
|
109
|
-
model: "LightningModuleBase",
|
110
|
-
checkpoint_path: Path,
|
111
|
-
):
|
112
|
-
config = cast("BaseConfig", model.config)
|
118
|
+
def _write_checkpoint_metadata(trainer: Trainer, checkpoint_path: Path):
|
113
119
|
metadata_path = _metadata_path(checkpoint_path)
|
114
|
-
metadata = _generate_checkpoint_metadata(
|
115
|
-
config, trainer, checkpoint_path, metadata_path
|
116
|
-
)
|
120
|
+
metadata = _generate_checkpoint_metadata(trainer, checkpoint_path, metadata_path)
|
117
121
|
|
118
122
|
# Write the metadata to the checkpoint directory
|
119
123
|
try:
|
@@ -0,0 +1 @@
|
|
1
|
+
from __future__ import annotations
|