nshtrainer 1.4.0__tar.gz → 1.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/PKG-INFO +2 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/pyproject.toml +8 -3
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_callback.py +50 -3
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/__init__.py +1 -1
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/_base.py +2 -2
- nshtrainer-1.5.0/src/nshtrainer/callbacks/log_epoch.py +136 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/print_table.py +2 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/__init__.py +0 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/optimizer/__init__.py +0 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/__init__.py +1 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/actsave.py +7 -1
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/wandb.py +5 -5
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/base.py +1 -1
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/callback.py +0 -17
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/logger.py +1 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/module_dict.py +4 -4
- nshtrainer-1.5.0/src/nshtrainer/nn/module_list.py +52 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/nonlinearity.py +15 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/optimizer.py +2 -4
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_config.py +1 -1
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/accelerator.py +1 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/__init__.py +1 -2
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/code_upload.py +1 -1
- nshtrainer-1.4.0/src/nshtrainer/callbacks/log_epoch.py +0 -49
- nshtrainer-1.4.0/src/nshtrainer/nn/module_list.py +0 -52
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/README.md +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/distributed_prediction_writer.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/nn/rng.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_distributed_prediction_result.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.4.0 → nshtrainer-1.5.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: nshtrainer
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.5.0
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -14,7 +14,7 @@ Provides-Extra: extra
|
|
14
14
|
Requires-Dist: GitPython ; extra == "extra"
|
15
15
|
Requires-Dist: huggingface-hub ; extra == "extra"
|
16
16
|
Requires-Dist: lightning
|
17
|
-
Requires-Dist: nshconfig (
|
17
|
+
Requires-Dist: nshconfig (>=0.43)
|
18
18
|
Requires-Dist: nshrunner ; extra == "extra"
|
19
19
|
Requires-Dist: nshutils ; extra == "extra"
|
20
20
|
Requires-Dist: numpy
|
@@ -1,13 +1,13 @@
|
|
1
1
|
[project]
|
2
2
|
name = "nshtrainer"
|
3
|
-
version = "1.
|
3
|
+
version = "1.5.0"
|
4
4
|
description = ""
|
5
5
|
authors = [{ name = "Nima Shoghi", email = "nimashoghi@gmail.com" }]
|
6
6
|
requires-python = ">=3.10,<4.0"
|
7
7
|
readme = "README.md"
|
8
8
|
|
9
9
|
dependencies = [
|
10
|
-
"nshconfig
|
10
|
+
"nshconfig>=0.43",
|
11
11
|
"psutil",
|
12
12
|
"numpy",
|
13
13
|
"torch",
|
@@ -47,7 +47,12 @@ deprecateTypingAliases = true
|
|
47
47
|
strictListInference = true
|
48
48
|
strictDictionaryInference = true
|
49
49
|
strictSetInference = true
|
50
|
-
reportPrivateImportUsage =
|
50
|
+
reportPrivateImportUsage = "none"
|
51
|
+
reportMatchNotExhaustive = "warning"
|
52
|
+
reportOverlappingOverload = "warning"
|
53
|
+
reportUnnecessaryTypeIgnoreComment = "warning"
|
54
|
+
reportImplicitOverride = "warning"
|
55
|
+
reportIncompatibleMethodOverride = "information"
|
51
56
|
|
52
57
|
[tool.ruff.lint]
|
53
58
|
select = ["FA102", "FA100"]
|
@@ -8,38 +8,46 @@ from lightning.pytorch import LightningModule
|
|
8
8
|
from lightning.pytorch.callbacks import Callback as _LightningCallback
|
9
9
|
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
10
10
|
from torch.optim import Optimizer
|
11
|
+
from typing_extensions import override
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from .trainer import Trainer
|
14
15
|
|
15
16
|
|
16
17
|
class NTCallbackBase(_LightningCallback):
|
18
|
+
@override
|
17
19
|
def setup( # pyright: ignore[reportIncompatibleMethodOverride]
|
18
20
|
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
19
21
|
) -> None:
|
20
22
|
"""Called when fit, validate, test, predict, or tune begins."""
|
21
23
|
|
24
|
+
@override
|
22
25
|
def teardown( # pyright: ignore[reportIncompatibleMethodOverride]
|
23
26
|
self, trainer: Trainer, pl_module: LightningModule, stage: str
|
24
27
|
) -> None:
|
25
28
|
"""Called when fit, validate, test, predict, or tune ends."""
|
26
29
|
|
30
|
+
@override
|
27
31
|
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
28
32
|
"""Called when fit begins."""
|
29
33
|
|
34
|
+
@override
|
30
35
|
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
31
36
|
"""Called when fit ends."""
|
32
37
|
|
38
|
+
@override
|
33
39
|
def on_sanity_check_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
34
40
|
self, trainer: Trainer, pl_module: LightningModule
|
35
41
|
) -> None:
|
36
42
|
"""Called when the validation sanity check starts."""
|
37
43
|
|
44
|
+
@override
|
38
45
|
def on_sanity_check_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
39
46
|
self, trainer: Trainer, pl_module: LightningModule
|
40
47
|
) -> None:
|
41
48
|
"""Called when the validation sanity check ends."""
|
42
49
|
|
50
|
+
@override
|
43
51
|
def on_train_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
44
52
|
self,
|
45
53
|
trainer: Trainer,
|
@@ -49,6 +57,7 @@ class NTCallbackBase(_LightningCallback):
|
|
49
57
|
) -> None:
|
50
58
|
"""Called when the train batch begins."""
|
51
59
|
|
60
|
+
@override
|
52
61
|
def on_train_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
53
62
|
self,
|
54
63
|
trainer: Trainer,
|
@@ -65,11 +74,13 @@ class NTCallbackBase(_LightningCallback):
|
|
65
74
|
|
66
75
|
"""
|
67
76
|
|
77
|
+
@override
|
68
78
|
def on_train_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
69
79
|
self, trainer: Trainer, pl_module: LightningModule
|
70
80
|
) -> None:
|
71
81
|
"""Called when the train epoch begins."""
|
72
82
|
|
83
|
+
@override
|
73
84
|
def on_train_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
74
85
|
self, trainer: Trainer, pl_module: LightningModule
|
75
86
|
) -> None:
|
@@ -81,10 +92,12 @@ class NTCallbackBase(_LightningCallback):
|
|
81
92
|
.. code-block:: python
|
82
93
|
|
83
94
|
class MyLightningModule(L.LightningModule):
|
95
|
+
@override
|
84
96
|
def __init__(self):
|
85
97
|
super().__init__() # pyright: ignore[reportIncompatibleMethodOverride]
|
86
98
|
self.training_step_outputs = []
|
87
99
|
|
100
|
+
@override
|
88
101
|
def training_step(self):
|
89
102
|
loss = ... # pyright: ignore[reportIncompatibleMethodOverride]
|
90
103
|
self.training_step_outputs.append(loss)
|
@@ -92,6 +105,7 @@ class NTCallbackBase(_LightningCallback):
|
|
92
105
|
|
93
106
|
|
94
107
|
class MyCallback(L.Callback):
|
108
|
+
@override
|
95
109
|
def on_train_epoch_end(self, trainer, pl_module):
|
96
110
|
# do something with all training_step outputs, for example: # pyright: ignore[reportIncompatibleMethodOverride]
|
97
111
|
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
|
@@ -101,36 +115,43 @@ class NTCallbackBase(_LightningCallback):
|
|
101
115
|
|
102
116
|
"""
|
103
117
|
|
118
|
+
@override
|
104
119
|
def on_validation_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
105
120
|
self, trainer: Trainer, pl_module: LightningModule
|
106
121
|
) -> None:
|
107
122
|
"""Called when the val epoch begins."""
|
108
123
|
|
124
|
+
@override
|
109
125
|
def on_validation_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
110
126
|
self, trainer: Trainer, pl_module: LightningModule
|
111
127
|
) -> None:
|
112
128
|
"""Called when the val epoch ends."""
|
113
129
|
|
130
|
+
@override
|
114
131
|
def on_test_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
115
132
|
self, trainer: Trainer, pl_module: LightningModule
|
116
133
|
) -> None:
|
117
134
|
"""Called when the test epoch begins."""
|
118
135
|
|
136
|
+
@override
|
119
137
|
def on_test_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
120
138
|
self, trainer: Trainer, pl_module: LightningModule
|
121
139
|
) -> None:
|
122
140
|
"""Called when the test epoch ends."""
|
123
141
|
|
142
|
+
@override
|
124
143
|
def on_predict_epoch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
125
144
|
self, trainer: Trainer, pl_module: LightningModule
|
126
145
|
) -> None:
|
127
146
|
"""Called when the predict epoch begins."""
|
128
147
|
|
148
|
+
@override
|
129
149
|
def on_predict_epoch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
130
150
|
self, trainer: Trainer, pl_module: LightningModule
|
131
151
|
) -> None:
|
132
152
|
"""Called when the predict epoch ends."""
|
133
153
|
|
154
|
+
@override
|
134
155
|
def on_validation_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
135
156
|
self,
|
136
157
|
trainer: Trainer,
|
@@ -141,6 +162,7 @@ class NTCallbackBase(_LightningCallback):
|
|
141
162
|
) -> None:
|
142
163
|
"""Called when the validation batch begins."""
|
143
164
|
|
165
|
+
@override
|
144
166
|
def on_validation_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
145
167
|
self,
|
146
168
|
trainer: Trainer,
|
@@ -152,6 +174,7 @@ class NTCallbackBase(_LightningCallback):
|
|
152
174
|
) -> None:
|
153
175
|
"""Called when the validation batch ends."""
|
154
176
|
|
177
|
+
@override
|
155
178
|
def on_test_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
156
179
|
self,
|
157
180
|
trainer: Trainer,
|
@@ -162,6 +185,7 @@ class NTCallbackBase(_LightningCallback):
|
|
162
185
|
) -> None:
|
163
186
|
"""Called when the test batch begins."""
|
164
187
|
|
188
|
+
@override
|
165
189
|
def on_test_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
166
190
|
self,
|
167
191
|
trainer: Trainer,
|
@@ -173,6 +197,7 @@ class NTCallbackBase(_LightningCallback):
|
|
173
197
|
) -> None:
|
174
198
|
"""Called when the test batch ends."""
|
175
199
|
|
200
|
+
@override
|
176
201
|
def on_predict_batch_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
177
202
|
self,
|
178
203
|
trainer: Trainer,
|
@@ -183,6 +208,7 @@ class NTCallbackBase(_LightningCallback):
|
|
183
208
|
) -> None:
|
184
209
|
"""Called when the predict batch begins."""
|
185
210
|
|
211
|
+
@override
|
186
212
|
def on_predict_batch_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
187
213
|
self,
|
188
214
|
trainer: Trainer,
|
@@ -194,36 +220,45 @@ class NTCallbackBase(_LightningCallback):
|
|
194
220
|
) -> None:
|
195
221
|
"""Called when the predict batch ends."""
|
196
222
|
|
223
|
+
@override
|
197
224
|
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
198
225
|
"""Called when the train begins."""
|
199
226
|
|
227
|
+
@override
|
200
228
|
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
201
229
|
"""Called when the train ends."""
|
202
230
|
|
231
|
+
@override
|
203
232
|
def on_validation_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
204
233
|
self, trainer: Trainer, pl_module: LightningModule
|
205
234
|
) -> None:
|
206
235
|
"""Called when the validation loop begins."""
|
207
236
|
|
237
|
+
@override
|
208
238
|
def on_validation_end( # pyright: ignore[reportIncompatibleMethodOverride]
|
209
239
|
self, trainer: Trainer, pl_module: LightningModule
|
210
240
|
) -> None:
|
211
241
|
"""Called when the validation loop ends."""
|
212
242
|
|
243
|
+
@override
|
213
244
|
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
214
245
|
"""Called when the test begins."""
|
215
246
|
|
247
|
+
@override
|
216
248
|
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
217
249
|
"""Called when the test ends."""
|
218
250
|
|
251
|
+
@override
|
219
252
|
def on_predict_start( # pyright: ignore[reportIncompatibleMethodOverride]
|
220
253
|
self, trainer: Trainer, pl_module: LightningModule
|
221
254
|
) -> None:
|
222
255
|
"""Called when the predict begins."""
|
223
256
|
|
257
|
+
@override
|
224
258
|
def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pyright: ignore[reportIncompatibleMethodOverride]
|
225
259
|
"""Called when predict ends."""
|
226
260
|
|
261
|
+
@override
|
227
262
|
def on_exception( # pyright: ignore[reportIncompatibleMethodOverride]
|
228
263
|
self,
|
229
264
|
trainer: Trainer,
|
@@ -232,7 +267,8 @@ class NTCallbackBase(_LightningCallback):
|
|
232
267
|
) -> None:
|
233
268
|
"""Called when any trainer execution is interrupted by an exception."""
|
234
269
|
|
235
|
-
|
270
|
+
@override
|
271
|
+
def state_dict(self) -> dict[str, Any]:
|
236
272
|
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``.
|
237
273
|
|
238
274
|
Returns:
|
@@ -241,7 +277,8 @@ class NTCallbackBase(_LightningCallback):
|
|
241
277
|
"""
|
242
278
|
return {}
|
243
279
|
|
244
|
-
|
280
|
+
@override
|
281
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
245
282
|
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
|
246
283
|
|
247
284
|
Args:
|
@@ -250,6 +287,7 @@ class NTCallbackBase(_LightningCallback):
|
|
250
287
|
"""
|
251
288
|
pass
|
252
289
|
|
290
|
+
@override
|
253
291
|
def on_save_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
254
292
|
self,
|
255
293
|
trainer: Trainer,
|
@@ -265,6 +303,7 @@ class NTCallbackBase(_LightningCallback):
|
|
265
303
|
|
266
304
|
"""
|
267
305
|
|
306
|
+
@override
|
268
307
|
def on_load_checkpoint( # pyright: ignore[reportIncompatibleMethodOverride]
|
269
308
|
self,
|
270
309
|
trainer: Trainer,
|
@@ -280,16 +319,19 @@ class NTCallbackBase(_LightningCallback):
|
|
280
319
|
|
281
320
|
"""
|
282
321
|
|
322
|
+
@override
|
283
323
|
def on_before_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
284
324
|
self, trainer: Trainer, pl_module: LightningModule, loss: torch.Tensor
|
285
325
|
) -> None:
|
286
326
|
"""Called before ``loss.backward()``."""
|
287
327
|
|
328
|
+
@override
|
288
329
|
def on_after_backward( # pyright: ignore[reportIncompatibleMethodOverride]
|
289
330
|
self, trainer: Trainer, pl_module: LightningModule
|
290
331
|
) -> None:
|
291
332
|
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
292
333
|
|
334
|
+
@override
|
293
335
|
def on_before_optimizer_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
294
336
|
self,
|
295
337
|
trainer: Trainer,
|
@@ -298,6 +340,7 @@ class NTCallbackBase(_LightningCallback):
|
|
298
340
|
) -> None:
|
299
341
|
"""Called before ``optimizer.step()``."""
|
300
342
|
|
343
|
+
@override
|
301
344
|
def on_before_zero_grad( # pyright: ignore[reportIncompatibleMethodOverride]
|
302
345
|
self,
|
303
346
|
trainer: Trainer,
|
@@ -306,7 +349,10 @@ class NTCallbackBase(_LightningCallback):
|
|
306
349
|
) -> None:
|
307
350
|
"""Called before ``optimizer.zero_grad()``."""
|
308
351
|
|
309
|
-
|
352
|
+
# =================================================================
|
353
|
+
# Our own new callbacks
|
354
|
+
# =================================================================
|
355
|
+
def on_checkpoint_saved(
|
310
356
|
self,
|
311
357
|
ckpt_path: Path,
|
312
358
|
metadata_path: Path | None,
|
@@ -317,6 +363,7 @@ class NTCallbackBase(_LightningCallback):
|
|
317
363
|
pass
|
318
364
|
|
319
365
|
|
366
|
+
@override
|
320
367
|
def _call_on_checkpoint_saved(
|
321
368
|
trainer: Trainer,
|
322
369
|
ckpt_path: str | Path,
|
@@ -75,5 +75,5 @@ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
|
|
75
75
|
|
76
76
|
CallbackConfig = TypeAliasType(
|
77
77
|
"CallbackConfig",
|
78
|
-
Annotated[CallbackConfigBase, callback_registry
|
78
|
+
Annotated[CallbackConfigBase, callback_registry],
|
79
79
|
)
|
@@ -5,13 +5,13 @@ import string
|
|
5
5
|
from abc import ABC, abstractmethod
|
6
6
|
from collections.abc import Callable
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import torch
|
12
12
|
from lightning.pytorch import Trainer
|
13
13
|
from lightning.pytorch.callbacks import Checkpoint
|
14
|
-
from typing_extensions import override
|
14
|
+
from typing_extensions import TypeVar, override
|
15
15
|
|
16
16
|
from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
|
17
17
|
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
@@ -0,0 +1,136 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import math
|
5
|
+
from typing import Any, Literal
|
6
|
+
|
7
|
+
from lightning.pytorch import LightningModule, Trainer
|
8
|
+
from lightning.pytorch.callbacks import Callback
|
9
|
+
from typing_extensions import final, override
|
10
|
+
|
11
|
+
from .base import CallbackConfigBase, callback_registry
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
@final
|
17
|
+
@callback_registry.register
|
18
|
+
class LogEpochCallbackConfig(CallbackConfigBase):
|
19
|
+
name: Literal["log_epoch"] = "log_epoch"
|
20
|
+
|
21
|
+
metric_name: str = "computed_epoch"
|
22
|
+
"""The name of the metric to log the epoch as."""
|
23
|
+
|
24
|
+
train: bool = True
|
25
|
+
"""Whether to log the epoch during training."""
|
26
|
+
|
27
|
+
val: bool = True
|
28
|
+
"""Whether to log the epoch during validation."""
|
29
|
+
|
30
|
+
test: bool = True
|
31
|
+
"""Whether to log the epoch during testing."""
|
32
|
+
|
33
|
+
@override
|
34
|
+
def create_callbacks(self, trainer_config):
|
35
|
+
yield LogEpochCallback(self)
|
36
|
+
|
37
|
+
|
38
|
+
def _worker_fn(
|
39
|
+
trainer: Trainer,
|
40
|
+
pl_module: LightningModule,
|
41
|
+
num_batches_prop: str,
|
42
|
+
dataloader_idx: int | None = None,
|
43
|
+
*,
|
44
|
+
metric_name: str,
|
45
|
+
):
|
46
|
+
if trainer.logger is None:
|
47
|
+
return
|
48
|
+
|
49
|
+
# If trainer.num_{training/val/test}_batches is not set or is nan/inf, we cannot calculate the epoch
|
50
|
+
if not (num_batches := getattr(trainer, num_batches_prop, None)):
|
51
|
+
log.warning(f"Trainer has no valid `{num_batches_prop}`. Cannot log epoch.")
|
52
|
+
return
|
53
|
+
|
54
|
+
# If the trainer has a dataloader_idx, num_batches is a list of num_batches for each dataloader.
|
55
|
+
if dataloader_idx is not None:
|
56
|
+
assert isinstance(num_batches, list), (
|
57
|
+
f"Expected num_batches to be a list, got {type(num_batches)}"
|
58
|
+
)
|
59
|
+
assert 0 <= dataloader_idx < len(num_batches), (
|
60
|
+
f"Expected dataloader_idx to be between 0 and {len(num_batches)}, got {dataloader_idx}"
|
61
|
+
)
|
62
|
+
num_batches = num_batches[dataloader_idx]
|
63
|
+
|
64
|
+
if (
|
65
|
+
not isinstance(num_batches, (int, float))
|
66
|
+
or math.isnan(num_batches)
|
67
|
+
or math.isinf(num_batches)
|
68
|
+
):
|
69
|
+
log.warning(
|
70
|
+
f"Trainer has no valid `{num_batches_prop}` (got {num_batches=}). Cannot log epoch."
|
71
|
+
)
|
72
|
+
return
|
73
|
+
|
74
|
+
epoch = pl_module.global_step / num_batches
|
75
|
+
pl_module.log(metric_name, epoch, on_step=True, on_epoch=False)
|
76
|
+
|
77
|
+
|
78
|
+
class LogEpochCallback(Callback):
|
79
|
+
def __init__(self, config: LogEpochCallbackConfig):
|
80
|
+
super().__init__()
|
81
|
+
|
82
|
+
self.config = config
|
83
|
+
|
84
|
+
@override
|
85
|
+
def on_train_batch_start(
|
86
|
+
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
|
87
|
+
):
|
88
|
+
if trainer.logger is None or not self.config.train:
|
89
|
+
return
|
90
|
+
|
91
|
+
_worker_fn(
|
92
|
+
trainer,
|
93
|
+
pl_module,
|
94
|
+
"num_training_batches",
|
95
|
+
metric_name=self.config.metric_name,
|
96
|
+
)
|
97
|
+
|
98
|
+
@override
|
99
|
+
def on_validation_batch_start(
|
100
|
+
self,
|
101
|
+
trainer: Trainer,
|
102
|
+
pl_module: LightningModule,
|
103
|
+
batch: Any,
|
104
|
+
batch_idx: int,
|
105
|
+
dataloader_idx: int = 0,
|
106
|
+
) -> None:
|
107
|
+
if trainer.logger is None or not self.config.val:
|
108
|
+
return
|
109
|
+
|
110
|
+
_worker_fn(
|
111
|
+
trainer,
|
112
|
+
pl_module,
|
113
|
+
"num_val_batches",
|
114
|
+
dataloader_idx=dataloader_idx,
|
115
|
+
metric_name=self.config.metric_name,
|
116
|
+
)
|
117
|
+
|
118
|
+
@override
|
119
|
+
def on_test_batch_start(
|
120
|
+
self,
|
121
|
+
trainer: Trainer,
|
122
|
+
pl_module: LightningModule,
|
123
|
+
batch: Any,
|
124
|
+
batch_idx: int,
|
125
|
+
dataloader_idx: int = 0,
|
126
|
+
) -> None:
|
127
|
+
if trainer.logger is None or not self.config.test:
|
128
|
+
return
|
129
|
+
|
130
|
+
_worker_fn(
|
131
|
+
trainer,
|
132
|
+
pl_module,
|
133
|
+
"num_test_batches",
|
134
|
+
dataloader_idx=dataloader_idx,
|
135
|
+
metric_name=self.config.metric_name,
|
136
|
+
)
|
@@ -67,14 +67,14 @@ class PrintTableMetricsCallback(Callback):
|
|
67
67
|
}
|
68
68
|
self.metrics.append(metrics_dict)
|
69
69
|
|
70
|
-
from rich.console import Console #
|
70
|
+
from rich.console import Console # pyright: ignore[reportMissingImports] # noqa
|
71
71
|
|
72
72
|
console = Console()
|
73
73
|
table = self.create_metrics_table()
|
74
74
|
console.print(table)
|
75
75
|
|
76
76
|
def create_metrics_table(self):
|
77
|
-
from rich.table import Table #
|
77
|
+
from rich.table import Table # pyright: ignore[reportMissingImports] # noqa
|
78
78
|
|
79
79
|
table = Table(show_header=True, header_style="bold magenta")
|
80
80
|
|
@@ -111,7 +111,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
|
111
111
|
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
112
112
|
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
113
113
|
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
114
|
-
from nshtrainer.optimizer import Union as Union
|
115
114
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
116
115
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
117
116
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
@@ -355,7 +354,6 @@ __all__ = [
|
|
355
354
|
"TorchSyncBatchNormPlugin",
|
356
355
|
"TrainerConfig",
|
357
356
|
"TransformerEnginePluginConfig",
|
358
|
-
"Union",
|
359
357
|
"WandbLoggerConfig",
|
360
358
|
"WandbUploadCodeCallbackConfig",
|
361
359
|
"WandbWatchCallbackConfig",
|
@@ -16,7 +16,6 @@ from nshtrainer.optimizer import RAdamConfig as RAdamConfig
|
|
16
16
|
from nshtrainer.optimizer import RMSpropConfig as RMSpropConfig
|
17
17
|
from nshtrainer.optimizer import RpropConfig as RpropConfig
|
18
18
|
from nshtrainer.optimizer import SGDConfig as SGDConfig
|
19
|
-
from nshtrainer.optimizer import Union as Union
|
20
19
|
from nshtrainer.optimizer import optimizer_registry as optimizer_registry
|
21
20
|
|
22
21
|
__all__ = [
|
@@ -34,6 +33,5 @@ __all__ = [
|
|
34
33
|
"RMSpropConfig",
|
35
34
|
"RpropConfig",
|
36
35
|
"SGDConfig",
|
37
|
-
"Union",
|
38
36
|
"optimizer_registry",
|
39
37
|
]
|
@@ -12,6 +12,5 @@ from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
|
12
12
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
13
13
|
|
14
14
|
LoggerConfig = TypeAliasType(
|
15
|
-
"LoggerConfig",
|
16
|
-
Annotated[LoggerConfigBase, logger_registry.DynamicResolution()],
|
15
|
+
"LoggerConfig", Annotated[LoggerConfigBase, logger_registry]
|
17
16
|
)
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
from lightning.pytorch.loggers import Logger
|
8
|
-
from typing_extensions import final
|
8
|
+
from typing_extensions import final, override
|
9
9
|
|
10
10
|
from .base import LoggerConfigBase, logger_registry
|
11
11
|
|
@@ -15,6 +15,7 @@ from .base import LoggerConfigBase, logger_registry
|
|
15
15
|
class ActSaveLoggerConfig(LoggerConfigBase):
|
16
16
|
name: Literal["actsave"] = "actsave"
|
17
17
|
|
18
|
+
@override
|
18
19
|
def create_logger(self, trainer_config):
|
19
20
|
if not self.enabled:
|
20
21
|
return None
|
@@ -24,10 +25,12 @@ class ActSaveLoggerConfig(LoggerConfigBase):
|
|
24
25
|
|
25
26
|
class ActSaveLogger(Logger):
|
26
27
|
@property
|
28
|
+
@override
|
27
29
|
def name(self):
|
28
30
|
return None
|
29
31
|
|
30
32
|
@property
|
33
|
+
@override
|
31
34
|
def version(self):
|
32
35
|
from nshutils import ActSave
|
33
36
|
|
@@ -37,6 +40,7 @@ class ActSaveLogger(Logger):
|
|
37
40
|
return ActSave._saver._id
|
38
41
|
|
39
42
|
@property
|
43
|
+
@override
|
40
44
|
def save_dir(self):
|
41
45
|
from nshutils import ActSave
|
42
46
|
|
@@ -45,6 +49,7 @@ class ActSaveLogger(Logger):
|
|
45
49
|
|
46
50
|
return str(ActSave._saver._save_dir)
|
47
51
|
|
52
|
+
@override
|
48
53
|
def log_hyperparams(
|
49
54
|
self,
|
50
55
|
params: dict[str, Any] | Namespace,
|
@@ -56,6 +61,7 @@ class ActSaveLogger(Logger):
|
|
56
61
|
# Wrap the hparams as a object-dtype np array
|
57
62
|
return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
|
58
63
|
|
64
|
+
@override
|
59
65
|
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
|
60
66
|
from nshutils import ActSave
|
61
67
|
|
@@ -63,7 +63,7 @@ class FinishWandbOnTeardownCallback(Callback):
|
|
63
63
|
stage: str,
|
64
64
|
):
|
65
65
|
try:
|
66
|
-
import wandb
|
66
|
+
import wandb
|
67
67
|
except ImportError:
|
68
68
|
return
|
69
69
|
|
@@ -139,7 +139,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
139
139
|
# If `wandb-core` is enabled, we should use the new backend.
|
140
140
|
if self.use_wandb_core:
|
141
141
|
try:
|
142
|
-
import wandb
|
142
|
+
import wandb
|
143
143
|
|
144
144
|
# The minimum version that supports the new backend is 0.17.5
|
145
145
|
wandb_version = version.parse(importlib.metadata.version("wandb"))
|
@@ -151,7 +151,7 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
151
151
|
)
|
152
152
|
# W&B versions 0.18.0 use wandb-core by default
|
153
153
|
elif wandb_version < version.parse("0.18.0"):
|
154
|
-
wandb.require("core")
|
154
|
+
wandb.require("core")
|
155
155
|
log.critical("Using the `wandb-core` backend for WandB.")
|
156
156
|
except ImportError:
|
157
157
|
pass
|
@@ -166,9 +166,9 @@ class WandbLoggerConfig(CallbackConfigBase, LoggerConfigBase):
|
|
166
166
|
"If you want to use the new `wandb-core` backend, set `use_wandb_core=True`."
|
167
167
|
)
|
168
168
|
try:
|
169
|
-
import wandb
|
169
|
+
import wandb
|
170
170
|
|
171
|
-
wandb.require("legacy-service")
|
171
|
+
wandb.require("legacy-service")
|
172
172
|
except ImportError:
|
173
173
|
pass
|
174
174
|
|
@@ -81,7 +81,7 @@ class LRSchedulerConfigBase(C.Config, ABC):
|
|
81
81
|
scheduler["monitor"] = metadata["monitor"]
|
82
82
|
# - `strict`
|
83
83
|
if scheduler.get("strict") is None and "strict" in metadata:
|
84
|
-
scheduler["strict"] = metadata["strict"]
|
84
|
+
scheduler["strict"] = metadata["strict"]
|
85
85
|
|
86
86
|
return scheduler
|
87
87
|
|