nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +51 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/_base.py +2 -1
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.0.dist-info/RECORD +0 -162
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
nshtrainer/trainer/trainer.py
CHANGED
@@ -2,90 +2,120 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import os
|
5
|
-
from collections.abc import Sequence
|
5
|
+
from collections.abc import Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, cast
|
7
|
+
from typing import IO, TYPE_CHECKING, Any, cast
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
11
11
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
12
12
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
13
|
+
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
14
|
+
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
13
15
|
from lightning.pytorch import LightningModule
|
14
16
|
from lightning.pytorch import Trainer as LightningTrainer
|
15
17
|
from lightning.pytorch.callbacks import Callback
|
18
|
+
from lightning.pytorch.core.saving import (
|
19
|
+
_default_map_location,
|
20
|
+
load_hparams_from_tags_csv,
|
21
|
+
load_hparams_from_yaml,
|
22
|
+
)
|
16
23
|
from lightning.pytorch.profilers import Profiler
|
17
24
|
from lightning.pytorch.trainer.states import TrainerFn
|
25
|
+
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
26
|
+
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
|
18
27
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
19
|
-
from typing_extensions import Unpack, assert_never, override
|
28
|
+
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
20
29
|
|
21
30
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
22
31
|
from ..callbacks.base import resolve_all_callbacks
|
32
|
+
from ..util._environment_info import EnvironmentConfig
|
23
33
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
24
34
|
from ._config import (
|
25
35
|
AcceleratorConfigProtocol,
|
26
36
|
LightningTrainerKwargs,
|
27
37
|
StrategyConfigProtocol,
|
38
|
+
TrainerConfig,
|
28
39
|
)
|
29
40
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
30
41
|
from .checkpoint_connector import _CheckpointConnector
|
31
42
|
from .signal_connector import _SignalConnector
|
32
43
|
|
33
|
-
if TYPE_CHECKING:
|
34
|
-
from ..model.config import BaseConfig
|
35
|
-
|
36
44
|
log = logging.getLogger(__name__)
|
37
45
|
|
38
46
|
|
39
47
|
class Trainer(LightningTrainer):
|
48
|
+
CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
|
49
|
+
|
50
|
+
@property
|
51
|
+
def hparams(self) -> TrainerConfig:
|
52
|
+
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For
|
53
|
+
the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
Mutable hyperparameters dictionary
|
57
|
+
|
58
|
+
"""
|
59
|
+
return self._hparams
|
60
|
+
|
61
|
+
@property
|
62
|
+
@deprecated("Use `hparams` instead")
|
63
|
+
def config(self):
|
64
|
+
return cast(Never, self.hparams)
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def hparams_cls(cls):
|
68
|
+
return TrainerConfig
|
69
|
+
|
40
70
|
@classmethod
|
41
|
-
def _pre_init(cls,
|
42
|
-
if (precision :=
|
71
|
+
def _pre_init(cls, hparams: TrainerConfig):
|
72
|
+
if (precision := hparams.set_float32_matmul_precision) is not None:
|
43
73
|
torch.set_float32_matmul_precision(precision)
|
44
74
|
|
45
75
|
@classmethod
|
46
76
|
def _update_kwargs(
|
47
77
|
cls,
|
48
|
-
|
78
|
+
hparams: TrainerConfig,
|
49
79
|
kwargs_ctor: LightningTrainerKwargs,
|
50
80
|
):
|
51
81
|
kwargs: LightningTrainerKwargs = {
|
52
|
-
"deterministic":
|
53
|
-
"fast_dev_run":
|
54
|
-
"max_epochs":
|
55
|
-
"min_epochs":
|
56
|
-
"max_steps":
|
57
|
-
"min_steps":
|
58
|
-
"max_time":
|
59
|
-
"limit_train_batches":
|
60
|
-
"limit_val_batches":
|
61
|
-
"limit_test_batches":
|
62
|
-
"limit_predict_batches":
|
63
|
-
"overfit_batches":
|
64
|
-
"val_check_interval":
|
65
|
-
"num_sanity_val_steps":
|
66
|
-
"log_every_n_steps":
|
67
|
-
"inference_mode":
|
82
|
+
"deterministic": hparams.reproducibility.deterministic,
|
83
|
+
"fast_dev_run": hparams.fast_dev_run,
|
84
|
+
"max_epochs": hparams.max_epochs,
|
85
|
+
"min_epochs": hparams.min_epochs,
|
86
|
+
"max_steps": hparams.max_steps,
|
87
|
+
"min_steps": hparams.min_steps,
|
88
|
+
"max_time": hparams.max_time,
|
89
|
+
"limit_train_batches": hparams.limit_train_batches,
|
90
|
+
"limit_val_batches": hparams.limit_val_batches,
|
91
|
+
"limit_test_batches": hparams.limit_test_batches,
|
92
|
+
"limit_predict_batches": hparams.limit_predict_batches,
|
93
|
+
"overfit_batches": hparams.overfit_batches,
|
94
|
+
"val_check_interval": hparams.val_check_interval,
|
95
|
+
"num_sanity_val_steps": hparams.num_sanity_val_steps,
|
96
|
+
"log_every_n_steps": hparams.log_every_n_steps,
|
97
|
+
"inference_mode": hparams.inference_mode,
|
68
98
|
"callbacks": [],
|
69
99
|
"plugins": [],
|
70
100
|
"logger": [],
|
71
101
|
# Moved to `lightning_kwargs`:
|
72
|
-
# "enable_checkpointing":
|
73
|
-
# "accelerator":
|
74
|
-
# "strategy":
|
75
|
-
# "num_nodes":
|
76
|
-
# "precision":
|
77
|
-
# "logger":
|
78
|
-
# "log_every_n_steps":
|
79
|
-
# "enable_progress_bar":
|
80
|
-
# "enable_model_summary":
|
81
|
-
# "accumulate_grad_batches":
|
82
|
-
# "benchmark":
|
83
|
-
# "use_distributed_sampler":
|
84
|
-
# "detect_anomaly":
|
85
|
-
# "barebones":
|
86
|
-
# "plugins":
|
87
|
-
# "sync_batchnorm":
|
88
|
-
# "reload_dataloaders_every_n_epochs":
|
102
|
+
# "enable_checkpointing": hparams.enable_checkpointing,
|
103
|
+
# "accelerator": hparams.accelerator,
|
104
|
+
# "strategy": hparams.strategy,
|
105
|
+
# "num_nodes": hparams.num_nodes,
|
106
|
+
# "precision": hparams.precision,
|
107
|
+
# "logger": hparams.logging.enabled,
|
108
|
+
# "log_every_n_steps": hparams.log_every_n_steps,
|
109
|
+
# "enable_progress_bar": hparams.enable_progress_bar,
|
110
|
+
# "enable_model_summary": hparams.enable_model_summary,
|
111
|
+
# "accumulate_grad_batches": hparams.accumulate_grad_batches,
|
112
|
+
# "benchmark": hparams.benchmark,
|
113
|
+
# "use_distributed_sampler": hparams.use_distributed_sampler,
|
114
|
+
# "detect_anomaly": hparams.detect_anomaly,
|
115
|
+
# "barebones": hparams.barebones,
|
116
|
+
# "plugins": hparams.plugins,
|
117
|
+
# "sync_batchnorm": hparams.sync_batchnorm,
|
118
|
+
# "reload_dataloaders_every_n_epochs": hparams.reload_dataloaders_every_n_epochs,
|
89
119
|
}
|
90
120
|
|
91
121
|
def _update_key(key: str, new_value: Any):
|
@@ -115,20 +145,22 @@ class Trainer(LightningTrainer):
|
|
115
145
|
_update_key(key, value)
|
116
146
|
|
117
147
|
# Set `default_root_dir` if `auto_set_default_root_dir` is enabled.
|
118
|
-
if
|
148
|
+
if hparams.auto_set_default_root_dir:
|
119
149
|
if kwargs.get("default_root_dir"):
|
120
150
|
raise ValueError(
|
121
|
-
"You have set `
|
151
|
+
"You have set `hparams.default_root_dir`. "
|
122
152
|
"But we are trying to set it automatically. "
|
123
|
-
"Please use `
|
124
|
-
"If you want to set it manually, please set `
|
153
|
+
"Please use `hparams.directory.base` rather than `hparams.default_root_dir`. "
|
154
|
+
"If you want to set it manually, please set `hparams.auto_set_default_root_dir=False`."
|
125
155
|
)
|
126
156
|
|
127
157
|
_update_kwargs(
|
128
|
-
default_root_dir=
|
158
|
+
default_root_dir=hparams.directory.resolve_run_root_directory(
|
159
|
+
hparams.id
|
160
|
+
)
|
129
161
|
)
|
130
162
|
|
131
|
-
if (devices_input :=
|
163
|
+
if (devices_input := hparams.devices) is not None:
|
132
164
|
match devices_input:
|
133
165
|
case "all":
|
134
166
|
devices = -1
|
@@ -141,22 +173,20 @@ class Trainer(LightningTrainer):
|
|
141
173
|
|
142
174
|
_update_kwargs(devices=devices)
|
143
175
|
|
144
|
-
if (
|
145
|
-
use_distributed_sampler := config.trainer.use_distributed_sampler
|
146
|
-
) is not None:
|
176
|
+
if (use_distributed_sampler := hparams.use_distributed_sampler) is not None:
|
147
177
|
_update_kwargs(use_distributed_sampler=use_distributed_sampler)
|
148
178
|
|
149
|
-
if (accelerator :=
|
179
|
+
if (accelerator := hparams.accelerator) is not None:
|
150
180
|
if isinstance(accelerator, AcceleratorConfigProtocol):
|
151
181
|
accelerator = accelerator.create_accelerator()
|
152
182
|
_update_kwargs(accelerator=accelerator)
|
153
183
|
|
154
|
-
if (strategy :=
|
184
|
+
if (strategy := hparams.strategy) is not None:
|
155
185
|
if isinstance(strategy, StrategyConfigProtocol):
|
156
186
|
strategy = strategy.create_strategy()
|
157
187
|
_update_kwargs(strategy=strategy)
|
158
188
|
|
159
|
-
if (precision :=
|
189
|
+
if (precision := hparams.precision) is not None:
|
160
190
|
resolved_precision: _PRECISION_INPUT
|
161
191
|
match precision:
|
162
192
|
case "64-true" | "32-true" | "bf16-mixed":
|
@@ -184,11 +214,11 @@ class Trainer(LightningTrainer):
|
|
184
214
|
|
185
215
|
_update_kwargs(precision=resolved_precision)
|
186
216
|
|
187
|
-
if (detect_anomaly :=
|
217
|
+
if (detect_anomaly := hparams.detect_anomaly) is not None:
|
188
218
|
_update_kwargs(detect_anomaly=detect_anomaly)
|
189
219
|
|
190
220
|
if (
|
191
|
-
grad_clip_config :=
|
221
|
+
grad_clip_config := hparams.optimizer.gradient_clipping
|
192
222
|
) is not None and grad_clip_config.enabled:
|
193
223
|
# kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
|
194
224
|
# kwargs["gradient_clip_val"] = grad_clip_config.value
|
@@ -197,9 +227,9 @@ class Trainer(LightningTrainer):
|
|
197
227
|
gradient_clip_val=grad_clip_config.value,
|
198
228
|
)
|
199
229
|
|
200
|
-
if profiler_config :=
|
201
|
-
if (profiler := profiler_config.create_profiler(
|
202
|
-
log.warning(f"Profiler
|
230
|
+
if profiler_config := hparams.profiler:
|
231
|
+
if (profiler := profiler_config.create_profiler(hparams)) is None:
|
232
|
+
log.warning(f"Profiler hparams {profiler_config=} returned None.")
|
203
233
|
# Make sure that the profiler is an instance of `Profiler`.
|
204
234
|
elif not isinstance(profiler, Profiler):
|
205
235
|
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
@@ -208,23 +238,29 @@ class Trainer(LightningTrainer):
|
|
208
238
|
else:
|
209
239
|
_update_kwargs(profiler=profiler)
|
210
240
|
|
211
|
-
if callbacks := resolve_all_callbacks(
|
241
|
+
if callbacks := resolve_all_callbacks(hparams):
|
212
242
|
_update_kwargs(callbacks=callbacks)
|
213
243
|
|
214
|
-
if plugin_configs :=
|
244
|
+
if plugin_configs := hparams.plugins:
|
215
245
|
_update_kwargs(
|
216
246
|
plugins=[
|
217
247
|
plugin_config.create_plugin() for plugin_config in plugin_configs
|
218
248
|
]
|
219
249
|
)
|
220
250
|
|
221
|
-
if not
|
222
|
-
log.critical(f"Disabling logger because {
|
251
|
+
if not hparams.logging.enabled:
|
252
|
+
log.critical(f"Disabling logger because {hparams.logging.enabled=}.")
|
223
253
|
kwargs["logger"] = False
|
224
254
|
else:
|
225
|
-
_update_kwargs(
|
255
|
+
_update_kwargs(
|
256
|
+
logger=[
|
257
|
+
logger
|
258
|
+
for logger in hparams.logging.create_loggers(hparams)
|
259
|
+
if logger is not None
|
260
|
+
]
|
261
|
+
)
|
226
262
|
|
227
|
-
if
|
263
|
+
if hparams.auto_determine_num_nodes:
|
228
264
|
# When num_nodes is auto, we need to detect the number of nodes.
|
229
265
|
if SLURMEnvironment.detect():
|
230
266
|
if (num_nodes := os.environ.get("SLURM_NNODES")) is not None:
|
@@ -243,12 +279,12 @@ class Trainer(LightningTrainer):
|
|
243
279
|
_update_kwargs(num_nodes=num_nodes)
|
244
280
|
else:
|
245
281
|
log.info(
|
246
|
-
"
|
282
|
+
"hparams.auto_determine_num_nodes ignored because no SLURM or LSF detected."
|
247
283
|
)
|
248
284
|
|
249
285
|
# Update the kwargs with the additional trainer kwargs
|
250
|
-
_update_kwargs(**cast(Any,
|
251
|
-
_update_kwargs(**
|
286
|
+
_update_kwargs(**cast(Any, hparams.additional_lightning_kwargs))
|
287
|
+
_update_kwargs(**hparams.lightning_kwargs)
|
252
288
|
_update_kwargs(**kwargs_ctor)
|
253
289
|
|
254
290
|
return kwargs
|
@@ -259,15 +295,29 @@ class Trainer(LightningTrainer):
|
|
259
295
|
@override
|
260
296
|
def __init__(
|
261
297
|
self,
|
262
|
-
|
298
|
+
hparams: TrainerConfig | Mapping[str, Any],
|
263
299
|
/,
|
264
300
|
**kwargs: Unpack[LightningTrainerKwargs],
|
265
301
|
):
|
266
|
-
|
302
|
+
# Validate the hparams.
|
303
|
+
hparams_cls = Trainer.hparams_cls()
|
304
|
+
if isinstance(hparams, Mapping):
|
305
|
+
hparams = hparams_cls.model_validate(hparams)
|
306
|
+
elif not isinstance(hparams, hparams_cls):
|
307
|
+
raise ValueError(
|
308
|
+
f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
|
309
|
+
f"Got {type(hparams)=} instead."
|
310
|
+
)
|
311
|
+
hparams = hparams.model_deep_validate()
|
312
|
+
|
313
|
+
self._pre_init(hparams)
|
267
314
|
|
268
|
-
kwargs = self._update_kwargs(
|
315
|
+
kwargs = self._update_kwargs(hparams, kwargs)
|
269
316
|
log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
|
270
317
|
|
318
|
+
self._hparams = hparams
|
319
|
+
self.debug = self.hparams.debug
|
320
|
+
|
271
321
|
super().__init__(**kwargs)
|
272
322
|
|
273
323
|
# Add our own start time callback to measure the start time.
|
@@ -285,7 +335,7 @@ class Trainer(LightningTrainer):
|
|
285
335
|
log.critical(f"LightningTrainer log directory: {self.log_dir}.")
|
286
336
|
|
287
337
|
# Set the checkpoint
|
288
|
-
if (ckpt_path :=
|
338
|
+
if (ckpt_path := hparams.ckpt_path) is not None:
|
289
339
|
self.ckpt_path = str(Path(ckpt_path).resolve().absolute())
|
290
340
|
|
291
341
|
def __runtime_tracker(self):
|
@@ -372,7 +422,16 @@ class Trainer(LightningTrainer):
|
|
372
422
|
We patch the `Trainer._run` method to throw if gradient clipping is enabled
|
373
423
|
and `model.automatic_optimization` is False.
|
374
424
|
"""
|
425
|
+
# Save the current environment information
|
426
|
+
datamodule = getattr(self, "datamodule", None)
|
427
|
+
self.hparams.environment = EnvironmentConfig.from_current_environment(
|
428
|
+
self.hparams, model, datamodule
|
429
|
+
)
|
375
430
|
|
431
|
+
# If gradient clipping is enabled, then we need to make sure that
|
432
|
+
# `model.automatic_optimization` is enabled. Otherwise, gradient clipping
|
433
|
+
# is not actually going to do anything, as we expect the user to manually
|
434
|
+
# call `optimizer.step()` and `optimizer.zero_grad()`.
|
376
435
|
if not model.automatic_optimization and (
|
377
436
|
self.gradient_clip_val is not None
|
378
437
|
or self.gradient_clip_algorithm is not None
|
@@ -401,12 +460,10 @@ class Trainer(LightningTrainer):
|
|
401
460
|
|
402
461
|
# Save the checkpoint metadata
|
403
462
|
metadata_path = None
|
404
|
-
|
405
|
-
root_config = cast("BaseConfig", lm.hparams)
|
406
|
-
if root_config.trainer.save_checkpoint_metadata and self.is_global_zero:
|
463
|
+
if self.hparams.save_checkpoint_metadata and self.is_global_zero:
|
407
464
|
# Generate the metadata and write to disk
|
408
465
|
if (
|
409
|
-
metadata_path := _write_checkpoint_metadata(self,
|
466
|
+
metadata_path := _write_checkpoint_metadata(self, filepath)
|
410
467
|
) is not None:
|
411
468
|
written_files.append(metadata_path)
|
412
469
|
|
@@ -414,3 +471,64 @@ class Trainer(LightningTrainer):
|
|
414
471
|
from .. import _callback
|
415
472
|
|
416
473
|
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
474
|
+
|
475
|
+
@classmethod
|
476
|
+
def load_from_checkpoint(
|
477
|
+
cls,
|
478
|
+
checkpoint_path: _PATH | IO,
|
479
|
+
map_location: _MAP_LOCATION_TYPE = None,
|
480
|
+
hparams_file: _PATH | None = None,
|
481
|
+
**kwargs: Any,
|
482
|
+
):
|
483
|
+
loaded = _load_from_checkpoint(
|
484
|
+
checkpoint_path,
|
485
|
+
map_location=map_location,
|
486
|
+
hparams_file=hparams_file,
|
487
|
+
**kwargs,
|
488
|
+
)
|
489
|
+
return loaded
|
490
|
+
|
491
|
+
|
492
|
+
def _load_from_checkpoint(
|
493
|
+
checkpoint_path: _PATH | IO,
|
494
|
+
map_location: _MAP_LOCATION_TYPE = None,
|
495
|
+
hparams_file: _PATH | None = None,
|
496
|
+
**kwargs: Any,
|
497
|
+
):
|
498
|
+
map_location = map_location or _default_map_location
|
499
|
+
with pl_legacy_patch():
|
500
|
+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
501
|
+
|
502
|
+
# convert legacy checkpoints to the new format
|
503
|
+
checkpoint = _pl_migrate_checkpoint(
|
504
|
+
checkpoint,
|
505
|
+
checkpoint_path=(
|
506
|
+
checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
|
507
|
+
),
|
508
|
+
)
|
509
|
+
|
510
|
+
if hparams_file is not None:
|
511
|
+
extension = str(hparams_file).split(".")[-1]
|
512
|
+
if extension.lower() == "csv":
|
513
|
+
hparams = load_hparams_from_tags_csv(hparams_file)
|
514
|
+
elif extension.lower() in ("yml", "yaml"):
|
515
|
+
hparams = load_hparams_from_yaml(hparams_file)
|
516
|
+
else:
|
517
|
+
raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
|
518
|
+
|
519
|
+
# overwrite hparams by the given file
|
520
|
+
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
|
521
|
+
|
522
|
+
# for past checkpoint need to add the new key
|
523
|
+
checkpoint.setdefault(Trainer.CHECKPOINT_HYPER_PARAMS_KEY, {})
|
524
|
+
# override the hparams with values that were passed in
|
525
|
+
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
|
526
|
+
|
527
|
+
# load the hparams
|
528
|
+
hparams = Trainer.hparams_cls().model_validate(
|
529
|
+
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
|
530
|
+
)
|
531
|
+
|
532
|
+
# create the trainer
|
533
|
+
trainer = Trainer(hparams)
|
534
|
+
return trainer
|
@@ -15,14 +15,14 @@ from typing import TYPE_CHECKING, Any, cast
|
|
15
15
|
import nshconfig as C
|
16
16
|
import psutil
|
17
17
|
import torch
|
18
|
+
from lightning.pytorch import LightningDataModule, LightningModule
|
18
19
|
from packaging import version
|
19
20
|
from typing_extensions import Self
|
20
21
|
|
21
22
|
from .slurm import parse_slurm_node_list
|
22
23
|
|
23
24
|
if TYPE_CHECKING:
|
24
|
-
from ..
|
25
|
-
from ..model.config import BaseConfig
|
25
|
+
from ..trainer._config import TrainerConfig
|
26
26
|
|
27
27
|
|
28
28
|
log = logging.getLogger(__name__)
|
@@ -708,6 +708,9 @@ class EnvironmentConfig(C.Config):
|
|
708
708
|
model: EnvironmentClassInformationConfig | None = None
|
709
709
|
"""The Lightning module class information."""
|
710
710
|
|
711
|
+
datamodule: EnvironmentClassInformationConfig | None = None
|
712
|
+
"""The Lightning data module class information."""
|
713
|
+
|
711
714
|
linux: EnvironmentLinuxEnvironmentConfig | None = None
|
712
715
|
"""The Linux environment information."""
|
713
716
|
|
@@ -768,8 +771,9 @@ class EnvironmentConfig(C.Config):
|
|
768
771
|
@classmethod
|
769
772
|
def from_current_environment(
|
770
773
|
cls,
|
771
|
-
|
772
|
-
model:
|
774
|
+
trainer_config: TrainerConfig,
|
775
|
+
model: LightningModule,
|
776
|
+
datamodule: LightningDataModule | None = None,
|
773
777
|
):
|
774
778
|
draft = cls.draft()
|
775
779
|
draft.cwd = Path(os.getcwd())
|
@@ -777,23 +781,27 @@ class EnvironmentConfig(C.Config):
|
|
777
781
|
draft.python_path = [Path(path) for path in sys.path]
|
778
782
|
draft.python_version = sys.version
|
779
783
|
draft.python_packages = EnvironmentPackageConfig.from_current_environment()
|
780
|
-
draft.config = EnvironmentClassInformationConfig.from_instance(
|
784
|
+
draft.config = EnvironmentClassInformationConfig.from_instance(trainer_config)
|
781
785
|
draft.model = EnvironmentClassInformationConfig.from_instance(model)
|
786
|
+
if datamodule is not None:
|
787
|
+
draft.datamodule = EnvironmentClassInformationConfig.from_instance(
|
788
|
+
datamodule
|
789
|
+
)
|
782
790
|
draft.linux = EnvironmentLinuxEnvironmentConfig.from_current_environment()
|
783
791
|
draft.hardware = EnvironmentHardwareConfig.from_current_environment()
|
784
792
|
draft.slurm = EnvironmentSLURMInformationConfig.from_current_environment()
|
785
793
|
draft.lsf = EnvironmentLSFInformationConfig.from_current_environment()
|
786
|
-
draft.base_dir =
|
787
|
-
|
794
|
+
draft.base_dir = trainer_config.directory.resolve_run_root_directory(
|
795
|
+
trainer_config.id
|
788
796
|
)
|
789
|
-
draft.log_dir =
|
790
|
-
|
797
|
+
draft.log_dir = trainer_config.directory.resolve_subdirectory(
|
798
|
+
trainer_config.id, "log"
|
791
799
|
)
|
792
|
-
draft.checkpoint_dir =
|
793
|
-
|
800
|
+
draft.checkpoint_dir = trainer_config.directory.resolve_subdirectory(
|
801
|
+
trainer_config.id, "checkpoint"
|
794
802
|
)
|
795
|
-
draft.stdio_dir =
|
796
|
-
|
803
|
+
draft.stdio_dir = trainer_config.directory.resolve_subdirectory(
|
804
|
+
trainer_config.id, "stdio"
|
797
805
|
)
|
798
806
|
draft.seed = (
|
799
807
|
int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
|
nshtrainer/util/config/dtype.py
CHANGED
@@ -9,7 +9,7 @@ from typing_extensions import assert_never
|
|
9
9
|
from ..bf16 import is_bf16_supported_no_emulation
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
-
from ...
|
12
|
+
from ...trainer._config import TrainerConfig
|
13
13
|
|
14
14
|
DTypeName: TypeAlias = Literal[
|
15
15
|
"float32",
|
@@ -59,8 +59,8 @@ class DTypeConfig(C.Config):
|
|
59
59
|
"""The name of the dtype."""
|
60
60
|
|
61
61
|
@classmethod
|
62
|
-
def
|
63
|
-
if (precision :=
|
62
|
+
def from_trainer_config(cls, trainer_config: TrainerConfig):
|
63
|
+
if (precision := trainer_config.precision) is None:
|
64
64
|
precision = "32-true"
|
65
65
|
|
66
66
|
match precision:
|
@@ -79,7 +79,7 @@ class DTypeConfig(C.Config):
|
|
79
79
|
case "64-true":
|
80
80
|
return cls(name="float64")
|
81
81
|
case _:
|
82
|
-
assert_never(
|
82
|
+
assert_never(trainer_config.precision)
|
83
83
|
|
84
84
|
@property
|
85
85
|
def torch_dtype(self):
|
nshtrainer/util/typing_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: nshtrainer
|
3
|
-
Version: 0.
|
3
|
+
Version: 1.0.0b9
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -15,7 +15,7 @@ Requires-Dist: huggingface-hub ; extra == "extra"
|
|
15
15
|
Requires-Dist: lightning
|
16
16
|
Requires-Dist: nshconfig
|
17
17
|
Requires-Dist: nshrunner
|
18
|
-
Requires-Dist: nshutils
|
18
|
+
Requires-Dist: nshutils ; extra == "extra"
|
19
19
|
Requires-Dist: numpy
|
20
20
|
Requires-Dist: packaging
|
21
21
|
Requires-Dist: psutil
|