nshtrainer 1.3.6__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/PKG-INFO +1 -1
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/pyproject.toml +1 -1
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/_checkpoint/metadata.py +4 -1
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/_hf_hub.py +3 -0
- nshtrainer-1.4.0/src/nshtrainer/callbacks/checkpoint/_base.py +320 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/lr_monitor.py +9 -1
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_config.py +8 -2
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/trainer.py +10 -2
- nshtrainer-1.3.6/src/nshtrainer/callbacks/checkpoint/_base.py +0 -187
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/README.md +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/.nshconfig.generated.json +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/distributed_prediction_writer.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/metric_validation.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/wandb_upload_code.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/.gitattributes +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/_checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/_checkpoint/metadata/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/_hf_hub/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/actsave/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/debug_flag/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/directory_setup/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/distributed_prediction_writer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/early_stopping/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/ema/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/finite_checks/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/gradient_skipping/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/log_epoch/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/lr_monitor/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/metric_validation/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/norm_logging/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/print_table/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/shared_parameters/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/timer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/callbacks/wandb_watch/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/actsave/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/csv/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/tensorboard/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/loggers/wandb/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/metrics/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/metrics/_config/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/mlp/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/nonlinearity/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/nn/rng/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/optimizer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/_base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/advanced/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/pytorch/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/profiler/simple/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/_config/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/accelerator/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/base/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/environment/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/io/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/layer_sync/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/plugin/precision/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/strategy/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/trainer/trainer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/_environment_info/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/dtype/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/configs/util/config/duration/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/data/datamodule.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/actsave.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/callback.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/debug.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/nn/rng.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_distributed_prediction_result.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_log_hparams.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/accelerator.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/base.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/environment.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/io.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/layer_sync.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/plugin/precision.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/trainer/strategy.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/bf16.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/code_upload.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/config/__init__.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/config/dtype.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -85,6 +85,7 @@ def _generate_checkpoint_metadata(
|
|
85
85
|
trainer: Trainer,
|
86
86
|
checkpoint_path: Path,
|
87
87
|
metadata_path: Path,
|
88
|
+
compute_checksum: bool = True,
|
88
89
|
):
|
89
90
|
checkpoint_timestamp = datetime.datetime.now()
|
90
91
|
start_timestamp = trainer.start_time()
|
@@ -105,7 +106,9 @@ def _generate_checkpoint_metadata(
|
|
105
106
|
# moving the checkpoint directory
|
106
107
|
checkpoint_path=checkpoint_path.relative_to(metadata_path.parent),
|
107
108
|
checkpoint_filename=checkpoint_path.name,
|
108
|
-
checkpoint_checksum=compute_file_checksum(checkpoint_path)
|
109
|
+
checkpoint_checksum=compute_file_checksum(checkpoint_path)
|
110
|
+
if compute_checksum
|
111
|
+
else "",
|
109
112
|
run_id=trainer.hparams.id,
|
110
113
|
name=trainer.hparams.full_name,
|
111
114
|
project=trainer.hparams.project,
|
@@ -91,6 +91,9 @@ class HuggingFaceHubConfig(CallbackConfigBase):
|
|
91
91
|
|
92
92
|
@override
|
93
93
|
def create_callbacks(self, trainer_config):
|
94
|
+
if not self:
|
95
|
+
return
|
96
|
+
|
94
97
|
# Attempt to login. If it fails, we'll log a warning or error based on the configuration.
|
95
98
|
try:
|
96
99
|
api = _api(self.token)
|
@@ -0,0 +1,320 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import string
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from collections.abc import Callable
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import torch
|
12
|
+
from lightning.pytorch import Trainer
|
13
|
+
from lightning.pytorch.callbacks import Checkpoint
|
14
|
+
from typing_extensions import override
|
15
|
+
|
16
|
+
from ..._checkpoint.metadata import CheckpointMetadata, _generate_checkpoint_metadata
|
17
|
+
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
18
|
+
from ..base import CallbackConfigBase
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from ...trainer._config import TrainerConfig
|
22
|
+
|
23
|
+
|
24
|
+
log = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class _FormatDict(dict):
|
28
|
+
"""A dictionary that returns an empty string for missing keys when formatting."""
|
29
|
+
|
30
|
+
def __missing__(self, key):
|
31
|
+
log.debug(
|
32
|
+
f"Missing format key '{key}' in checkpoint filename, using empty string"
|
33
|
+
)
|
34
|
+
return ""
|
35
|
+
|
36
|
+
|
37
|
+
def _get_checkpoint_metadata(dirpath: Path) -> list[CheckpointMetadata]:
|
38
|
+
"""Get all checkpoint metadata from a directory."""
|
39
|
+
return [
|
40
|
+
CheckpointMetadata.from_file(p)
|
41
|
+
for p in dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
42
|
+
if p.is_file() and not p.is_symlink()
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
def _sort_checkpoint_metadata(
|
47
|
+
metas: list[CheckpointMetadata],
|
48
|
+
key_fn: Callable[[CheckpointMetadata], Any],
|
49
|
+
reverse: bool = False,
|
50
|
+
) -> list[CheckpointMetadata]:
|
51
|
+
"""Sort checkpoint metadata by the given key function."""
|
52
|
+
return sorted(metas, key=key_fn, reverse=reverse)
|
53
|
+
|
54
|
+
|
55
|
+
def _remove_checkpoints(
|
56
|
+
trainer: Trainer,
|
57
|
+
dirpath: Path,
|
58
|
+
metas_to_remove: list[CheckpointMetadata],
|
59
|
+
) -> None:
|
60
|
+
"""Remove checkpoint files and their metadata."""
|
61
|
+
for meta in metas_to_remove:
|
62
|
+
ckpt_path = dirpath / meta.checkpoint_filename
|
63
|
+
if not ckpt_path.exists():
|
64
|
+
log.warning(
|
65
|
+
f"Checkpoint file not found: {ckpt_path}\n"
|
66
|
+
"Skipping removal of the checkpoint metadata."
|
67
|
+
)
|
68
|
+
continue
|
69
|
+
|
70
|
+
remove_checkpoint(trainer, ckpt_path, metadata=True)
|
71
|
+
log.debug(f"Removed checkpoint: {ckpt_path}")
|
72
|
+
|
73
|
+
|
74
|
+
def _update_symlink(
|
75
|
+
dirpath: Path,
|
76
|
+
symlink_path: Path | None,
|
77
|
+
sort_key_fn: Callable[[CheckpointMetadata], Any],
|
78
|
+
sort_reverse: bool,
|
79
|
+
) -> None:
|
80
|
+
"""Update symlink to point to the best checkpoint."""
|
81
|
+
if symlink_path is None:
|
82
|
+
return
|
83
|
+
|
84
|
+
# Get all checkpoint metadata after any removals
|
85
|
+
remaining_metas = _get_checkpoint_metadata(dirpath)
|
86
|
+
|
87
|
+
if remaining_metas:
|
88
|
+
# Sort by the key function
|
89
|
+
remaining_metas = _sort_checkpoint_metadata(
|
90
|
+
remaining_metas, sort_key_fn, sort_reverse
|
91
|
+
)
|
92
|
+
|
93
|
+
# Link to the best checkpoint
|
94
|
+
best_meta = remaining_metas[0]
|
95
|
+
best_filepath = dirpath / best_meta.checkpoint_filename
|
96
|
+
link_checkpoint(best_filepath, symlink_path, metadata=True)
|
97
|
+
log.debug(f"Updated symlink {symlink_path.name} -> {best_filepath.name}")
|
98
|
+
else:
|
99
|
+
log.warning(f"No checkpoints found in {dirpath} to create symlink.")
|
100
|
+
|
101
|
+
|
102
|
+
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
103
|
+
dirpath: str | Path | None = None
|
104
|
+
"""Directory path to save the checkpoint file."""
|
105
|
+
|
106
|
+
filename: str | None = None
|
107
|
+
"""Checkpoint filename. This must not include the extension.
|
108
|
+
If None, the default filename will be used."""
|
109
|
+
|
110
|
+
save_weights_only: bool = False
|
111
|
+
"""Whether to save only the model's weights or the entire model object."""
|
112
|
+
|
113
|
+
save_symlink: bool = True
|
114
|
+
"""Whether to create a symlink to the saved checkpoint."""
|
115
|
+
|
116
|
+
topk: int | Literal["all"] = 1
|
117
|
+
"""The number of checkpoints to keep."""
|
118
|
+
|
119
|
+
@abstractmethod
|
120
|
+
def create_checkpoint(
|
121
|
+
self,
|
122
|
+
trainer_config: TrainerConfig,
|
123
|
+
dirpath: Path,
|
124
|
+
) -> "CheckpointBase | None": ...
|
125
|
+
|
126
|
+
@override
|
127
|
+
def create_callbacks(self, trainer_config):
|
128
|
+
dirpath = Path(
|
129
|
+
self.dirpath
|
130
|
+
or trainer_config.directory.resolve_subdirectory(
|
131
|
+
trainer_config.id, "checkpoint"
|
132
|
+
)
|
133
|
+
)
|
134
|
+
|
135
|
+
if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
|
136
|
+
yield callback
|
137
|
+
|
138
|
+
|
139
|
+
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
140
|
+
|
141
|
+
|
142
|
+
class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
143
|
+
def __init__(self, config: TConfig, dirpath: Path):
|
144
|
+
super().__init__()
|
145
|
+
|
146
|
+
self.config = config
|
147
|
+
self.dirpath = dirpath / self.name()
|
148
|
+
self.dirpath.mkdir(parents=True, exist_ok=True)
|
149
|
+
self.symlink_dirpath = dirpath
|
150
|
+
|
151
|
+
@abstractmethod
|
152
|
+
def default_filename(self) -> str: ...
|
153
|
+
|
154
|
+
@abstractmethod
|
155
|
+
def name(self) -> str: ...
|
156
|
+
|
157
|
+
def extension(self) -> str:
|
158
|
+
return ".ckpt"
|
159
|
+
|
160
|
+
@abstractmethod
|
161
|
+
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
162
|
+
|
163
|
+
@abstractmethod
|
164
|
+
def topk_sort_reverse(self) -> bool: ...
|
165
|
+
|
166
|
+
def symlink_path(self):
|
167
|
+
if not self.config.save_symlink:
|
168
|
+
return None
|
169
|
+
|
170
|
+
return self.symlink_dirpath / f"{self.name()}{self.extension()}"
|
171
|
+
|
172
|
+
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
173
|
+
if (filename := self.config.filename) is None:
|
174
|
+
filename = self.default_filename()
|
175
|
+
|
176
|
+
# Extract all field names from the format string
|
177
|
+
field_names = [
|
178
|
+
fname for _, fname, _, _ in string.Formatter().parse(filename) if fname
|
179
|
+
]
|
180
|
+
|
181
|
+
# Filter current_metrics to only include keys that are in the format string
|
182
|
+
format_dict = {k: v for k, v in current_metrics.items() if k in field_names}
|
183
|
+
|
184
|
+
try:
|
185
|
+
formatted_filename = filename.format(**format_dict)
|
186
|
+
except KeyError as e:
|
187
|
+
log.warning(
|
188
|
+
f"Missing key {e} in {filename=} with {format_dict=}. Using default values."
|
189
|
+
)
|
190
|
+
# Provide a simple fallback for missing keys
|
191
|
+
formatted_filename = string.Formatter().vformat(
|
192
|
+
filename, (), _FormatDict(format_dict)
|
193
|
+
)
|
194
|
+
|
195
|
+
return self.dirpath / f"{formatted_filename}{self.extension()}"
|
196
|
+
|
197
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
198
|
+
current_metrics: dict[str, Any] = {
|
199
|
+
"epoch": trainer.current_epoch,
|
200
|
+
"step": trainer.global_step,
|
201
|
+
}
|
202
|
+
|
203
|
+
for name, value in trainer.callback_metrics.items():
|
204
|
+
match value:
|
205
|
+
case torch.Tensor() if value.numel() == 1:
|
206
|
+
value = value.detach().cpu().item()
|
207
|
+
case np.ndarray() if value.size == 1:
|
208
|
+
value = value.item()
|
209
|
+
case _:
|
210
|
+
pass
|
211
|
+
|
212
|
+
current_metrics[name] = value
|
213
|
+
|
214
|
+
log.debug(
|
215
|
+
f"Current metrics: {current_metrics}, {trainer.callback_metrics=}, {trainer.logged_metrics=}"
|
216
|
+
)
|
217
|
+
return current_metrics
|
218
|
+
|
219
|
+
def save_checkpoints(self, trainer: Trainer):
|
220
|
+
log.debug(
|
221
|
+
f"{type(self).__name__}.save_checkpoints() called at {trainer.current_epoch=}, {trainer.global_step=}"
|
222
|
+
)
|
223
|
+
# Also print out the current stack trace for debugging
|
224
|
+
if log.isEnabledFor(logging.DEBUG):
|
225
|
+
import traceback
|
226
|
+
|
227
|
+
stack = traceback.extract_stack()
|
228
|
+
log.debug(f"Stack trace: {''.join(traceback.format_list(stack))}")
|
229
|
+
|
230
|
+
if self._should_skip_saving_checkpoint(trainer):
|
231
|
+
return
|
232
|
+
|
233
|
+
from ...trainer import Trainer as NTTrainer
|
234
|
+
|
235
|
+
if not isinstance(trainer, NTTrainer):
|
236
|
+
raise TypeError(
|
237
|
+
f"Trainer must be an instance of {NTTrainer.__name__}, "
|
238
|
+
f"but got {type(trainer).__name__}"
|
239
|
+
)
|
240
|
+
|
241
|
+
current_metrics = self.current_metrics(trainer)
|
242
|
+
filepath = self.resolve_checkpoint_path(current_metrics)
|
243
|
+
|
244
|
+
# Get all existing checkpoint metadata
|
245
|
+
existing_metas = _get_checkpoint_metadata(self.dirpath)
|
246
|
+
|
247
|
+
# Determine which checkpoints to remove
|
248
|
+
to_remove: list[CheckpointMetadata] = []
|
249
|
+
should_save = True
|
250
|
+
|
251
|
+
# Check if we should save this checkpoint
|
252
|
+
if (topk := self.config.topk) != "all" and len(existing_metas) >= topk:
|
253
|
+
# Generate hypothetical metadata for the current checkpoint
|
254
|
+
hypothetical_meta = _generate_checkpoint_metadata(
|
255
|
+
trainer=trainer,
|
256
|
+
checkpoint_path=filepath,
|
257
|
+
metadata_path=filepath.with_suffix(CheckpointMetadata.PATH_SUFFIX),
|
258
|
+
compute_checksum=False,
|
259
|
+
)
|
260
|
+
|
261
|
+
# Add the hypothetical metadata to the list and sort
|
262
|
+
metas = _sort_checkpoint_metadata(
|
263
|
+
[*existing_metas, hypothetical_meta],
|
264
|
+
self.topk_sort_key,
|
265
|
+
self.topk_sort_reverse(),
|
266
|
+
)
|
267
|
+
|
268
|
+
# If the hypothetical metadata is not in the top-k, skip saving
|
269
|
+
if hypothetical_meta not in metas[:topk]:
|
270
|
+
log.debug(
|
271
|
+
f"Skipping checkpoint save: would not make top {topk} "
|
272
|
+
f"based on {self.topk_sort_key.__name__}"
|
273
|
+
)
|
274
|
+
should_save = False
|
275
|
+
else:
|
276
|
+
# Determine which existing checkpoints to remove
|
277
|
+
to_remove = metas[topk:]
|
278
|
+
assert hypothetical_meta not in to_remove, (
|
279
|
+
"Hypothetical metadata should not be in the to_remove list."
|
280
|
+
)
|
281
|
+
log.debug(
|
282
|
+
f"Removing checkpoints: {[meta.checkpoint_filename for meta in to_remove]} "
|
283
|
+
f"and saving the new checkpoint: {hypothetical_meta.checkpoint_filename}"
|
284
|
+
)
|
285
|
+
|
286
|
+
# Only save if it would make it into the top-k
|
287
|
+
if should_save:
|
288
|
+
# Save the new checkpoint
|
289
|
+
trainer.save_checkpoint(
|
290
|
+
filepath,
|
291
|
+
weights_only=self.config.save_weights_only,
|
292
|
+
)
|
293
|
+
|
294
|
+
if trainer.is_global_zero:
|
295
|
+
# Remove old checkpoints that should be deleted
|
296
|
+
if to_remove:
|
297
|
+
_remove_checkpoints(trainer, self.dirpath, to_remove)
|
298
|
+
|
299
|
+
# Update the symlink to point to the best checkpoint
|
300
|
+
_update_symlink(
|
301
|
+
self.dirpath,
|
302
|
+
self.symlink_path(),
|
303
|
+
self.topk_sort_key,
|
304
|
+
self.topk_sort_reverse(),
|
305
|
+
)
|
306
|
+
|
307
|
+
# Barrier to ensure all processes have completed checkpoint operations
|
308
|
+
trainer.strategy.barrier()
|
309
|
+
|
310
|
+
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
311
|
+
from lightning.pytorch.trainer.states import TrainerFn
|
312
|
+
|
313
|
+
return (
|
314
|
+
bool(
|
315
|
+
getattr(trainer, "fast_dev_run", False)
|
316
|
+
) # disable checkpointing with fast_dev_run
|
317
|
+
or trainer.state.fn
|
318
|
+
!= TrainerFn.FITTING # don't save anything during non-fit
|
319
|
+
or trainer.sanity_checking # don't save anything during sanity check
|
320
|
+
)
|
@@ -1,12 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from typing import Literal
|
4
5
|
|
5
6
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
-
from typing_extensions import final
|
7
|
+
from typing_extensions import final, override
|
7
8
|
|
8
9
|
from .base import CallbackConfigBase, callback_registry
|
9
10
|
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
10
13
|
|
11
14
|
@final
|
12
15
|
@callback_registry.register
|
@@ -28,7 +31,12 @@ class LearningRateMonitorConfig(CallbackConfigBase):
|
|
28
31
|
Option to also log the weight decay values of the optimizer. Defaults to False.
|
29
32
|
"""
|
30
33
|
|
34
|
+
@override
|
31
35
|
def create_callbacks(self, trainer_config):
|
36
|
+
if not list(trainer_config.enabled_loggers()):
|
37
|
+
log.warning("No loggers enabled. LearningRateMonitor will not be used.")
|
38
|
+
return
|
39
|
+
|
32
40
|
yield LearningRateMonitor(
|
33
41
|
logging_interval=self.logging_interval,
|
34
42
|
log_momentum=self.log_momentum,
|
@@ -717,8 +717,9 @@ class TrainerConfig(C.Config):
|
|
717
717
|
|
718
718
|
auto_set_default_root_dir: bool = True
|
719
719
|
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
720
|
-
save_checkpoint_metadata:
|
721
|
-
"""
|
720
|
+
save_checkpoint_metadata: Literal[True] = True
|
721
|
+
"""Will save additional metadata whenever a checkpoint is saved.
|
722
|
+
This is a core feature of nshtrainer and cannot be disabled."""
|
722
723
|
auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
|
723
724
|
"""If enabled, will automatically set the debug flag to True if:
|
724
725
|
- The trainer is running in fast_dev_run mode.
|
@@ -1308,6 +1309,11 @@ class TrainerConfig(C.Config):
|
|
1308
1309
|
if self.barebones and self.shared_parameters:
|
1309
1310
|
raise ValueError("shared_parameters is not supported under barebones mode")
|
1310
1311
|
|
1312
|
+
if not self.save_checkpoint_metadata:
|
1313
|
+
raise ValueError(
|
1314
|
+
"save_checkpoint_metadata must be True. This is a core feature of nshtrainer and cannot be disabled."
|
1315
|
+
)
|
1316
|
+
|
1311
1317
|
def _nshtrainer_set_id_if_missing(self):
|
1312
1318
|
"""
|
1313
1319
|
Set the ID for the configuration object if it is missing.
|
@@ -45,6 +45,9 @@ patch_log_hparams_function()
|
|
45
45
|
|
46
46
|
|
47
47
|
class Trainer(LightningTrainer):
|
48
|
+
profiler: Profiler
|
49
|
+
"""Profiler used for profiling the training process."""
|
50
|
+
|
48
51
|
CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
|
49
52
|
|
50
53
|
@property
|
@@ -469,6 +472,11 @@ class Trainer(LightningTrainer):
|
|
469
472
|
weights_only: bool = False,
|
470
473
|
storage_options: Any | None = None,
|
471
474
|
):
|
475
|
+
assert self.hparams.save_checkpoint_metadata, (
|
476
|
+
"Checkpoint metadata is not enabled. "
|
477
|
+
"Please set `hparams.save_checkpoint_metadata=True`."
|
478
|
+
)
|
479
|
+
|
472
480
|
filepath = Path(filepath)
|
473
481
|
|
474
482
|
if self.model is None:
|
@@ -476,7 +484,7 @@ class Trainer(LightningTrainer):
|
|
476
484
|
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
477
485
|
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
478
486
|
)
|
479
|
-
with self.profiler.profile("save_checkpoint"):
|
487
|
+
with self.profiler.profile("save_checkpoint"):
|
480
488
|
checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
|
481
489
|
# Update the checkpoint for the trainer hyperparameters
|
482
490
|
checkpoint[self.CHECKPOINT_HYPER_PARAMS_KEY] = self.hparams.model_dump(
|
@@ -489,7 +497,7 @@ class Trainer(LightningTrainer):
|
|
489
497
|
|
490
498
|
# Save the checkpoint metadata
|
491
499
|
metadata_path = None
|
492
|
-
if self.
|
500
|
+
if self.is_global_zero:
|
493
501
|
# Generate the metadata and write to disk
|
494
502
|
metadata_path = write_checkpoint_metadata(self, filepath)
|
495
503
|
|
@@ -1,187 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from abc import ABC, abstractmethod
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Generic, Literal
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
import torch
|
10
|
-
from lightning.pytorch import Trainer
|
11
|
-
from lightning.pytorch.callbacks import Checkpoint
|
12
|
-
from typing_extensions import TypeVar, override
|
13
|
-
|
14
|
-
from ..._checkpoint.metadata import CheckpointMetadata
|
15
|
-
from ..._checkpoint.saver import link_checkpoint, remove_checkpoint
|
16
|
-
from ..base import CallbackConfigBase
|
17
|
-
|
18
|
-
if TYPE_CHECKING:
|
19
|
-
from ...trainer._config import TrainerConfig
|
20
|
-
|
21
|
-
|
22
|
-
log = logging.getLogger(__name__)
|
23
|
-
|
24
|
-
|
25
|
-
class BaseCheckpointCallbackConfig(CallbackConfigBase, ABC):
|
26
|
-
dirpath: str | Path | None = None
|
27
|
-
"""Directory path to save the checkpoint file."""
|
28
|
-
|
29
|
-
filename: str | None = None
|
30
|
-
"""Checkpoint filename. This must not include the extension.
|
31
|
-
If None, the default filename will be used."""
|
32
|
-
|
33
|
-
save_weights_only: bool = False
|
34
|
-
"""Whether to save only the model's weights or the entire model object."""
|
35
|
-
|
36
|
-
save_symlink: bool = True
|
37
|
-
"""Whether to create a symlink to the saved checkpoint."""
|
38
|
-
|
39
|
-
topk: int | Literal["all"] = 1
|
40
|
-
"""The number of checkpoints to keep."""
|
41
|
-
|
42
|
-
@abstractmethod
|
43
|
-
def create_checkpoint(
|
44
|
-
self,
|
45
|
-
trainer_config: TrainerConfig,
|
46
|
-
dirpath: Path,
|
47
|
-
) -> "CheckpointBase | None": ...
|
48
|
-
|
49
|
-
@override
|
50
|
-
def create_callbacks(self, trainer_config):
|
51
|
-
dirpath = Path(
|
52
|
-
self.dirpath
|
53
|
-
or trainer_config.directory.resolve_subdirectory(
|
54
|
-
trainer_config.id, "checkpoint"
|
55
|
-
)
|
56
|
-
)
|
57
|
-
|
58
|
-
if (callback := self.create_checkpoint(trainer_config, dirpath)) is not None:
|
59
|
-
yield callback
|
60
|
-
|
61
|
-
|
62
|
-
TConfig = TypeVar("TConfig", bound=BaseCheckpointCallbackConfig, infer_variance=True)
|
63
|
-
|
64
|
-
|
65
|
-
class CheckpointBase(Checkpoint, ABC, Generic[TConfig]):
|
66
|
-
def __init__(self, config: TConfig, dirpath: Path):
|
67
|
-
super().__init__()
|
68
|
-
|
69
|
-
self.config = config
|
70
|
-
self.dirpath = dirpath / self.name()
|
71
|
-
self.dirpath.mkdir(parents=True, exist_ok=True)
|
72
|
-
self.symlink_dirpath = dirpath
|
73
|
-
|
74
|
-
@abstractmethod
|
75
|
-
def default_filename(self) -> str: ...
|
76
|
-
|
77
|
-
@abstractmethod
|
78
|
-
def name(self) -> str: ...
|
79
|
-
|
80
|
-
def extension(self) -> str:
|
81
|
-
return ".ckpt"
|
82
|
-
|
83
|
-
@abstractmethod
|
84
|
-
def topk_sort_key(self, metadata: CheckpointMetadata) -> Any: ...
|
85
|
-
|
86
|
-
@abstractmethod
|
87
|
-
def topk_sort_reverse(self) -> bool: ...
|
88
|
-
|
89
|
-
def symlink_path(self):
|
90
|
-
if not self.config.save_symlink:
|
91
|
-
return None
|
92
|
-
|
93
|
-
return self.symlink_dirpath / f"{self.name()}{self.extension()}"
|
94
|
-
|
95
|
-
def resolve_checkpoint_path(self, current_metrics: dict[str, Any]) -> Path:
|
96
|
-
if (filename := self.config.filename) is None:
|
97
|
-
filename = self.default_filename()
|
98
|
-
filename = filename.format(**current_metrics)
|
99
|
-
return self.dirpath / f"{filename}{self.extension()}"
|
100
|
-
|
101
|
-
def remove_old_checkpoints(self, trainer: Trainer):
|
102
|
-
if (topk := self.config.topk) == "all":
|
103
|
-
return
|
104
|
-
|
105
|
-
# Get all the checkpoint metadata
|
106
|
-
metas = [
|
107
|
-
CheckpointMetadata.from_file(p)
|
108
|
-
for p in self.dirpath.glob(f"*{CheckpointMetadata.PATH_SUFFIX}")
|
109
|
-
if p.is_file() and not p.is_symlink()
|
110
|
-
]
|
111
|
-
|
112
|
-
# Sort by the topk sort key
|
113
|
-
metas = sorted(metas, key=self.topk_sort_key, reverse=self.topk_sort_reverse())
|
114
|
-
|
115
|
-
# Now, the metas are sorted from the best to the worst,
|
116
|
-
# so we can remove the worst checkpoints
|
117
|
-
for meta in metas[topk:]:
|
118
|
-
if not (old_ckpt_path := self.dirpath / meta.checkpoint_filename).exists():
|
119
|
-
log.warning(
|
120
|
-
f"Checkpoint file not found: {old_ckpt_path}\n"
|
121
|
-
"Skipping removal of the checkpoint metadata."
|
122
|
-
)
|
123
|
-
continue
|
124
|
-
|
125
|
-
remove_checkpoint(trainer, old_ckpt_path, metadata=True)
|
126
|
-
log.debug(f"Removed old checkpoint: {old_ckpt_path}")
|
127
|
-
|
128
|
-
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
129
|
-
current_metrics: dict[str, Any] = {
|
130
|
-
"epoch": trainer.current_epoch,
|
131
|
-
"step": trainer.global_step,
|
132
|
-
}
|
133
|
-
|
134
|
-
for name, value in trainer.callback_metrics.items():
|
135
|
-
match value:
|
136
|
-
case torch.Tensor() if value.numel() == 1:
|
137
|
-
value = value.detach().cpu().item()
|
138
|
-
case np.ndarray() if value.size == 1:
|
139
|
-
value = value.item()
|
140
|
-
case _:
|
141
|
-
pass
|
142
|
-
|
143
|
-
current_metrics[name] = value
|
144
|
-
|
145
|
-
return current_metrics
|
146
|
-
|
147
|
-
def save_checkpoints(self, trainer: Trainer):
|
148
|
-
if self._should_skip_saving_checkpoint(trainer):
|
149
|
-
return
|
150
|
-
|
151
|
-
from ...trainer import Trainer as NTTrainer
|
152
|
-
|
153
|
-
if not isinstance(trainer, NTTrainer):
|
154
|
-
raise TypeError(
|
155
|
-
f"Trainer must be an instance of {NTTrainer.__name__}, "
|
156
|
-
f"but got {type(trainer).__name__}"
|
157
|
-
)
|
158
|
-
|
159
|
-
# Save the new checkpoint
|
160
|
-
filepath = self.resolve_checkpoint_path(self.current_metrics(trainer))
|
161
|
-
trainer.save_checkpoint(filepath, self.config.save_weights_only)
|
162
|
-
|
163
|
-
if trainer.hparams.save_checkpoint_metadata and trainer.is_global_zero:
|
164
|
-
# Remove old checkpoints
|
165
|
-
self.remove_old_checkpoints(trainer)
|
166
|
-
|
167
|
-
# Create the latest symlink
|
168
|
-
if (symlink_filename := self.symlink_path()) is not None:
|
169
|
-
symlink_path = self.dirpath / symlink_filename
|
170
|
-
link_checkpoint(filepath, symlink_path, metadata=True)
|
171
|
-
log.debug(f"Created latest symlink: {symlink_path}")
|
172
|
-
|
173
|
-
# Barrier to ensure all processes have saved the checkpoint,
|
174
|
-
# deleted the old checkpoints, and created the symlink before continuing
|
175
|
-
trainer.strategy.barrier()
|
176
|
-
|
177
|
-
def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
|
178
|
-
from lightning.pytorch.trainer.states import TrainerFn
|
179
|
-
|
180
|
-
return (
|
181
|
-
bool(
|
182
|
-
getattr(trainer, "fast_dev_run", False)
|
183
|
-
) # disable checkpointing with fast_dev_run
|
184
|
-
or trainer.state.fn
|
185
|
-
!= TrainerFn.FITTING # don't save anything during non-fit
|
186
|
-
or trainer.sanity_checking # don't save anything during sanity check
|
187
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
File without changes
|
{nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
File without changes
|
{nshtrainer-1.3.6 → nshtrainer-1.4.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py
RENAMED
File without changes
|
File without changes
|
File without changes
|