nshtrainer 1.0.0b47__tar.gz → 1.0.0b50__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/PKG-INFO +1 -1
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/pyproject.toml +1 -1
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +3 -3
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/early_stopping.py +1 -1
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/metric_validation.py +3 -3
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/datamodule.py +2 -2
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/__init__.py +0 -1
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +18 -11
- nshtrainer-1.0.0b50/src/nshtrainer/metrics/_config.py +25 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/base.py +4 -4
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/mixins/debug.py +1 -1
- nshtrainer-1.0.0b50/src/nshtrainer/model/mixins/logger.py +275 -0
- nshtrainer-1.0.0b47/src/nshtrainer/metrics/_config.py +0 -42
- nshtrainer-1.0.0b47/src/nshtrainer/model/mixins/logger.py +0 -181
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/README.md +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/lr_monitor.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_directory/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/_config.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/util/typing_utils.py +0 -0
{nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
@@ -51,7 +51,7 @@ class BestCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
|
51
51
|
class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
|
52
52
|
@property
|
53
53
|
def _metric_name_normalized(self):
|
54
|
-
return self.metric.
|
54
|
+
return self.metric.monitor.replace("/", "_").replace(" ", "_").replace(".", "_")
|
55
55
|
|
56
56
|
@override
|
57
57
|
def __init__(
|
@@ -69,12 +69,12 @@ class BestCheckpointCallback(CheckpointBase[BestCheckpointCallbackConfig]):
|
|
69
69
|
|
70
70
|
@override
|
71
71
|
def default_filename(self):
|
72
|
-
return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.
|
72
|
+
return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.monitor}}}"
|
73
73
|
|
74
74
|
@override
|
75
75
|
def topk_sort_key(self, metadata: CheckpointMetadata):
|
76
76
|
return metadata.metrics.get(
|
77
|
-
self.metric.
|
77
|
+
self.metric.monitor,
|
78
78
|
float("-inf" if self.metric.mode == "max" else "inf"),
|
79
79
|
)
|
80
80
|
|
@@ -68,7 +68,7 @@ class EarlyStoppingCallback(_EarlyStopping):
|
|
68
68
|
del config, metric
|
69
69
|
|
70
70
|
super().__init__(
|
71
|
-
monitor=self.metric.
|
71
|
+
monitor=self.metric.monitor,
|
72
72
|
mode=self.metric.mode,
|
73
73
|
patience=self.config.patience,
|
74
74
|
min_delta=self.config.min_delta,
|
@@ -55,14 +55,14 @@ class MetricValidationCallback(Callback):
|
|
55
55
|
self.metrics = metrics
|
56
56
|
|
57
57
|
def _check_metrics(self, trainer: Trainer):
|
58
|
-
metric_names = ", ".join(metric.
|
58
|
+
metric_names = ", ".join(metric.monitor for metric in self.metrics)
|
59
59
|
log.info(f"Validating metrics: {metric_names}...")
|
60
60
|
logged_metrics = set(trainer.logged_metrics.keys())
|
61
61
|
|
62
62
|
invalid_metrics: list[str] = []
|
63
63
|
for metric in self.metrics:
|
64
|
-
if metric.
|
65
|
-
invalid_metrics.append(metric.
|
64
|
+
if metric.monitor not in logged_metrics:
|
65
|
+
invalid_metrics.append(metric.monitor)
|
66
66
|
|
67
67
|
if invalid_metrics:
|
68
68
|
msg = (
|
@@ -171,7 +171,7 @@ class CustomRLPImplementation(Protocol):
|
|
171
171
|
__reduce_lr_on_plateau__: bool
|
172
172
|
|
173
173
|
|
174
|
-
class
|
174
|
+
class RLPSanityCheckModuleMixin(LightningModule):
|
175
175
|
def reduce_lr_on_plateau_config(
|
176
176
|
self,
|
177
177
|
lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
|
@@ -11,13 +11,13 @@ from lightning.pytorch import LightningDataModule
|
|
11
11
|
from typing_extensions import Never, TypeVar, deprecated, override
|
12
12
|
|
13
13
|
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
14
|
-
from ..model.mixins.debug import
|
14
|
+
from ..model.mixins.debug import DebugModuleMixin
|
15
15
|
|
16
16
|
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
17
17
|
|
18
18
|
|
19
19
|
class LightningDataModuleBase(
|
20
|
-
|
20
|
+
DebugModuleMixin,
|
21
21
|
CallbackRegistrarModuleMixin,
|
22
22
|
LightningDataModule,
|
23
23
|
ABC,
|
{nshtrainer-1.0.0b47 → nshtrainer-1.0.0b50}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py
RENAMED
@@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
7
7
|
from typing_extensions import final, override
|
8
8
|
|
9
9
|
from ..metrics._config import MetricConfig
|
10
|
+
from ..util.config import EpochsConfig
|
10
11
|
from .base import LRSchedulerConfigBase, LRSchedulerMetadata, lr_scheduler_registry
|
11
12
|
|
12
13
|
|
@@ -21,13 +22,13 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
21
22
|
"""Metric to monitor.
|
22
23
|
If not provided, the primary metric of the runner will be used."""
|
23
24
|
|
24
|
-
patience: int
|
25
|
+
patience: int | EpochsConfig
|
25
26
|
r"""Number of epochs with no improvement after which learning rate will be reduced."""
|
26
27
|
|
27
28
|
factor: float
|
28
29
|
r"""Factor by which the learning rate will be reduced. new_lr = lr * factor."""
|
29
30
|
|
30
|
-
cooldown: int = 0
|
31
|
+
cooldown: int | EpochsConfig = 0
|
31
32
|
r"""Number of epochs to wait before resuming normal operation after lr has been reduced."""
|
32
33
|
|
33
34
|
min_lr: float | list[float] = 0.0
|
@@ -49,28 +50,34 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
49
50
|
if (metric := self.metric) is None:
|
50
51
|
from ..trainer import Trainer
|
51
52
|
|
52
|
-
assert isinstance(
|
53
|
-
trainer
|
54
|
-
)
|
53
|
+
assert isinstance(trainer := lightning_module.trainer, Trainer), (
|
54
|
+
"The trainer must be a `nshtrainer.Trainer` instance."
|
55
|
+
)
|
55
56
|
|
56
|
-
assert (
|
57
|
-
metric
|
58
|
-
)
|
57
|
+
assert (metric := trainer.hparams.primary_metric) is not None, (
|
58
|
+
"Primary metric must be provided if metric is not specified."
|
59
|
+
)
|
60
|
+
|
61
|
+
if isinstance(patience := self.patience, EpochsConfig):
|
62
|
+
patience = int(patience.value)
|
63
|
+
|
64
|
+
if isinstance(cooldown := self.cooldown, EpochsConfig):
|
65
|
+
cooldown = int(cooldown.value)
|
59
66
|
|
60
67
|
lr_scheduler = ReduceLROnPlateau(
|
61
68
|
optimizer,
|
62
69
|
mode=metric.mode,
|
63
70
|
factor=self.factor,
|
64
|
-
patience=
|
71
|
+
patience=patience,
|
65
72
|
threshold=self.threshold,
|
66
73
|
threshold_mode=self.threshold_mode,
|
67
|
-
cooldown=
|
74
|
+
cooldown=cooldown,
|
68
75
|
min_lr=self.min_lr,
|
69
76
|
eps=self.eps,
|
70
77
|
)
|
71
78
|
return {
|
72
79
|
"scheduler": lr_scheduler,
|
73
|
-
"monitor": metric.
|
80
|
+
"monitor": metric.monitor,
|
74
81
|
}
|
75
82
|
|
76
83
|
@override
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import builtins
|
4
|
+
from typing import Any, Literal
|
5
|
+
|
6
|
+
import nshconfig as C
|
7
|
+
|
8
|
+
|
9
|
+
class MetricConfig(C.Config):
|
10
|
+
monitor: str
|
11
|
+
"""The name of the metric to monitor."""
|
12
|
+
|
13
|
+
mode: Literal["min", "max"]
|
14
|
+
"""
|
15
|
+
The mode of the primary metric:
|
16
|
+
- "min" for metrics that should be minimized (e.g., loss)
|
17
|
+
- "max" for metrics that should be maximized (e.g., accuracy)
|
18
|
+
"""
|
19
|
+
|
20
|
+
@property
|
21
|
+
def best(self):
|
22
|
+
return builtins.min if self.mode == "min" else builtins.max
|
23
|
+
|
24
|
+
def is_better(self, a: Any, b: Any):
|
25
|
+
return self.best(a, b) == a
|
@@ -15,9 +15,9 @@ from lightning.pytorch.utilities.model_helpers import is_overridden
|
|
15
15
|
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
16
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
17
17
|
|
18
|
-
from ..callbacks.rlp_sanity_checks import
|
18
|
+
from ..callbacks.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
19
19
|
from .mixins.callback import CallbackModuleMixin
|
20
|
-
from .mixins.debug import
|
20
|
+
from .mixins.debug import DebugModuleMixin
|
21
21
|
from .mixins.logger import LoggerLightningModuleMixin
|
22
22
|
|
23
23
|
log = logging.getLogger(__name__)
|
@@ -54,8 +54,8 @@ VALID_REDUCE_OPS = (
|
|
54
54
|
|
55
55
|
|
56
56
|
class LightningModuleBase(
|
57
|
-
|
58
|
-
|
57
|
+
DebugModuleMixin,
|
58
|
+
RLPSanityCheckModuleMixin,
|
59
59
|
LoggerLightningModuleMixin,
|
60
60
|
CallbackModuleMixin,
|
61
61
|
LightningModule,
|
@@ -0,0 +1,275 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from collections import deque
|
5
|
+
from collections.abc import Callable, Generator, Mapping
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from typing import Any, ClassVar
|
8
|
+
|
9
|
+
import torchmetrics
|
10
|
+
from lightning.pytorch import LightningModule
|
11
|
+
from lightning.pytorch.utilities.types import _METRIC
|
12
|
+
from lightning_utilities.core.rank_zero import rank_zero_warn
|
13
|
+
from typing_extensions import override
|
14
|
+
|
15
|
+
from ...util.typing_utils import mixin_base_type
|
16
|
+
|
17
|
+
|
18
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
19
|
+
class _LogContextKwargs:
|
20
|
+
__ignore_fields__: ClassVar[set[str]] = {"prefix", "disabled"}
|
21
|
+
|
22
|
+
prefix: str | None = None
|
23
|
+
disabled: bool | None = None
|
24
|
+
prog_bar: bool | None = None
|
25
|
+
logger: bool | None = None
|
26
|
+
on_step: bool | None = None
|
27
|
+
on_epoch: bool | None = None
|
28
|
+
reduce_fx: str | Callable | None = None
|
29
|
+
enable_graph: bool | None = None
|
30
|
+
sync_dist: bool | None = None
|
31
|
+
sync_dist_group: Any | None = None
|
32
|
+
add_dataloader_idx: bool | None = None
|
33
|
+
batch_size: int | None = None
|
34
|
+
rank_zero_only: bool | None = None
|
35
|
+
|
36
|
+
def to_dict(self):
|
37
|
+
d = dataclasses.asdict(self)
|
38
|
+
for field in self.__ignore_fields__:
|
39
|
+
d.pop(field, None)
|
40
|
+
|
41
|
+
# Pop all None values
|
42
|
+
for k in list(d.keys()):
|
43
|
+
if d[k] is None:
|
44
|
+
d.pop(k)
|
45
|
+
|
46
|
+
return d
|
47
|
+
|
48
|
+
|
49
|
+
class LoggerLightningModuleMixin(mixin_base_type(LightningModule)):
|
50
|
+
@override
|
51
|
+
def __init__(self, *args, **kwargs):
|
52
|
+
super().__init__(*args, **kwargs)
|
53
|
+
|
54
|
+
self._logger_prefix_stack = deque[_LogContextKwargs]()
|
55
|
+
|
56
|
+
@property
|
57
|
+
def logging_enabled(self) -> bool:
|
58
|
+
# Logging is disabled in barebones mode.
|
59
|
+
if (trainer := self._trainer) is not None and trainer.barebones:
|
60
|
+
# Warn the user once that logging is disabled in barebones mode.
|
61
|
+
if not hasattr(self, "_barebones_logging_warned"):
|
62
|
+
rank_zero_warn(
|
63
|
+
"Logging is disabled in barebones mode. "
|
64
|
+
"This is to reduce the overhead of logging in barebones mode. "
|
65
|
+
"If you want to enable logging, set `barebones=False` in the Trainer.",
|
66
|
+
)
|
67
|
+
self._barebones_logging_warned = True
|
68
|
+
return False
|
69
|
+
|
70
|
+
# If no loggers are registered, then logging is disabled.
|
71
|
+
if not self.logger:
|
72
|
+
return False
|
73
|
+
|
74
|
+
# Check if the topmost non-null context is disabled
|
75
|
+
for context in reversed(self._logger_prefix_stack):
|
76
|
+
if context.disabled is not None:
|
77
|
+
return not context.disabled
|
78
|
+
|
79
|
+
# Otherwise, logging is enabled.
|
80
|
+
return True
|
81
|
+
|
82
|
+
@contextmanager
|
83
|
+
def log_context(
|
84
|
+
self,
|
85
|
+
prefix: str | None = None,
|
86
|
+
disabled: bool | None = None,
|
87
|
+
prog_bar: bool | None = None,
|
88
|
+
logger: bool | None = None,
|
89
|
+
on_step: bool | None = None,
|
90
|
+
on_epoch: bool | None = None,
|
91
|
+
reduce_fx: str | Callable | None = None,
|
92
|
+
enable_graph: bool | None = None,
|
93
|
+
sync_dist: bool | None = None,
|
94
|
+
sync_dist_group: Any | None = None,
|
95
|
+
add_dataloader_idx: bool | None = None,
|
96
|
+
batch_size: int | None = None,
|
97
|
+
rank_zero_only: bool | None = None,
|
98
|
+
) -> Generator[None, None, None]:
|
99
|
+
self._logger_prefix_stack.append(
|
100
|
+
_LogContextKwargs(
|
101
|
+
prefix=prefix,
|
102
|
+
disabled=disabled,
|
103
|
+
prog_bar=prog_bar,
|
104
|
+
logger=logger,
|
105
|
+
on_step=on_step,
|
106
|
+
on_epoch=on_epoch,
|
107
|
+
reduce_fx=reduce_fx,
|
108
|
+
enable_graph=enable_graph,
|
109
|
+
sync_dist=sync_dist,
|
110
|
+
sync_dist_group=sync_dist_group,
|
111
|
+
add_dataloader_idx=add_dataloader_idx,
|
112
|
+
batch_size=batch_size,
|
113
|
+
rank_zero_only=rank_zero_only,
|
114
|
+
)
|
115
|
+
)
|
116
|
+
try:
|
117
|
+
yield
|
118
|
+
finally:
|
119
|
+
_ = self._logger_prefix_stack.pop()
|
120
|
+
|
121
|
+
def _make_prefix_and_kwargs_dict(self, kwargs: _LogContextKwargs):
|
122
|
+
prefix = "".join(c.prefix for c in self._logger_prefix_stack if c.prefix)
|
123
|
+
|
124
|
+
fn_kwargs: dict[str, Any] = {}
|
125
|
+
for c in self._logger_prefix_stack:
|
126
|
+
fn_kwargs.update(c.to_dict())
|
127
|
+
|
128
|
+
fn_kwargs.update(kwargs.to_dict())
|
129
|
+
return prefix, fn_kwargs
|
130
|
+
|
131
|
+
@override
|
132
|
+
def log(
|
133
|
+
self,
|
134
|
+
name: str,
|
135
|
+
value: _METRIC,
|
136
|
+
prog_bar: bool | None = None,
|
137
|
+
logger: bool | None = None,
|
138
|
+
on_step: bool | None = None,
|
139
|
+
on_epoch: bool | None = None,
|
140
|
+
reduce_fx: str | Callable | None = None,
|
141
|
+
enable_graph: bool | None = None,
|
142
|
+
sync_dist: bool | None = None,
|
143
|
+
sync_dist_group: Any | None = None,
|
144
|
+
add_dataloader_idx: bool | None = None,
|
145
|
+
batch_size: int | None = None,
|
146
|
+
metric_attribute: str | None = None,
|
147
|
+
rank_zero_only: bool | None = None,
|
148
|
+
) -> None:
|
149
|
+
"""Log a key, value pair.
|
150
|
+
|
151
|
+
Example::
|
152
|
+
|
153
|
+
self.log('train_loss', loss)
|
154
|
+
|
155
|
+
The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
|
159
|
+
value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
|
160
|
+
prog_bar: if ``True`` logs to the progress bar.
|
161
|
+
logger: if ``True`` logs to the logger.
|
162
|
+
on_step: if ``True`` logs at this step. The default value is determined by the hook.
|
163
|
+
See :ref:`extensions/logging:Automatic Logging` for details.
|
164
|
+
on_epoch: if ``True`` logs epoch accumulated metrics. The default value is determined by the hook.
|
165
|
+
See :ref:`extensions/logging:Automatic Logging` for details.
|
166
|
+
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
|
167
|
+
enable_graph: if ``True``, will not auto detach the graph.
|
168
|
+
sync_dist: if ``True``, reduces the metric across devices. Use with care as this may lead to a significant
|
169
|
+
communication overhead.
|
170
|
+
sync_dist_group: the DDP group to sync across.
|
171
|
+
add_dataloader_idx: if ``True``, appends the index of the current dataloader to
|
172
|
+
the name (when using multiple dataloaders). If False, user needs to give unique names for
|
173
|
+
each dataloader to not mix the values.
|
174
|
+
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
|
175
|
+
but for some data structures you might need to explicitly provide it.
|
176
|
+
metric_attribute: To restore the metric state, Lightning requires the reference of the
|
177
|
+
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
|
178
|
+
rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
|
179
|
+
rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
|
180
|
+
(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
|
181
|
+
:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
|
182
|
+
|
183
|
+
"""
|
184
|
+
# If logging is disabled, then do nothing.
|
185
|
+
if not self.logging_enabled:
|
186
|
+
return
|
187
|
+
|
188
|
+
prefix, fn_kwargs = self._make_prefix_and_kwargs_dict(
|
189
|
+
_LogContextKwargs(
|
190
|
+
prog_bar=prog_bar,
|
191
|
+
logger=logger,
|
192
|
+
on_step=on_step,
|
193
|
+
on_epoch=on_epoch,
|
194
|
+
reduce_fx=reduce_fx,
|
195
|
+
enable_graph=enable_graph,
|
196
|
+
sync_dist=sync_dist,
|
197
|
+
sync_dist_group=sync_dist_group,
|
198
|
+
add_dataloader_idx=add_dataloader_idx,
|
199
|
+
batch_size=batch_size,
|
200
|
+
rank_zero_only=rank_zero_only,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
name = f"{prefix}{name}"
|
204
|
+
return super().log(name, value, metric_attribute=metric_attribute, **fn_kwargs)
|
205
|
+
|
206
|
+
def log_dict(
|
207
|
+
self,
|
208
|
+
dictionary: Mapping[str, _METRIC] | torchmetrics.MetricCollection,
|
209
|
+
prog_bar: bool | None = None,
|
210
|
+
logger: bool | None = None,
|
211
|
+
on_step: bool | None = None,
|
212
|
+
on_epoch: bool | None = None,
|
213
|
+
reduce_fx: str | Callable | None = None,
|
214
|
+
enable_graph: bool | None = None,
|
215
|
+
sync_dist: bool | None = None,
|
216
|
+
sync_dist_group: Any | None = None,
|
217
|
+
add_dataloader_idx: bool | None = None,
|
218
|
+
batch_size: int | None = None,
|
219
|
+
rank_zero_only: bool | None = None,
|
220
|
+
) -> None:
|
221
|
+
"""Log a dictionary of values at once.
|
222
|
+
|
223
|
+
Example::
|
224
|
+
|
225
|
+
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
|
226
|
+
self.log_dict(values)
|
227
|
+
|
228
|
+
Args:
|
229
|
+
dictionary: key value pairs.
|
230
|
+
Keys must be identical across all processes if using DDP or any other distributed strategy.
|
231
|
+
The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
|
232
|
+
prog_bar: if ``True`` logs to the progress base.
|
233
|
+
logger: if ``True`` logs to the logger.
|
234
|
+
on_step: if ``True`` logs at this step.
|
235
|
+
``None`` auto-logs for training_step but not validation/test_step.
|
236
|
+
The default value is determined by the hook.
|
237
|
+
See :ref:`extensions/logging:Automatic Logging` for details.
|
238
|
+
on_epoch: if ``True`` logs epoch accumulated metrics.
|
239
|
+
``None`` auto-logs for val/test step but not ``training_step``.
|
240
|
+
The default value is determined by the hook.
|
241
|
+
See :ref:`extensions/logging:Automatic Logging` for details.
|
242
|
+
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
|
243
|
+
enable_graph: if ``True``, will not auto-detach the graph
|
244
|
+
sync_dist: if ``True``, reduces the metric across GPUs/TPUs. Use with care as this may lead to a significant
|
245
|
+
communication overhead.
|
246
|
+
sync_dist_group: the ddp group to sync across.
|
247
|
+
add_dataloader_idx: if ``True``, appends the index of the current dataloader to
|
248
|
+
the name (when using multiple). If ``False``, user needs to give unique names for
|
249
|
+
each dataloader to not mix values.
|
250
|
+
batch_size: Current batch size. This will be directly inferred from the loaded batch,
|
251
|
+
but some data structures might need to explicitly provide it.
|
252
|
+
rank_zero_only: Tells Lightning if you are calling ``self.log`` from every process (default) or only from
|
253
|
+
rank 0. If ``True``, you won't be able to use this metric as a monitor in callbacks
|
254
|
+
(e.g., early stopping). Warning: Improper use can lead to deadlocks! See
|
255
|
+
:ref:`Advanced Logging <visualize/logging_advanced:rank_zero_only>` for more details.
|
256
|
+
|
257
|
+
"""
|
258
|
+
|
259
|
+
_, fn_kwargs = self._make_prefix_and_kwargs_dict(
|
260
|
+
_LogContextKwargs(
|
261
|
+
prog_bar=prog_bar,
|
262
|
+
logger=logger,
|
263
|
+
on_step=on_step,
|
264
|
+
on_epoch=on_epoch,
|
265
|
+
reduce_fx=reduce_fx,
|
266
|
+
enable_graph=enable_graph,
|
267
|
+
sync_dist=sync_dist,
|
268
|
+
sync_dist_group=sync_dist_group,
|
269
|
+
add_dataloader_idx=add_dataloader_idx,
|
270
|
+
batch_size=batch_size,
|
271
|
+
rank_zero_only=rank_zero_only,
|
272
|
+
)
|
273
|
+
)
|
274
|
+
# NOTE: Prefix will be handled by the individual log calls.
|
275
|
+
return super().log_dict(dictionary, **fn_kwargs)
|
@@ -1,42 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import builtins
|
4
|
-
from typing import Any, Literal
|
5
|
-
|
6
|
-
import nshconfig as C
|
7
|
-
|
8
|
-
|
9
|
-
class MetricConfig(C.Config):
|
10
|
-
name: str
|
11
|
-
"""The name of the primary metric."""
|
12
|
-
|
13
|
-
mode: Literal["min", "max"]
|
14
|
-
"""
|
15
|
-
The mode of the primary metric:
|
16
|
-
- "min" for metrics that should be minimized (e.g., loss)
|
17
|
-
- "max" for metrics that should be maximized (e.g., accuracy)
|
18
|
-
"""
|
19
|
-
|
20
|
-
@property
|
21
|
-
def validation_monitor(self) -> str:
|
22
|
-
return f"val/{self.name}"
|
23
|
-
|
24
|
-
def __post_init__(self):
|
25
|
-
for split in ("train", "val", "test", "predict"):
|
26
|
-
if self.name.startswith(f"{split}/"):
|
27
|
-
raise ValueError(
|
28
|
-
f"Primary metric name should not start with '{split}/'. "
|
29
|
-
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
30
|
-
"The split name is automatically added depending on the context."
|
31
|
-
)
|
32
|
-
|
33
|
-
@classmethod
|
34
|
-
def loss(cls, mode: Literal["min", "max"] = "min"):
|
35
|
-
return cls(name="loss", mode=mode)
|
36
|
-
|
37
|
-
@property
|
38
|
-
def best(self):
|
39
|
-
return builtins.min if self.mode == "min" else builtins.max
|
40
|
-
|
41
|
-
def is_better(self, a: Any, b: Any):
|
42
|
-
return self.best(a, b) == a
|