nshtrainer 1.0.0b11__py3-none-any.whl → 1.0.0b13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/callbacks/lr_monitor.py +31 -0
- nshtrainer/configs/__init__.py +5 -13
- nshtrainer/configs/callbacks/__init__.py +8 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +31 -0
- nshtrainer/configs/trainer/__init__.py +19 -15
- nshtrainer/configs/trainer/_config/__init__.py +19 -15
- nshtrainer/data/datamodule.py +68 -1
- nshtrainer/model/base.py +100 -2
- nshtrainer/trainer/_config.py +95 -147
- nshtrainer/trainer/trainer.py +48 -76
- {nshtrainer-1.0.0b11.dist-info → nshtrainer-1.0.0b13.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b11.dist-info → nshtrainer-1.0.0b13.dist-info}/RECORD +13 -12
- nshtrainer/scripts/find_packages.py +0 -52
- {nshtrainer-1.0.0b11.dist-info → nshtrainer-1.0.0b13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
6
|
+
|
7
|
+
from .base import CallbackConfigBase
|
8
|
+
|
9
|
+
|
10
|
+
class LearningRateMonitorConfig(CallbackConfigBase):
|
11
|
+
logging_interval: Literal["step", "epoch"] | None = None
|
12
|
+
"""
|
13
|
+
Set to 'epoch' or 'step' to log 'lr' of all optimizers at the same interval, set to None to log at individual interval according to the 'interval' key of each scheduler. Defaults to None.
|
14
|
+
"""
|
15
|
+
|
16
|
+
log_momentum: bool = False
|
17
|
+
"""
|
18
|
+
Option to also log the momentum values of the optimizer, if the optimizer has the 'momentum' or 'betas' attribute. Defaults to False.
|
19
|
+
"""
|
20
|
+
|
21
|
+
log_weight_decay: bool = False
|
22
|
+
"""
|
23
|
+
Option to also log the weight decay values of the optimizer. Defaults to False.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def create_callbacks(self, trainer_config):
|
27
|
+
yield LearningRateMonitor(
|
28
|
+
logging_interval=self.logging_interval,
|
29
|
+
log_momentum=self.log_momentum,
|
30
|
+
log_weight_decay=self.log_weight_decay,
|
31
|
+
)
|
nshtrainer/configs/__init__.py
CHANGED
@@ -132,10 +132,8 @@ if TYPE_CHECKING:
|
|
132
132
|
from nshtrainer.trainer._config import (
|
133
133
|
GradientClippingConfig as GradientClippingConfig,
|
134
134
|
)
|
135
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
136
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
137
135
|
from nshtrainer.trainer._config import (
|
138
|
-
|
136
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
139
137
|
)
|
140
138
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
141
139
|
from nshtrainer.util._environment_info import (
|
@@ -325,6 +323,10 @@ else:
|
|
325
323
|
).LastCheckpointStrategyConfig
|
326
324
|
if name == "LeakyReLUNonlinearityConfig":
|
327
325
|
return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
|
326
|
+
if name == "LearningRateMonitorConfig":
|
327
|
+
return importlib.import_module(
|
328
|
+
"nshtrainer.trainer._config"
|
329
|
+
).LearningRateMonitorConfig
|
328
330
|
if name == "LinearWarmupCosineDecayLRSchedulerConfig":
|
329
331
|
return importlib.import_module(
|
330
332
|
"nshtrainer.lr_scheduler"
|
@@ -333,8 +335,6 @@ else:
|
|
333
335
|
return importlib.import_module(
|
334
336
|
"nshtrainer.callbacks"
|
335
337
|
).LogEpochCallbackConfig
|
336
|
-
if name == "LoggingConfig":
|
337
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
338
338
|
if name == "MLPConfig":
|
339
339
|
return importlib.import_module("nshtrainer.nn").MLPConfig
|
340
340
|
if name == "MetricConfig":
|
@@ -349,10 +349,6 @@ else:
|
|
349
349
|
return importlib.import_module(
|
350
350
|
"nshtrainer.callbacks"
|
351
351
|
).OnExceptionCheckpointCallbackConfig
|
352
|
-
if name == "OptimizationConfig":
|
353
|
-
return importlib.import_module(
|
354
|
-
"nshtrainer.trainer._config"
|
355
|
-
).OptimizationConfig
|
356
352
|
if name == "OptimizerConfigBase":
|
357
353
|
return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
|
358
354
|
if name == "PReLUConfig":
|
@@ -373,10 +369,6 @@ else:
|
|
373
369
|
return importlib.import_module(
|
374
370
|
"nshtrainer.lr_scheduler"
|
375
371
|
).ReduceLROnPlateauConfig
|
376
|
-
if name == "ReproducibilityConfig":
|
377
|
-
return importlib.import_module(
|
378
|
-
"nshtrainer.trainer._config"
|
379
|
-
).ReproducibilityConfig
|
380
372
|
if name == "SanityCheckingConfig":
|
381
373
|
return importlib.import_module(
|
382
374
|
"nshtrainer.trainer._config"
|
@@ -62,6 +62,9 @@ if TYPE_CHECKING:
|
|
62
62
|
CheckpointMetadata as CheckpointMetadata,
|
63
63
|
)
|
64
64
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
65
|
+
from nshtrainer.callbacks.lr_monitor import (
|
66
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
67
|
+
)
|
65
68
|
else:
|
66
69
|
|
67
70
|
def __getattr__(name):
|
@@ -115,6 +118,10 @@ else:
|
|
115
118
|
return importlib.import_module(
|
116
119
|
"nshtrainer.callbacks"
|
117
120
|
).LastCheckpointCallbackConfig
|
121
|
+
if name == "LearningRateMonitorConfig":
|
122
|
+
return importlib.import_module(
|
123
|
+
"nshtrainer.callbacks.lr_monitor"
|
124
|
+
).LearningRateMonitorConfig
|
118
125
|
if name == "LogEpochCallbackConfig":
|
119
126
|
return importlib.import_module(
|
120
127
|
"nshtrainer.callbacks"
|
@@ -167,6 +174,7 @@ from . import ema as ema
|
|
167
174
|
from . import finite_checks as finite_checks
|
168
175
|
from . import gradient_skipping as gradient_skipping
|
169
176
|
from . import log_epoch as log_epoch
|
177
|
+
from . import lr_monitor as lr_monitor
|
170
178
|
from . import norm_logging as norm_logging
|
171
179
|
from . import print_table as print_table
|
172
180
|
from . import rlp_sanity_checks as rlp_sanity_checks
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfigBase
|
11
|
+
from nshtrainer.callbacks.lr_monitor import (
|
12
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
13
|
+
)
|
14
|
+
else:
|
15
|
+
|
16
|
+
def __getattr__(name):
|
17
|
+
import importlib
|
18
|
+
|
19
|
+
if name in globals():
|
20
|
+
return globals()[name]
|
21
|
+
if name == "CallbackConfigBase":
|
22
|
+
return importlib.import_module(
|
23
|
+
"nshtrainer.callbacks.lr_monitor"
|
24
|
+
).CallbackConfigBase
|
25
|
+
if name == "LearningRateMonitorConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.lr_monitor"
|
28
|
+
).LearningRateMonitorConfig
|
29
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
|
+
|
31
|
+
# Submodule exports
|
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
11
11
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
12
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
12
13
|
from nshtrainer.trainer._config import (
|
13
14
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
14
15
|
)
|
@@ -39,20 +40,21 @@ if TYPE_CHECKING:
|
|
39
40
|
from nshtrainer.trainer._config import (
|
40
41
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
41
42
|
)
|
43
|
+
from nshtrainer.trainer._config import (
|
44
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
45
|
+
)
|
42
46
|
from nshtrainer.trainer._config import (
|
43
47
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
44
48
|
)
|
45
49
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
46
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
47
50
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
48
51
|
from nshtrainer.trainer._config import (
|
49
|
-
|
52
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
52
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
53
54
|
from nshtrainer.trainer._config import (
|
54
|
-
|
55
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
55
56
|
)
|
57
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
56
58
|
from nshtrainer.trainer._config import (
|
57
59
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
58
60
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
11
|
+
from nshtrainer.trainer._config import BaseLoggerConfig as BaseLoggerConfig
|
11
12
|
from nshtrainer.trainer._config import (
|
12
13
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
13
14
|
)
|
@@ -38,20 +39,21 @@ if TYPE_CHECKING:
|
|
38
39
|
from nshtrainer.trainer._config import (
|
39
40
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
40
41
|
)
|
42
|
+
from nshtrainer.trainer._config import (
|
43
|
+
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
44
|
+
)
|
41
45
|
from nshtrainer.trainer._config import (
|
42
46
|
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
43
47
|
)
|
44
48
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
45
|
-
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
46
49
|
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
47
50
|
from nshtrainer.trainer._config import (
|
48
|
-
|
51
|
+
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
49
52
|
)
|
50
|
-
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
51
|
-
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
52
53
|
from nshtrainer.trainer._config import (
|
53
|
-
|
54
|
+
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
54
55
|
)
|
56
|
+
from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
57
59
|
)
|
@@ -75,6 +77,10 @@ else:
|
|
75
77
|
return importlib.import_module(
|
76
78
|
"nshtrainer.trainer._config"
|
77
79
|
).ActSaveLoggerConfig
|
80
|
+
if name == "BaseLoggerConfig":
|
81
|
+
return importlib.import_module(
|
82
|
+
"nshtrainer.trainer._config"
|
83
|
+
).BaseLoggerConfig
|
78
84
|
if name == "BestCheckpointCallbackConfig":
|
79
85
|
return importlib.import_module(
|
80
86
|
"nshtrainer.trainer._config"
|
@@ -119,30 +125,28 @@ else:
|
|
119
125
|
return importlib.import_module(
|
120
126
|
"nshtrainer.trainer._config"
|
121
127
|
).LastCheckpointCallbackConfig
|
128
|
+
if name == "LearningRateMonitorConfig":
|
129
|
+
return importlib.import_module(
|
130
|
+
"nshtrainer.trainer._config"
|
131
|
+
).LearningRateMonitorConfig
|
122
132
|
if name == "LogEpochCallbackConfig":
|
123
133
|
return importlib.import_module(
|
124
134
|
"nshtrainer.trainer._config"
|
125
135
|
).LogEpochCallbackConfig
|
126
|
-
if name == "LoggingConfig":
|
127
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
136
|
if name == "MetricConfig":
|
129
137
|
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
-
if name == "
|
138
|
+
if name == "NormLoggingCallbackConfig":
|
131
139
|
return importlib.import_module(
|
132
140
|
"nshtrainer.trainer._config"
|
133
|
-
).
|
134
|
-
if name == "
|
141
|
+
).NormLoggingCallbackConfig
|
142
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
135
143
|
return importlib.import_module(
|
136
144
|
"nshtrainer.trainer._config"
|
137
|
-
).
|
145
|
+
).OnExceptionCheckpointCallbackConfig
|
138
146
|
if name == "RLPSanityChecksCallbackConfig":
|
139
147
|
return importlib.import_module(
|
140
148
|
"nshtrainer.trainer._config"
|
141
149
|
).RLPSanityChecksCallbackConfig
|
142
|
-
if name == "ReproducibilityConfig":
|
143
|
-
return importlib.import_module(
|
144
|
-
"nshtrainer.trainer._config"
|
145
|
-
).ReproducibilityConfig
|
146
150
|
if name == "SanityCheckingConfig":
|
147
151
|
return importlib.import_module(
|
148
152
|
"nshtrainer.trainer._config"
|
nshtrainer/data/datamodule.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from collections.abc import Mapping
|
4
|
+
from collections.abc import Callable, Mapping
|
5
|
+
from pathlib import Path
|
5
6
|
from typing import Any, Generic, cast
|
6
7
|
|
7
8
|
import nshconfig as C
|
9
|
+
import torch
|
8
10
|
from lightning.pytorch import LightningDataModule
|
9
11
|
from typing_extensions import Never, TypeVar, deprecated, override
|
10
12
|
|
@@ -55,3 +57,68 @@ class LightningDataModuleBase(
|
|
55
57
|
)
|
56
58
|
hparams = hparams.model_deep_validate()
|
57
59
|
self.save_hyperparameters(hparams)
|
60
|
+
|
61
|
+
@override
|
62
|
+
@classmethod
|
63
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
64
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
65
|
+
|
66
|
+
@classmethod
|
67
|
+
def hparams_from_checkpoint(
|
68
|
+
cls,
|
69
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
70
|
+
/,
|
71
|
+
strict: bool | None = None,
|
72
|
+
*,
|
73
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
74
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
75
|
+
):
|
76
|
+
if isinstance(ckpt_or_path, dict):
|
77
|
+
ckpt = ckpt_or_path
|
78
|
+
else:
|
79
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
80
|
+
|
81
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
82
|
+
raise ValueError(
|
83
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
84
|
+
)
|
85
|
+
if update_hparams_dict is not None:
|
86
|
+
hparams = update_hparams_dict(hparams)
|
87
|
+
|
88
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
89
|
+
if update_hparams is not None:
|
90
|
+
hparams = update_hparams(hparams)
|
91
|
+
|
92
|
+
return hparams
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def from_checkpoint(
|
96
|
+
cls,
|
97
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
98
|
+
/,
|
99
|
+
strict: bool | None = None,
|
100
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
101
|
+
*,
|
102
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
103
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
104
|
+
):
|
105
|
+
# Load checkpoint
|
106
|
+
if isinstance(ckpt_or_path, Mapping):
|
107
|
+
ckpt = ckpt_or_path
|
108
|
+
else:
|
109
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
110
|
+
|
111
|
+
# Load hyperparameters from checkpoint
|
112
|
+
hparams = cls.hparams_from_checkpoint(
|
113
|
+
ckpt,
|
114
|
+
strict=strict,
|
115
|
+
update_hparams=update_hparams,
|
116
|
+
update_hparams_dict=update_hparams_dict,
|
117
|
+
)
|
118
|
+
|
119
|
+
# Load datamodule from checkpoint
|
120
|
+
datamodule = cls(hparams)
|
121
|
+
if datamodule.__class__.__qualname__ in ckpt:
|
122
|
+
datamodule.load_state_dict(ckpt[datamodule.__class__.__qualname__])
|
123
|
+
|
124
|
+
return datamodule
|
nshtrainer/model/base.py
CHANGED
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from collections.abc import Mapping
|
5
|
+
from collections.abc import Callable, Mapping
|
6
|
+
from pathlib import Path
|
6
7
|
from typing import Any, Generic, Literal, cast
|
7
8
|
|
8
9
|
import nshconfig as C
|
@@ -10,11 +11,13 @@ import torch
|
|
10
11
|
import torch.distributed
|
11
12
|
from lightning.pytorch import LightningModule
|
12
13
|
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
14
|
+
from lightning.pytorch.utilities.model_helpers import is_overridden
|
15
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
13
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
14
17
|
|
15
18
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
16
19
|
from .mixins.callback import CallbackModuleMixin
|
17
|
-
from .mixins.debug import _DebugModuleMixin
|
20
|
+
from .mixins.debug import _DebugModuleMixin
|
18
21
|
from .mixins.logger import LoggerLightningModuleMixin
|
19
22
|
|
20
23
|
log = logging.getLogger(__name__)
|
@@ -241,3 +244,98 @@ class LightningModuleBase(
|
|
241
244
|
loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
|
242
245
|
loss = cast(torch.Tensor, loss)
|
243
246
|
return loss
|
247
|
+
|
248
|
+
@override
|
249
|
+
@classmethod
|
250
|
+
def load_from_checkpoint(cls, *args, **kwargs) -> Never:
|
251
|
+
raise ValueError("This method is not supported. Use `from_checkpoint` instead.")
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def hparams_from_checkpoint(
|
255
|
+
cls,
|
256
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
257
|
+
/,
|
258
|
+
strict: bool | None = None,
|
259
|
+
*,
|
260
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
261
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
262
|
+
):
|
263
|
+
if isinstance(ckpt_or_path, dict):
|
264
|
+
ckpt = ckpt_or_path
|
265
|
+
else:
|
266
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
267
|
+
|
268
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
269
|
+
raise ValueError(
|
270
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
271
|
+
)
|
272
|
+
if update_hparams_dict is not None:
|
273
|
+
hparams = update_hparams_dict(hparams)
|
274
|
+
|
275
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
276
|
+
if update_hparams is not None:
|
277
|
+
hparams = update_hparams(hparams)
|
278
|
+
|
279
|
+
return hparams
|
280
|
+
|
281
|
+
@classmethod
|
282
|
+
def from_checkpoint(
|
283
|
+
cls,
|
284
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
285
|
+
/,
|
286
|
+
strict: bool | None = None,
|
287
|
+
map_location: torch.serialization.MAP_LOCATION = None,
|
288
|
+
*,
|
289
|
+
update_hparams: Callable[[THparams], THparams] | None = None,
|
290
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
291
|
+
):
|
292
|
+
# Load checkpoint
|
293
|
+
if isinstance(ckpt_or_path, Mapping):
|
294
|
+
ckpt = ckpt_or_path
|
295
|
+
else:
|
296
|
+
ckpt = torch.load(ckpt_or_path, map_location=map_location)
|
297
|
+
|
298
|
+
# Load hyperparameters from checkpoint
|
299
|
+
hparams = cls.hparams_from_checkpoint(
|
300
|
+
ckpt,
|
301
|
+
strict=strict,
|
302
|
+
update_hparams=update_hparams,
|
303
|
+
update_hparams_dict=update_hparams_dict,
|
304
|
+
)
|
305
|
+
|
306
|
+
# Load model from checkpoint
|
307
|
+
model = cls(hparams)
|
308
|
+
|
309
|
+
# Load model state from checkpoint
|
310
|
+
if (
|
311
|
+
model._strict_loading is not None
|
312
|
+
and strict is not None
|
313
|
+
and strict != model.strict_loading
|
314
|
+
):
|
315
|
+
raise ValueError(
|
316
|
+
f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
|
317
|
+
f" `{cls.__name__}.strict_loading={model.strict_loading!r}. Please set the same value for both of them."
|
318
|
+
)
|
319
|
+
strict = model.strict_loading if strict is None else strict
|
320
|
+
|
321
|
+
if is_overridden("configure_model", model):
|
322
|
+
model.configure_model()
|
323
|
+
|
324
|
+
# give model a chance to load something
|
325
|
+
model.on_load_checkpoint(ckpt)
|
326
|
+
|
327
|
+
# load the state_dict on the model automatically
|
328
|
+
|
329
|
+
keys = model.load_state_dict(ckpt["state_dict"], strict=strict)
|
330
|
+
|
331
|
+
if not strict:
|
332
|
+
if keys.missing_keys:
|
333
|
+
rank_zero_warn(
|
334
|
+
f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
|
335
|
+
)
|
336
|
+
if keys.unexpected_keys:
|
337
|
+
rank_zero_warn(
|
338
|
+
f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
|
339
|
+
)
|
340
|
+
|
341
|
+
return model
|
nshtrainer/trainer/_config.py
CHANGED
@@ -9,7 +9,6 @@ from collections.abc import Iterable, Sequence
|
|
9
9
|
from datetime import timedelta
|
10
10
|
from pathlib import Path
|
11
11
|
from typing import (
|
12
|
-
TYPE_CHECKING,
|
13
12
|
Annotated,
|
14
13
|
Any,
|
15
14
|
ClassVar,
|
@@ -41,11 +40,13 @@ from ..callbacks import (
|
|
41
40
|
CallbackConfig,
|
42
41
|
EarlyStoppingCallbackConfig,
|
43
42
|
LastCheckpointCallbackConfig,
|
43
|
+
NormLoggingCallbackConfig,
|
44
44
|
OnExceptionCheckpointCallbackConfig,
|
45
45
|
)
|
46
46
|
from ..callbacks.base import CallbackConfigBase
|
47
47
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
48
48
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
49
|
+
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
49
50
|
from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
|
50
51
|
from ..callbacks.shared_parameters import SharedParametersCallbackConfig
|
51
52
|
from ..loggers import (
|
@@ -54,6 +55,7 @@ from ..loggers import (
|
|
54
55
|
TensorboardLoggerConfig,
|
55
56
|
WandbLoggerConfig,
|
56
57
|
)
|
58
|
+
from ..loggers._base import BaseLoggerConfig
|
57
59
|
from ..loggers.actsave import ActSaveLoggerConfig
|
58
60
|
from ..metrics._config import MetricConfig
|
59
61
|
from ..profiler import ProfilerConfig
|
@@ -62,103 +64,6 @@ from ..util._environment_info import EnvironmentConfig
|
|
62
64
|
log = logging.getLogger(__name__)
|
63
65
|
|
64
66
|
|
65
|
-
class LoggingConfig(CallbackConfigBase):
|
66
|
-
enabled: bool = True
|
67
|
-
"""Enable experiment tracking."""
|
68
|
-
|
69
|
-
loggers: Sequence[LoggerConfig] = [
|
70
|
-
WandbLoggerConfig(),
|
71
|
-
CSVLoggerConfig(),
|
72
|
-
TensorboardLoggerConfig(),
|
73
|
-
]
|
74
|
-
"""Loggers to use for experiment tracking."""
|
75
|
-
|
76
|
-
log_lr: bool | Literal["step", "epoch"] = True
|
77
|
-
"""If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
|
78
|
-
log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
|
79
|
-
"""If enabled, will log the fractional epoch number to the logger."""
|
80
|
-
|
81
|
-
actsave_logger: ActSaveLoggerConfig | None = None
|
82
|
-
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
83
|
-
|
84
|
-
@property
|
85
|
-
def wandb(self):
|
86
|
-
return next(
|
87
|
-
(
|
88
|
-
logger
|
89
|
-
for logger in self.loggers
|
90
|
-
if isinstance(logger, WandbLoggerConfig)
|
91
|
-
),
|
92
|
-
None,
|
93
|
-
)
|
94
|
-
|
95
|
-
@property
|
96
|
-
def csv(self):
|
97
|
-
return next(
|
98
|
-
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
99
|
-
None,
|
100
|
-
)
|
101
|
-
|
102
|
-
@property
|
103
|
-
def tensorboard(self):
|
104
|
-
return next(
|
105
|
-
(
|
106
|
-
logger
|
107
|
-
for logger in self.loggers
|
108
|
-
if isinstance(logger, TensorboardLoggerConfig)
|
109
|
-
),
|
110
|
-
None,
|
111
|
-
)
|
112
|
-
|
113
|
-
def create_loggers(self, trainer_config: TrainerConfig):
|
114
|
-
"""
|
115
|
-
Constructs and returns a list of loggers based on the provided root configuration.
|
116
|
-
|
117
|
-
Args:
|
118
|
-
trainer_config (TrainerConfig): The root configuration object.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
list[Logger]: A list of constructed loggers.
|
122
|
-
"""
|
123
|
-
if not self.enabled:
|
124
|
-
return
|
125
|
-
|
126
|
-
for logger_config in sorted(
|
127
|
-
self.loggers,
|
128
|
-
key=lambda x: x.priority,
|
129
|
-
reverse=True,
|
130
|
-
):
|
131
|
-
if not logger_config.enabled:
|
132
|
-
continue
|
133
|
-
if (logger := logger_config.create_logger(trainer_config)) is None:
|
134
|
-
continue
|
135
|
-
yield logger
|
136
|
-
|
137
|
-
# If the actsave_metrics is enabled, add the ActSave logger
|
138
|
-
if self.actsave_logger:
|
139
|
-
yield self.actsave_logger.create_logger(trainer_config)
|
140
|
-
|
141
|
-
@override
|
142
|
-
def create_callbacks(self, trainer_config):
|
143
|
-
if self.log_lr:
|
144
|
-
from lightning.pytorch.callbacks import LearningRateMonitor
|
145
|
-
|
146
|
-
logging_interval: str | None = None
|
147
|
-
if isinstance(self.log_lr, str):
|
148
|
-
logging_interval = self.log_lr
|
149
|
-
|
150
|
-
yield LearningRateMonitor(logging_interval=logging_interval)
|
151
|
-
|
152
|
-
if self.log_epoch:
|
153
|
-
yield from self.log_epoch.create_callbacks(trainer_config)
|
154
|
-
|
155
|
-
for logger in self.loggers:
|
156
|
-
if not logger or not isinstance(logger, CallbackConfigBase):
|
157
|
-
continue
|
158
|
-
|
159
|
-
yield from logger.create_callbacks(trainer_config)
|
160
|
-
|
161
|
-
|
162
67
|
class GradientClippingConfig(C.Config):
|
163
68
|
enabled: bool = True
|
164
69
|
"""Enable gradient clipping."""
|
@@ -168,32 +73,6 @@ class GradientClippingConfig(C.Config):
|
|
168
73
|
"""Norm type to use for gradient clipping."""
|
169
74
|
|
170
75
|
|
171
|
-
class OptimizationConfig(CallbackConfigBase):
|
172
|
-
log_grad_norm: bool | str | float = False
|
173
|
-
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
174
|
-
log_grad_norm_per_param: bool | str | float = False
|
175
|
-
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
176
|
-
|
177
|
-
log_param_norm: bool | str | float = False
|
178
|
-
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
179
|
-
log_param_norm_per_param: bool | str | float = False
|
180
|
-
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
181
|
-
|
182
|
-
gradient_clipping: GradientClippingConfig | None = None
|
183
|
-
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
184
|
-
|
185
|
-
@override
|
186
|
-
def create_callbacks(self, trainer_config):
|
187
|
-
from ..callbacks.norm_logging import NormLoggingCallbackConfig
|
188
|
-
|
189
|
-
yield from NormLoggingCallbackConfig(
|
190
|
-
log_grad_norm=self.log_grad_norm,
|
191
|
-
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
192
|
-
log_param_norm=self.log_param_norm,
|
193
|
-
log_param_norm_per_param=self.log_param_norm_per_param,
|
194
|
-
).create_callbacks(trainer_config)
|
195
|
-
|
196
|
-
|
197
76
|
TPlugin = TypeVar(
|
198
77
|
"TPlugin",
|
199
78
|
Precision,
|
@@ -253,15 +132,6 @@ StrategyLiteral: TypeAlias = Literal[
|
|
253
132
|
]
|
254
133
|
|
255
134
|
|
256
|
-
class ReproducibilityConfig(C.Config):
|
257
|
-
deterministic: bool | Literal["warn"] | None = None
|
258
|
-
"""
|
259
|
-
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
260
|
-
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
261
|
-
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
262
|
-
"""
|
263
|
-
|
264
|
-
|
265
135
|
CheckpointCallbackConfig: TypeAlias = Annotated[
|
266
136
|
BestCheckpointCallbackConfig
|
267
137
|
| LastCheckpointCallbackConfig
|
@@ -635,14 +505,34 @@ class TrainerConfig(C.Config):
|
|
635
505
|
hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
|
636
506
|
"""Hugging Face Hub configuration options."""
|
637
507
|
|
638
|
-
|
639
|
-
|
508
|
+
loggers: Sequence[LoggerConfig] = [
|
509
|
+
WandbLoggerConfig(),
|
510
|
+
CSVLoggerConfig(),
|
511
|
+
TensorboardLoggerConfig(),
|
512
|
+
]
|
513
|
+
"""Loggers to use for experiment tracking."""
|
640
514
|
|
641
|
-
|
642
|
-
"""
|
515
|
+
actsave_logger: ActSaveLoggerConfig | None = None
|
516
|
+
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
643
517
|
|
644
|
-
|
645
|
-
"""
|
518
|
+
lr_monitor: LearningRateMonitorConfig | None = LearningRateMonitorConfig()
|
519
|
+
"""Learning rate monitoring configuration options."""
|
520
|
+
|
521
|
+
log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
|
522
|
+
"""If enabled, will log the fractional epoch number to the logger."""
|
523
|
+
|
524
|
+
gradient_clipping: GradientClippingConfig | None = None
|
525
|
+
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
526
|
+
|
527
|
+
log_norms: NormLoggingCallbackConfig | None = None
|
528
|
+
"""Norm logging configuration options."""
|
529
|
+
|
530
|
+
deterministic: bool | Literal["warn"] | None = None
|
531
|
+
"""
|
532
|
+
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
533
|
+
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
534
|
+
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
535
|
+
"""
|
646
536
|
|
647
537
|
reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
|
648
538
|
RLPSanityChecksCallbackConfig()
|
@@ -857,27 +747,87 @@ class TrainerConfig(C.Config):
|
|
857
747
|
set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
|
858
748
|
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
859
749
|
|
750
|
+
@property
|
751
|
+
def wandb_logger(self):
|
752
|
+
return next(
|
753
|
+
(
|
754
|
+
logger
|
755
|
+
for logger in self.loggers
|
756
|
+
if isinstance(logger, WandbLoggerConfig)
|
757
|
+
),
|
758
|
+
None,
|
759
|
+
)
|
760
|
+
|
761
|
+
@property
|
762
|
+
def csv_logger(self):
|
763
|
+
return next(
|
764
|
+
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
765
|
+
None,
|
766
|
+
)
|
767
|
+
|
768
|
+
@property
|
769
|
+
def tensorboard_logger(self):
|
770
|
+
return next(
|
771
|
+
(
|
772
|
+
logger
|
773
|
+
for logger in self.loggers
|
774
|
+
if isinstance(logger, TensorboardLoggerConfig)
|
775
|
+
),
|
776
|
+
None,
|
777
|
+
)
|
778
|
+
|
860
779
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
861
780
|
yield self.early_stopping
|
862
781
|
yield self.checkpoint_saving
|
863
|
-
yield self.
|
864
|
-
yield
|
782
|
+
yield self.lr_monitor
|
783
|
+
yield from (
|
784
|
+
logger_config
|
785
|
+
for logger_config in self.loggers
|
786
|
+
if logger_config is not None
|
787
|
+
and isinstance(logger_config, CallbackConfigBase)
|
788
|
+
)
|
789
|
+
yield self.log_epoch
|
790
|
+
yield self.log_norms
|
865
791
|
yield self.hf_hub
|
866
792
|
yield self.shared_parameters
|
867
793
|
yield self.reduce_lr_on_plateau_sanity_checking
|
868
794
|
yield self.auto_set_debug_flag
|
869
795
|
yield from self.callbacks
|
870
796
|
|
797
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[BaseLoggerConfig | None]:
|
798
|
+
yield from self.loggers
|
799
|
+
yield self.actsave_logger
|
800
|
+
|
871
801
|
# region Helper Methods
|
802
|
+
def fast_dev_run_(self, value: int | bool = True, /):
|
803
|
+
"""
|
804
|
+
Enables fast_dev_run mode for the trainer.
|
805
|
+
This will run the training loop for a specified number of batches,
|
806
|
+
if an integer is provided, or for a single batch if True is provided.
|
807
|
+
"""
|
808
|
+
self.fast_dev_run = value
|
809
|
+
return self
|
810
|
+
|
872
811
|
def with_fast_dev_run(self, value: int | bool = True, /):
|
873
812
|
"""
|
874
813
|
Enables fast_dev_run mode for the trainer.
|
875
814
|
This will run the training loop for a specified number of batches,
|
876
815
|
if an integer is provided, or for a single batch if True is provided.
|
877
816
|
"""
|
878
|
-
|
879
|
-
|
880
|
-
|
817
|
+
return copy.deepcopy(self).fast_dev_run_(value)
|
818
|
+
|
819
|
+
def project_root_(self, project_root: str | Path | os.PathLike):
|
820
|
+
"""
|
821
|
+
Set the project root directory for the trainer.
|
822
|
+
|
823
|
+
Args:
|
824
|
+
project_root (Path): The base directory to use.
|
825
|
+
|
826
|
+
Returns:
|
827
|
+
self: The current instance of the class.
|
828
|
+
"""
|
829
|
+
self.directory.project_root = Path(project_root)
|
830
|
+
return self
|
881
831
|
|
882
832
|
def with_project_root(self, project_root: str | Path | os.PathLike):
|
883
833
|
"""
|
@@ -889,9 +839,7 @@ class TrainerConfig(C.Config):
|
|
889
839
|
Returns:
|
890
840
|
self: The current instance of the class.
|
891
841
|
"""
|
892
|
-
|
893
|
-
config.directory.project_root = Path(project_root)
|
894
|
-
return config
|
842
|
+
return copy.deepcopy(self).project_root_(project_root)
|
895
843
|
|
896
844
|
def reset_run(
|
897
845
|
self,
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -2,28 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import os
|
5
|
-
from collections.abc import Mapping, Sequence
|
5
|
+
from collections.abc import Callable, Mapping, Sequence
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import
|
7
|
+
from typing import TYPE_CHECKING, Any, cast
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
11
11
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
12
12
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
13
|
-
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
14
|
-
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
15
13
|
from lightning.pytorch import LightningModule
|
16
14
|
from lightning.pytorch import Trainer as LightningTrainer
|
17
15
|
from lightning.pytorch.callbacks import Callback
|
18
|
-
from lightning.pytorch.core.saving import (
|
19
|
-
_default_map_location,
|
20
|
-
load_hparams_from_tags_csv,
|
21
|
-
load_hparams_from_yaml,
|
22
|
-
)
|
23
16
|
from lightning.pytorch.profilers import Profiler
|
24
17
|
from lightning.pytorch.trainer.states import TrainerFn
|
25
|
-
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
26
|
-
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
|
27
18
|
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
28
19
|
from typing_extensions import Never, Unpack, assert_never, deprecated, override
|
29
20
|
|
@@ -79,7 +70,7 @@ class Trainer(LightningTrainer):
|
|
79
70
|
kwargs_ctor: LightningTrainerKwargs,
|
80
71
|
):
|
81
72
|
kwargs: LightningTrainerKwargs = {
|
82
|
-
"deterministic": hparams.
|
73
|
+
"deterministic": hparams.deterministic,
|
83
74
|
"fast_dev_run": hparams.fast_dev_run,
|
84
75
|
"max_epochs": hparams.max_epochs,
|
85
76
|
"min_epochs": hparams.min_epochs,
|
@@ -218,7 +209,7 @@ class Trainer(LightningTrainer):
|
|
218
209
|
_update_kwargs(detect_anomaly=detect_anomaly)
|
219
210
|
|
220
211
|
if (
|
221
|
-
grad_clip_config := hparams.
|
212
|
+
grad_clip_config := hparams.gradient_clipping
|
222
213
|
) is not None and grad_clip_config.enabled:
|
223
214
|
# kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
|
224
215
|
# kwargs["gradient_clip_val"] = grad_clip_config.value
|
@@ -248,17 +239,14 @@ class Trainer(LightningTrainer):
|
|
248
239
|
]
|
249
240
|
)
|
250
241
|
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
logger
|
257
|
-
|
258
|
-
|
259
|
-
if logger is not None
|
260
|
-
]
|
261
|
-
)
|
242
|
+
_update_kwargs(
|
243
|
+
logger=[
|
244
|
+
logger
|
245
|
+
for logger_config in hparams._nshtrainer_all_logger_configs()
|
246
|
+
if logger_config is not None
|
247
|
+
and (logger := logger_config.create_logger(hparams)) is not None
|
248
|
+
]
|
249
|
+
)
|
262
250
|
|
263
251
|
if hparams.auto_determine_num_nodes:
|
264
252
|
# When num_nodes is auto, we need to detect the number of nodes.
|
@@ -473,62 +461,46 @@ class Trainer(LightningTrainer):
|
|
473
461
|
_callback._call_on_checkpoint_saved(self, filepath, metadata_path)
|
474
462
|
|
475
463
|
@classmethod
|
476
|
-
def
|
464
|
+
def hparams_from_checkpoint(
|
477
465
|
cls,
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
466
|
+
ckpt_or_path: dict[str, Any] | str | Path,
|
467
|
+
/,
|
468
|
+
strict: bool | None = None,
|
469
|
+
*,
|
470
|
+
update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
|
471
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
482
472
|
):
|
483
|
-
|
484
|
-
|
485
|
-
map_location=map_location,
|
486
|
-
hparams_file=hparams_file,
|
487
|
-
**kwargs,
|
488
|
-
)
|
489
|
-
return loaded
|
490
|
-
|
491
|
-
|
492
|
-
def _load_from_checkpoint(
|
493
|
-
checkpoint_path: _PATH | IO,
|
494
|
-
map_location: _MAP_LOCATION_TYPE = None,
|
495
|
-
hparams_file: _PATH | None = None,
|
496
|
-
**kwargs: Any,
|
497
|
-
):
|
498
|
-
map_location = map_location or _default_map_location
|
499
|
-
with pl_legacy_patch():
|
500
|
-
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
501
|
-
|
502
|
-
# convert legacy checkpoints to the new format
|
503
|
-
checkpoint = _pl_migrate_checkpoint(
|
504
|
-
checkpoint,
|
505
|
-
checkpoint_path=(
|
506
|
-
checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None
|
507
|
-
),
|
508
|
-
)
|
509
|
-
|
510
|
-
if hparams_file is not None:
|
511
|
-
extension = str(hparams_file).split(".")[-1]
|
512
|
-
if extension.lower() == "csv":
|
513
|
-
hparams = load_hparams_from_tags_csv(hparams_file)
|
514
|
-
elif extension.lower() in ("yml", "yaml"):
|
515
|
-
hparams = load_hparams_from_yaml(hparams_file)
|
473
|
+
if isinstance(ckpt_or_path, dict):
|
474
|
+
ckpt = ckpt_or_path
|
516
475
|
else:
|
517
|
-
|
476
|
+
ckpt = torch.load(ckpt_or_path, map_location="cpu")
|
518
477
|
|
519
|
-
|
520
|
-
|
478
|
+
if (hparams := ckpt.get(cls.CHECKPOINT_HYPER_PARAMS_KEY)) is None:
|
479
|
+
raise ValueError(
|
480
|
+
f"The checkpoint does not contain hyperparameters. It must contain the key '{cls.CHECKPOINT_HYPER_PARAMS_KEY}'."
|
481
|
+
)
|
482
|
+
if update_hparams_dict is not None:
|
483
|
+
hparams = update_hparams_dict(hparams)
|
521
484
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
|
485
|
+
hparams = cls.hparams_cls().model_validate(hparams, strict=strict)
|
486
|
+
if update_hparams is not None:
|
487
|
+
hparams = update_hparams(hparams)
|
526
488
|
|
527
|
-
|
528
|
-
hparams = Trainer.hparams_cls().model_validate(
|
529
|
-
checkpoint[Trainer.CHECKPOINT_HYPER_PARAMS_KEY]
|
530
|
-
)
|
489
|
+
return hparams
|
531
490
|
|
532
|
-
|
533
|
-
|
534
|
-
|
491
|
+
@classmethod
|
492
|
+
def from_checkpoint(
|
493
|
+
cls,
|
494
|
+
path: str | Path,
|
495
|
+
strict: bool | None = None,
|
496
|
+
*,
|
497
|
+
update_hparams: Callable[[TrainerConfig], TrainerConfig] | None = None,
|
498
|
+
update_hparams_dict: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
499
|
+
):
|
500
|
+
hparams = cls.hparams_from_checkpoint(
|
501
|
+
path,
|
502
|
+
strict=strict,
|
503
|
+
update_hparams=update_hparams,
|
504
|
+
update_hparams_dict=update_hparams_dict,
|
505
|
+
)
|
506
|
+
return cls(hparams)
|
@@ -22,6 +22,7 @@ nshtrainer/callbacks/finite_checks.py,sha256=iCiKQ5i9RckkzcPeCHzC3hkg3AlW3ESuWtF
|
|
22
22
|
nshtrainer/callbacks/gradient_skipping.py,sha256=k5qNaNeileZ_5YFad4ssfLplMxMKeKFhPcY8-QVmLek,3464
|
23
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
24
24
|
nshtrainer/callbacks/log_epoch.py,sha256=Wr-Ksxsynsqu_zyB_zoiPLjnWv-ksC3xPekY6iyN-P8,1396
|
25
|
+
nshtrainer/callbacks/lr_monitor.py,sha256=IyFZoXaxJoTBSkdLu1iEZ1qI8_UFNJwafR_xTVPZXXU,1050
|
25
26
|
nshtrainer/callbacks/norm_logging.py,sha256=C44Mvt73gqQEpCFd0j3qYg6NY7sL2jm3X1qJVY_XLfI,6329
|
26
27
|
nshtrainer/callbacks/print_table.py,sha256=WIgfzVSfAfS3_8kUuX-nWJOGWBEmtNlejypuoJQViPY,2884
|
27
28
|
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=kWl2dYOXn2L8k6ub_012jNkqOxtyea1yr1qWRNG6UW4,9990
|
@@ -29,13 +30,13 @@ nshtrainer/callbacks/shared_parameters.py,sha256=33eRzifNj6reKbvmGuam1hUofo3sD4J
|
|
29
30
|
nshtrainer/callbacks/timer.py,sha256=BB-M7tV4QNYOwY_Su6j9P7IILxVRae_upmDq4qsxiao,4670
|
30
31
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=PTqNE1QB5U8NR5zhbiQZrmQuugX2UV7B12UdMpo9aV0,2353
|
31
32
|
nshtrainer/callbacks/wandb_watch.py,sha256=tTTcFzxd2Ia9xu8tCogQ5CLJZBq1ne5JlpGVE75vKYs,2976
|
32
|
-
nshtrainer/configs/__init__.py,sha256=
|
33
|
+
nshtrainer/configs/__init__.py,sha256=Vyf_gn7u3s9ET4Yszf6SILtqvpIGiJ4X5RJfmW-FK6I,22293
|
33
34
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=vuiBbd4VzCo7lRyhyTUArEQeWwJkewvNPKDxBJiUHoY,2719
|
34
35
|
nshtrainer/configs/_checkpoint/loader/__init__.py,sha256=hdLpypoEkES1MTaTHAdGFJnSoZzgx_8NzAKbK143SyI,2399
|
35
36
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=X9KxpcoHQbJp6-MTGvp4pct-MYHaHcl82s9yqZ5KiSk,867
|
36
37
|
nshtrainer/configs/_directory/__init__.py,sha256=mTUoSz-DSsvI2M98cqu2Z2x215oM0sLyljh_5rVexvQ,1029
|
37
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=3HGCGhRb7NhOuLeskGqbYNuS9c81oOUbX6ibyF3XiCY,1063
|
38
|
-
nshtrainer/configs/callbacks/__init__.py,sha256
|
39
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=xgCa98EmqU8cHxlJa-64Cc4c_0fS0Cz2iVac4edL_yc,7657
|
39
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=AkVWS9vCcDJFpPUpyc7i9cjaFZU2kKxDyFDqakMZA-E,809
|
40
41
|
nshtrainer/configs/callbacks/base/__init__.py,sha256=OdtHDMkYC_ioCEAkg7bSQi3o7e2t5WHPcFjavXdfdTA,602
|
41
42
|
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=DE-JTtY4NdaP3mgWduearFYMvy1tswWRBWMde06RzQc,2700
|
@@ -50,6 +51,7 @@ nshtrainer/configs/callbacks/ema/__init__.py,sha256=KlPGdJWjYTKLdpl-VnN4BYY2sA_L
|
|
50
51
|
nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=V6Owp05XdIk3EO67AMVGdwbT4-D86QRuvqWM2gu5Xpw,949
|
51
52
|
nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=RhIfJFq-x_sWYrWVGaVEBeT8uUFYjFgt0Ug8pPgpJSg,981
|
52
53
|
nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=4jePzjE3bVxaI7hQrcWW5SrKT5MrFyplJZwK8bQHbGI,900
|
54
|
+
nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=iC8U0oWC75JzPUMRoGWkC8WkMuLbF9-zuN_yQlByycY,916
|
53
55
|
nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=8M9hcGpEuQadfgTR4-YL4TWeyxZjg0s84x420B03-aE,941
|
54
56
|
nshtrainer/configs/callbacks/print_table/__init__.py,sha256=Ni47iS2mIzwGu8XuHfUY5BJKawUO_2TyJMZ62QBpEW0,961
|
55
57
|
nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=OgPywk8Z9y_dnq_liH2PPWuQSpUlQ_Q2-q99HDN9Leg,977
|
@@ -78,8 +80,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=smDYCplrI5B38XJcNZ462ZeTo9l
|
|
78
80
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=VCCvbzhEeOcdZ0Unvk_anAcmbQuGogTQhK_bXs5RG9U,892
|
79
81
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=hJ90ym5ElI-BY_XS3VSLjcgQWfV0Pp1MdzTU6Qi8MFg,884
|
80
82
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=18V64kKYrJeSCrPmY3wYnshEISaf7xmrfw2Ny-6P3uE,859
|
81
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
82
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
83
|
+
nshtrainer/configs/trainer/__init__.py,sha256=QLCDVxVg1Ig-wgUW5r8I1FdPdbYz9-gse17s3R69Fw0,8019
|
84
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=YEcliai8jLOoB53lxT5BZIR3NzfoLb4x3VGbMakFVo4,7909
|
83
85
|
nshtrainer/configs/trainer/checkpoint_connector/__init__.py,sha256=pSu79zOFFWvqjI3SkHWl13H8ZNJFTc6a5J1r2KnfUKM,667
|
84
86
|
nshtrainer/configs/trainer/trainer/__init__.py,sha256=P-Y2DOZZcJtEdPjGEKCxq5R3JSzKhUUoidkSvO_cfKI,797
|
85
87
|
nshtrainer/configs/util/__init__.py,sha256=ZcmEqg2OWKKcPBqzDG1SnuaAMgR4_c0jog-Xg6QTUzc,4555
|
@@ -89,7 +91,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=NCXMVO-EUz3JvPmlDci72O9Z
|
|
89
91
|
nshtrainer/configs/util/config/duration/__init__.py,sha256=8llT1MCKQpsdNldN5h5Wo0GjUuRn28Sxw2FTXTNKBpM,1060
|
90
92
|
nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
|
91
93
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
92
|
-
nshtrainer/data/datamodule.py,sha256=
|
94
|
+
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
93
95
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
94
96
|
nshtrainer/loggers/__init__.py,sha256=11X6D_lF0cHAkxkYsVZY3Q3r30Fq0FUi9heeb5RD870,570
|
95
97
|
nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
|
@@ -104,7 +106,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5N
|
|
104
106
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
105
107
|
nshtrainer/metrics/_config.py,sha256=kR9OJLZXIaZpG8UpHG3EQwOL36HB8yPk6afYjoIM0XM,1324
|
106
108
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
107
|
-
nshtrainer/model/base.py,sha256=
|
109
|
+
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
108
110
|
nshtrainer/model/mixins/callback.py,sha256=Ea_legORzs0N078j0N9RJivDVeWH5KtXDpdJS75IwIo,3098
|
109
111
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
110
112
|
nshtrainer/model/mixins/logger.py,sha256=LQDJJbiv30PlWX6rTT_EhjNBNfUFfcvGz5sX4MnOCzI,5330
|
@@ -119,13 +121,12 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
|
|
119
121
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
120
122
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
121
123
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
122
|
-
nshtrainer/scripts/find_packages.py,sha256=-oNnSNPp3pujCVgManW_LFlJcnvhrtHgvUJ4W88-6-o,1460
|
123
124
|
nshtrainer/trainer/__init__.py,sha256=MmoydVS6aYeav7zgDAUHxAQrV_PMQsbnZTCuPnLH9Wk,128
|
124
|
-
nshtrainer/trainer/_config.py,sha256=
|
125
|
+
nshtrainer/trainer/_config.py,sha256=2S6Qhwn724n_jgGhWVI64Wi_pHKjU1ggoY4sxq-_SlA,32309
|
125
126
|
nshtrainer/trainer/_runtime_callback.py,sha256=T3epaj1YeIN0R8CS2cg5HNJIB21TyaD_PVNNOPJ6nJs,4200
|
126
127
|
nshtrainer/trainer/checkpoint_connector.py,sha256=pC1tTDcq0p6sAsoTmAbwINW49IfqupMMtnE9-AKdTUw,2824
|
127
128
|
nshtrainer/trainer/signal_connector.py,sha256=YMJf6vTnW0JcnBkuYikm9x_9XscaokrCEzCn4THOGao,10776
|
128
|
-
nshtrainer/trainer/trainer.py,sha256=
|
129
|
+
nshtrainer/trainer/trainer.py,sha256=kIXh_25jDJSGcwEyLjbvqWN0P5B35VBJLXOwXqUGqF4,19759
|
129
130
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
130
131
|
nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
|
131
132
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
@@ -138,6 +139,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
138
139
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
139
140
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
140
141
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
141
|
-
nshtrainer-1.0.
|
142
|
-
nshtrainer-1.0.
|
143
|
-
nshtrainer-1.0.
|
142
|
+
nshtrainer-1.0.0b13.dist-info/METADATA,sha256=9PQNipTw68KmSV_7Kt4fK_KtlYKSaKBcvvkBZrwWFtY,937
|
143
|
+
nshtrainer-1.0.0b13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
144
|
+
nshtrainer-1.0.0b13.dist-info/RECORD,,
|
@@ -1,52 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import argparse
|
4
|
-
import ast
|
5
|
-
import glob
|
6
|
-
import sys
|
7
|
-
from pathlib import Path
|
8
|
-
|
9
|
-
|
10
|
-
def get_imports(file_path: Path):
|
11
|
-
with open(file_path, "r") as file:
|
12
|
-
try:
|
13
|
-
tree = ast.parse(file.read())
|
14
|
-
except SyntaxError:
|
15
|
-
print(f"Syntax error in file: {file_path}", file=sys.stderr)
|
16
|
-
return set()
|
17
|
-
|
18
|
-
imports = set()
|
19
|
-
for node in ast.walk(tree):
|
20
|
-
if isinstance(node, ast.Import):
|
21
|
-
for alias in node.names:
|
22
|
-
imports.add(alias.name.split(".")[0])
|
23
|
-
elif isinstance(node, ast.ImportFrom):
|
24
|
-
if node.level == 0 and node.module: # Absolute import
|
25
|
-
imports.add(node.module.split(".")[0])
|
26
|
-
return imports
|
27
|
-
|
28
|
-
|
29
|
-
def main():
|
30
|
-
parser = argparse.ArgumentParser(
|
31
|
-
description="Find unique Python packages used in files."
|
32
|
-
)
|
33
|
-
parser.add_argument("glob_pattern", help="Glob pattern to match files")
|
34
|
-
parser.add_argument(
|
35
|
-
"--exclude-std", action="store_true", help="Exclude Python standard libraries"
|
36
|
-
)
|
37
|
-
args = parser.parse_args()
|
38
|
-
|
39
|
-
all_imports = set()
|
40
|
-
for file_path in glob.glob(args.glob_pattern, recursive=True):
|
41
|
-
all_imports.update(get_imports(Path(file_path)))
|
42
|
-
|
43
|
-
if args.exclude_std:
|
44
|
-
std_libs = set(sys.stdlib_module_names)
|
45
|
-
all_imports = all_imports - std_libs
|
46
|
-
|
47
|
-
for package in sorted(all_imports):
|
48
|
-
print(package)
|
49
|
-
|
50
|
-
|
51
|
-
if __name__ == "__main__":
|
52
|
-
main()
|
File without changes
|