nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
nshtrainer/trainer/_config.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import logging
|
5
|
+
import os
|
6
|
+
import string
|
7
|
+
import time
|
4
8
|
from collections.abc import Iterable, Sequence
|
5
9
|
from datetime import timedelta
|
6
10
|
from pathlib import Path
|
@@ -8,6 +12,7 @@ from typing import (
|
|
8
12
|
TYPE_CHECKING,
|
9
13
|
Annotated,
|
10
14
|
Any,
|
15
|
+
ClassVar,
|
11
16
|
Literal,
|
12
17
|
Protocol,
|
13
18
|
TypeAlias,
|
@@ -15,6 +20,7 @@ from typing import (
|
|
15
20
|
)
|
16
21
|
|
17
22
|
import nshconfig as C
|
23
|
+
import numpy as np
|
18
24
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
19
25
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
20
26
|
from lightning.pytorch.accelerators import Accelerator
|
@@ -28,6 +34,7 @@ from lightning.pytorch.strategies.strategy import Strategy
|
|
28
34
|
from typing_extensions import TypedDict, TypeVar, override
|
29
35
|
|
30
36
|
from .._checkpoint.loader import CheckpointLoadingConfig
|
37
|
+
from .._directory import DirectoryConfig
|
31
38
|
from .._hf_hub import HuggingFaceHubConfig
|
32
39
|
from ..callbacks import (
|
33
40
|
BestCheckpointCallbackConfig,
|
@@ -47,10 +54,10 @@ from ..loggers import (
|
|
47
54
|
TensorboardLoggerConfig,
|
48
55
|
WandbLoggerConfig,
|
49
56
|
)
|
57
|
+
from ..loggers.actsave import ActSaveLoggerConfig
|
58
|
+
from ..metrics._config import MetricConfig
|
50
59
|
from ..profiler import ProfilerConfig
|
51
|
-
|
52
|
-
if TYPE_CHECKING:
|
53
|
-
from ..model.config import BaseConfig
|
60
|
+
from ..util._environment_info import EnvironmentConfig
|
54
61
|
|
55
62
|
log = logging.getLogger(__name__)
|
56
63
|
|
@@ -71,7 +78,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
71
78
|
log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
|
72
79
|
"""If enabled, will log the fractional epoch number to the logger."""
|
73
80
|
|
74
|
-
|
81
|
+
actsave_logger: ActSaveLoggerConfig | None = None
|
75
82
|
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
76
83
|
|
77
84
|
@property
|
@@ -103,12 +110,12 @@ class LoggingConfig(CallbackConfigBase):
|
|
103
110
|
None,
|
104
111
|
)
|
105
112
|
|
106
|
-
def create_loggers(self,
|
113
|
+
def create_loggers(self, trainer_config: TrainerConfig):
|
107
114
|
"""
|
108
115
|
Constructs and returns a list of loggers based on the provided root configuration.
|
109
116
|
|
110
117
|
Args:
|
111
|
-
|
118
|
+
trainer_config (TrainerConfig): The root configuration object.
|
112
119
|
|
113
120
|
Returns:
|
114
121
|
list[Logger]: A list of constructed loggers.
|
@@ -123,12 +130,16 @@ class LoggingConfig(CallbackConfigBase):
|
|
123
130
|
):
|
124
131
|
if not logger_config.enabled:
|
125
132
|
continue
|
126
|
-
if (logger := logger_config.create_logger(
|
133
|
+
if (logger := logger_config.create_logger(trainer_config)) is None:
|
127
134
|
continue
|
128
135
|
yield logger
|
129
136
|
|
137
|
+
# If the actsave_metrics is enabled, add the ActSave logger
|
138
|
+
if self.actsave_logger:
|
139
|
+
yield self.actsave_logger.create_logger(trainer_config)
|
140
|
+
|
130
141
|
@override
|
131
|
-
def create_callbacks(self,
|
142
|
+
def create_callbacks(self, trainer_config):
|
132
143
|
if self.log_lr:
|
133
144
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
134
145
|
|
@@ -139,13 +150,13 @@ class LoggingConfig(CallbackConfigBase):
|
|
139
150
|
yield LearningRateMonitor(logging_interval=logging_interval)
|
140
151
|
|
141
152
|
if self.log_epoch:
|
142
|
-
yield from self.log_epoch.create_callbacks(
|
153
|
+
yield from self.log_epoch.create_callbacks(trainer_config)
|
143
154
|
|
144
155
|
for logger in self.loggers:
|
145
156
|
if not logger or not isinstance(logger, CallbackConfigBase):
|
146
157
|
continue
|
147
158
|
|
148
|
-
yield from logger.create_callbacks(
|
159
|
+
yield from logger.create_callbacks(trainer_config)
|
149
160
|
|
150
161
|
|
151
162
|
class GradientClippingConfig(C.Config):
|
@@ -172,7 +183,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
172
183
|
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
173
184
|
|
174
185
|
@override
|
175
|
-
def create_callbacks(self,
|
186
|
+
def create_callbacks(self, trainer_config):
|
176
187
|
from ..callbacks.norm_logging import NormLoggingCallbackConfig
|
177
188
|
|
178
189
|
yield from NormLoggingCallbackConfig(
|
@@ -180,7 +191,7 @@ class OptimizationConfig(CallbackConfigBase):
|
|
180
191
|
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
181
192
|
log_param_norm=self.log_param_norm,
|
182
193
|
log_param_norm_per_param=self.log_param_norm_per_param,
|
183
|
-
).create_callbacks(
|
194
|
+
).create_callbacks(trainer_config)
|
184
195
|
|
185
196
|
|
186
197
|
TPlugin = TypeVar(
|
@@ -274,22 +285,22 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
274
285
|
self.enabled = False
|
275
286
|
return self
|
276
287
|
|
277
|
-
def should_save_checkpoints(self,
|
288
|
+
def should_save_checkpoints(self, trainer_config: TrainerConfig):
|
278
289
|
if not self.enabled:
|
279
290
|
return False
|
280
291
|
|
281
|
-
if
|
292
|
+
if trainer_config.fast_dev_run:
|
282
293
|
return False
|
283
294
|
|
284
295
|
return True
|
285
296
|
|
286
297
|
@override
|
287
|
-
def create_callbacks(self,
|
288
|
-
if not self.should_save_checkpoints(
|
298
|
+
def create_callbacks(self, trainer_config: TrainerConfig):
|
299
|
+
if not self.should_save_checkpoints(trainer_config):
|
289
300
|
return
|
290
301
|
|
291
302
|
for callback_config in self.checkpoint_callbacks:
|
292
|
-
yield from callback_config.create_callbacks(
|
303
|
+
yield from callback_config.create_callbacks(trainer_config)
|
293
304
|
|
294
305
|
|
295
306
|
class LightningTrainerKwargs(TypedDict, total=False):
|
@@ -541,6 +552,74 @@ class SanityCheckingConfig(C.Config):
|
|
541
552
|
|
542
553
|
|
543
554
|
class TrainerConfig(C.Config):
|
555
|
+
# region Active Run Configuration
|
556
|
+
id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
|
557
|
+
"""ID of the run."""
|
558
|
+
name: list[str] = []
|
559
|
+
"""Run name in parts. Full name is constructed by joining the parts with spaces."""
|
560
|
+
project: str | None = None
|
561
|
+
"""Project name."""
|
562
|
+
tags: list[str] = []
|
563
|
+
"""Tags for the run."""
|
564
|
+
notes: list[str] = []
|
565
|
+
"""Human readable notes for the run."""
|
566
|
+
|
567
|
+
@property
|
568
|
+
def full_name(self):
|
569
|
+
return " ".join(self.name)
|
570
|
+
|
571
|
+
debug: bool = False
|
572
|
+
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
573
|
+
|
574
|
+
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = (
|
575
|
+
EnvironmentConfig.empty()
|
576
|
+
)
|
577
|
+
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
578
|
+
|
579
|
+
directory: DirectoryConfig = DirectoryConfig()
|
580
|
+
"""Directory configuration options."""
|
581
|
+
|
582
|
+
_rng: ClassVar[np.random.Generator | None] = None
|
583
|
+
|
584
|
+
@classmethod
|
585
|
+
def generate_id(cls, *, length: int = 8) -> str:
|
586
|
+
"""
|
587
|
+
Generate a random ID of specified length.
|
588
|
+
|
589
|
+
"""
|
590
|
+
if (rng := cls._rng) is None:
|
591
|
+
rng = np.random.default_rng()
|
592
|
+
|
593
|
+
alphabet = list(string.ascii_lowercase + string.digits)
|
594
|
+
|
595
|
+
id = "".join(rng.choice(alphabet) for _ in range(length))
|
596
|
+
return id
|
597
|
+
|
598
|
+
@classmethod
|
599
|
+
def set_seed(cls, seed: int | None = None) -> None:
|
600
|
+
"""
|
601
|
+
Set the seed for the random number generator.
|
602
|
+
|
603
|
+
Args:
|
604
|
+
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
605
|
+
|
606
|
+
Returns:
|
607
|
+
None
|
608
|
+
"""
|
609
|
+
if seed is None:
|
610
|
+
seed = int(time.time() * 1000)
|
611
|
+
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
612
|
+
cls._rng = np.random.default_rng(seed)
|
613
|
+
|
614
|
+
# endregion
|
615
|
+
|
616
|
+
primary_metric: MetricConfig | None = None
|
617
|
+
"""Primary metric configuration options. This is used in the following ways:
|
618
|
+
- To determine the best model checkpoint to save with the ModelCheckpoint callback.
|
619
|
+
- To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
|
620
|
+
- For the ReduceLROnPlateau scheduler.
|
621
|
+
"""
|
622
|
+
|
544
623
|
ckpt_path: Literal["none"] | str | Path | None = None
|
545
624
|
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
546
625
|
|
@@ -788,3 +867,71 @@ class TrainerConfig(C.Config):
|
|
788
867
|
yield self.reduce_lr_on_plateau_sanity_checking
|
789
868
|
yield self.auto_set_debug_flag
|
790
869
|
yield from self.callbacks
|
870
|
+
|
871
|
+
# region Helper Methods
|
872
|
+
def with_fast_dev_run(self, value: int | bool = True, /):
|
873
|
+
"""
|
874
|
+
Enables fast_dev_run mode for the trainer.
|
875
|
+
This will run the training loop for a specified number of batches,
|
876
|
+
if an integer is provided, or for a single batch if True is provided.
|
877
|
+
"""
|
878
|
+
config = copy.deepcopy(self)
|
879
|
+
config.fast_dev_run = value
|
880
|
+
return config
|
881
|
+
|
882
|
+
def with_project_root(self, project_root: str | Path | os.PathLike):
|
883
|
+
"""
|
884
|
+
Set the project root directory for the trainer.
|
885
|
+
|
886
|
+
Args:
|
887
|
+
project_root (Path): The base directory to use.
|
888
|
+
|
889
|
+
Returns:
|
890
|
+
self: The current instance of the class.
|
891
|
+
"""
|
892
|
+
config = copy.deepcopy(self)
|
893
|
+
config.directory.project_root = Path(project_root)
|
894
|
+
return config
|
895
|
+
|
896
|
+
def reset_run(
|
897
|
+
self,
|
898
|
+
*,
|
899
|
+
id: bool = True,
|
900
|
+
basic: bool = True,
|
901
|
+
project_root: bool = True,
|
902
|
+
environment: bool = True,
|
903
|
+
):
|
904
|
+
"""
|
905
|
+
Reset the configuration object to its initial state.
|
906
|
+
|
907
|
+
Parameters:
|
908
|
+
- id (bool): If True, generate a new ID for the configuration object.
|
909
|
+
- basic (bool): If True, reset basic attributes like name, project, tags, and notes.
|
910
|
+
- project_root (bool): If True, reset the directory configuration to its initial state.
|
911
|
+
- environment (bool): If True, reset the environment configuration to its initial state.
|
912
|
+
- meta (bool): If True, reset the meta dictionary to an empty dictionary.
|
913
|
+
|
914
|
+
Returns:
|
915
|
+
- self: The updated configuration object.
|
916
|
+
|
917
|
+
"""
|
918
|
+
config = copy.deepcopy(self)
|
919
|
+
|
920
|
+
if id:
|
921
|
+
config.id = config.generate_id()
|
922
|
+
|
923
|
+
if basic:
|
924
|
+
config.name = []
|
925
|
+
config.project = None
|
926
|
+
config.tags = []
|
927
|
+
config.notes = []
|
928
|
+
|
929
|
+
if project_root:
|
930
|
+
config.directory = DirectoryConfig()
|
931
|
+
|
932
|
+
if environment:
|
933
|
+
config.environment = EnvironmentConfig.empty()
|
934
|
+
|
935
|
+
return config
|
936
|
+
|
937
|
+
# endregion
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import TYPE_CHECKING, cast
|
6
5
|
|
7
6
|
from lightning.pytorch.trainer.connectors.checkpoint_connector import (
|
8
7
|
_CheckpointConnector as _LightningCheckpointConnector,
|
@@ -12,8 +11,6 @@ from typing_extensions import override
|
|
12
11
|
|
13
12
|
from .._checkpoint.loader import CheckpointLoadingConfig, _resolve_checkpoint
|
14
13
|
|
15
|
-
if TYPE_CHECKING:
|
16
|
-
from ..model.config import BaseConfig
|
17
14
|
log = logging.getLogger(__name__)
|
18
15
|
|
19
16
|
|
@@ -32,8 +29,7 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
32
29
|
return None
|
33
30
|
|
34
31
|
# Now, resolve the checkpoint loader config.
|
35
|
-
|
36
|
-
ckpt_loader_config = root_config.trainer.checkpoint_loading
|
32
|
+
ckpt_loader_config = trainer.hparams.checkpoint_loading
|
37
33
|
match ckpt_loader_config:
|
38
34
|
case "auto":
|
39
35
|
ckpt_loader_config = CheckpointLoadingConfig.auto(ckpt_path, state_fn)
|
@@ -44,9 +40,7 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
44
40
|
log.debug(f"Checkpoint loader config: {ckpt_loader_config}")
|
45
41
|
|
46
42
|
# Use the config to resolve the checkpoint.
|
47
|
-
if (
|
48
|
-
ckpt_path := _resolve_checkpoint(ckpt_loader_config, root_config, trainer)
|
49
|
-
) is None:
|
43
|
+
if (ckpt_path := _resolve_checkpoint(ckpt_loader_config, trainer)) is None:
|
50
44
|
log.info(
|
51
45
|
"No checkpoint found for the current trainer state. "
|
52
46
|
"Training will start from scratch."
|
@@ -69,3 +63,24 @@ class _CheckpointConnector(_LightningCheckpointConnector):
|
|
69
63
|
return super()._parse_ckpt_path(
|
70
64
|
state_fn, ckpt_path, model_provided, model_connected
|
71
65
|
)
|
66
|
+
|
67
|
+
@override
|
68
|
+
def dump_checkpoint(self, weights_only: bool = False):
|
69
|
+
checkpoint = super().dump_checkpoint(weights_only)
|
70
|
+
|
71
|
+
# Save the trainer's config.
|
72
|
+
_add_trainer_config_to_checkpoint_(checkpoint, self.trainer)
|
73
|
+
|
74
|
+
return checkpoint
|
75
|
+
|
76
|
+
|
77
|
+
def _add_trainer_config_to_checkpoint_(checkpoint: dict, trainer):
|
78
|
+
from .trainer import Trainer
|
79
|
+
|
80
|
+
# If this isn't an `nshtrainer` trainer (which I don't know why it wouldn't be),
|
81
|
+
# then we just return.
|
82
|
+
if isinstance(trainer, Trainer):
|
83
|
+
return None
|
84
|
+
|
85
|
+
# Save the trainer's config.
|
86
|
+
checkpoint[trainer.CHECKPOINT_HYPER_PARAMS_KEY] = dict(trainer.hparams)
|