nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
nshtrainer/loggers/wandb.py
CHANGED
@@ -15,23 +15,24 @@ from ..callbacks.wandb_watch import WandbWatchCallbackConfig
|
|
15
15
|
from ._base import BaseLoggerConfig
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
|
-
from ..
|
18
|
+
from ..trainer._config import TrainerConfig
|
19
|
+
|
19
20
|
|
20
21
|
log = logging.getLogger(__name__)
|
21
22
|
|
22
23
|
|
23
24
|
def _project_name(
|
24
|
-
|
25
|
+
trainer_config: TrainerConfig,
|
25
26
|
default_project: str = "lightning_logs",
|
26
27
|
):
|
27
28
|
# If the config has a project name, use that.
|
28
|
-
if project :=
|
29
|
+
if project := trainer_config.project:
|
29
30
|
return project
|
30
31
|
|
31
32
|
# Otherwise, we should use the name of the module that the config is defined in,
|
32
33
|
# if we can find it.
|
33
34
|
# If this isn't in a module, use the default project name.
|
34
|
-
if not (module :=
|
35
|
+
if not (module := trainer_config.__module__):
|
35
36
|
return default_project
|
36
37
|
|
37
38
|
# If the module is a package, use the package name.
|
@@ -129,7 +130,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
129
130
|
assert_never(self.log_model)
|
130
131
|
|
131
132
|
@override
|
132
|
-
def create_logger(self,
|
133
|
+
def create_logger(self, trainer_config):
|
133
134
|
if not self.enabled:
|
134
135
|
return None
|
135
136
|
|
@@ -171,31 +172,31 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
171
172
|
|
172
173
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
173
174
|
|
174
|
-
save_dir =
|
175
|
-
|
175
|
+
save_dir = trainer_config.directory._resolve_log_directory_for_logger(
|
176
|
+
trainer_config.id,
|
176
177
|
self,
|
177
178
|
)
|
178
179
|
return WandbLogger(
|
179
180
|
save_dir=save_dir,
|
180
|
-
project=self.project or _project_name(
|
181
|
-
name=
|
182
|
-
version=
|
181
|
+
project=self.project or _project_name(trainer_config),
|
182
|
+
name=trainer_config.full_name,
|
183
|
+
version=trainer_config.id,
|
183
184
|
log_model=self._lightning_log_model,
|
184
185
|
notes=(
|
185
|
-
"\n".join(f"- {note}" for note in
|
186
|
-
if
|
186
|
+
"\n".join(f"- {note}" for note in trainer_config.notes)
|
187
|
+
if trainer_config.notes
|
187
188
|
else None
|
188
189
|
),
|
189
|
-
tags=
|
190
|
+
tags=trainer_config.tags,
|
190
191
|
offline=self.offline,
|
191
192
|
)
|
192
193
|
|
193
194
|
@override
|
194
|
-
def create_callbacks(self,
|
195
|
+
def create_callbacks(self, trainer_config):
|
195
196
|
yield FinishWandbOnTeardownCallback()
|
196
197
|
|
197
198
|
if self.watch:
|
198
|
-
yield from self.watch.create_callbacks(
|
199
|
+
yield from self.watch.create_callbacks(trainer_config)
|
199
200
|
|
200
201
|
if self.log_code:
|
201
|
-
yield from self.log_code.create_callbacks(
|
202
|
+
yield from self.log_code.create_callbacks(trainer_config)
|
@@ -1,17 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Literal
|
4
4
|
|
5
5
|
from lightning.pytorch.utilities.types import LRSchedulerConfigType
|
6
6
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
7
7
|
from typing_extensions import override
|
8
8
|
|
9
|
-
from ..
|
9
|
+
from ..metrics._config import MetricConfig
|
10
10
|
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
11
11
|
|
12
|
-
if TYPE_CHECKING:
|
13
|
-
from ..model.base import BaseConfig
|
14
|
-
|
15
12
|
|
16
13
|
class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
17
14
|
"""Reduce learning rate when a metric has stopped improving."""
|
@@ -48,9 +45,14 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
48
45
|
self, optimizer, lightning_module
|
49
46
|
) -> LRSchedulerConfigType:
|
50
47
|
if (metric := self.metric) is None:
|
51
|
-
|
48
|
+
from ..trainer import Trainer
|
49
|
+
|
50
|
+
assert isinstance(
|
51
|
+
trainer := lightning_module.trainer, Trainer
|
52
|
+
), "The trainer must be a `nshtrainer.Trainer` instance."
|
53
|
+
|
52
54
|
assert (
|
53
|
-
metric :=
|
55
|
+
metric := trainer.hparams.primary_metric
|
54
56
|
) is not None, "Primary metric must be provided if metric is not specified."
|
55
57
|
|
56
58
|
lr_scheduler = ReduceLROnPlateau(
|
nshtrainer/model/__init__.py
CHANGED
@@ -1,7 +1,3 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from .base import LightningModuleBase as LightningModuleBase
|
4
|
-
from .config import BaseConfig as BaseConfig
|
5
|
-
from .config import DirectoryConfig as DirectoryConfig
|
6
|
-
from .config import MetricConfig as MetricConfig
|
7
|
-
from .config import TrainerConfig as TrainerConfig
|
nshtrainer/model/base.py
CHANGED
@@ -1,28 +1,25 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import inspect
|
4
3
|
import logging
|
5
4
|
from abc import ABC, abstractmethod
|
6
|
-
from collections.abc import
|
7
|
-
from typing import
|
5
|
+
from collections.abc import Mapping
|
6
|
+
from typing import Any, Generic, Literal, cast
|
8
7
|
|
8
|
+
import nshconfig as C
|
9
9
|
import torch
|
10
10
|
import torch.distributed
|
11
|
-
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
12
11
|
from lightning.pytorch import LightningModule
|
13
12
|
from lightning.pytorch.profilers import PassThroughProfiler, Profiler
|
14
|
-
from
|
15
|
-
from typing_extensions import Self, TypeVar, override
|
13
|
+
from typing_extensions import Never, TypeVar, deprecated, override
|
16
14
|
|
17
15
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
18
|
-
from ..util._environment_info import EnvironmentConfig
|
19
|
-
from .config import BaseConfig
|
20
16
|
from .mixins.callback import CallbackModuleMixin
|
17
|
+
from .mixins.debug import _DebugModuleMixin, _trainer
|
21
18
|
from .mixins.logger import LoggerLightningModuleMixin
|
22
19
|
|
23
20
|
log = logging.getLogger(__name__)
|
24
21
|
|
25
|
-
THparams = TypeVar("THparams", bound=
|
22
|
+
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
26
23
|
|
27
24
|
|
28
25
|
T = TypeVar("T", infer_variance=True)
|
@@ -53,7 +50,8 @@ VALID_REDUCE_OPS = (
|
|
53
50
|
)
|
54
51
|
|
55
52
|
|
56
|
-
class LightningModuleBase(
|
53
|
+
class LightningModuleBase(
|
54
|
+
_DebugModuleMixin,
|
57
55
|
_RLPSanityCheckModuleMixin,
|
58
56
|
LoggerLightningModuleMixin,
|
59
57
|
CallbackModuleMixin,
|
@@ -61,21 +59,36 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
61
59
|
ABC,
|
62
60
|
Generic[THparams],
|
63
61
|
):
|
64
|
-
# region
|
65
|
-
@torch.jit.unused
|
66
|
-
@property
|
67
|
-
def config(self) -> THparams:
|
68
|
-
return self.hparams
|
69
|
-
|
62
|
+
# region Debug
|
70
63
|
@property
|
71
64
|
def debug(self) -> bool:
|
72
65
|
if torch.jit.is_scripting():
|
73
66
|
return False
|
74
|
-
return self.config.debug
|
75
67
|
|
76
|
-
|
68
|
+
if (trainer := self._trainer) is None:
|
69
|
+
return False
|
77
70
|
|
78
|
-
|
71
|
+
from ..trainer import Trainer
|
72
|
+
|
73
|
+
if not isinstance(trainer, Trainer):
|
74
|
+
return False
|
75
|
+
|
76
|
+
return trainer.debug
|
77
|
+
|
78
|
+
@debug.setter
|
79
|
+
def debug(self, value: bool):
|
80
|
+
if torch.jit.is_scripting():
|
81
|
+
return
|
82
|
+
|
83
|
+
if (trainer := self._trainer) is None:
|
84
|
+
return
|
85
|
+
|
86
|
+
from ..trainer import Trainer
|
87
|
+
|
88
|
+
if not isinstance(trainer, Trainer):
|
89
|
+
return
|
90
|
+
|
91
|
+
trainer.debug = value
|
79
92
|
|
80
93
|
@torch.jit.unused
|
81
94
|
def breakpoint(self, rank_zero_only: bool = True):
|
@@ -146,7 +159,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
146
159
|
return object_list
|
147
160
|
|
148
161
|
def barrier(self, name: str | None = None):
|
149
|
-
self.trainer.strategy.barrier(name=name)
|
162
|
+
return self.trainer.strategy.barrier(name=name)
|
150
163
|
|
151
164
|
def reduce(
|
152
165
|
self,
|
@@ -170,7 +183,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
170
183
|
@override
|
171
184
|
def __repr__(self):
|
172
185
|
parts: list[str] = []
|
173
|
-
parts.append(f"
|
186
|
+
parts.append(f"hparams={repr(self.hparams)}")
|
174
187
|
parts.append(f"device={self.device}")
|
175
188
|
if self.debug:
|
176
189
|
parts.append("debug=True")
|
@@ -178,85 +191,46 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
178
191
|
parts_str = ", ".join(parts)
|
179
192
|
return f"{self.__class__.__name__}({parts_str})"
|
180
193
|
|
181
|
-
@
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
return
|
194
|
+
@property
|
195
|
+
@override
|
196
|
+
def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
197
|
+
return cast(THparams, super().hparams)
|
186
198
|
|
187
|
-
|
188
|
-
|
199
|
+
@property
|
200
|
+
@override
|
201
|
+
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
202
|
+
hparams = cast(THparams, super().hparams_initial)
|
203
|
+
hparams_dict = {"model": hparams.model_dump(mode="json")}
|
204
|
+
if (trainer := self._trainer) is not None:
|
205
|
+
from ..trainer import Trainer
|
189
206
|
|
190
|
-
|
191
|
-
|
192
|
-
_ = parameters.pop("self", None)
|
193
|
-
if len(parameters) != 1:
|
194
|
-
raise TypeError(
|
195
|
-
f"__init__ must take a single argument, got {len(parameters)}: {init_fn}"
|
196
|
-
)
|
207
|
+
if isinstance(trainer, Trainer):
|
208
|
+
hparams_dict["trainer"] = trainer.hparams.model_dump(mode="json")
|
197
209
|
|
198
|
-
|
199
|
-
raise TypeError(
|
200
|
-
f"__init__'s argument must be named 'hparams', got {parameters}"
|
201
|
-
)
|
210
|
+
return cast(Never, hparams_dict)
|
202
211
|
|
203
|
-
|
204
|
-
|
212
|
+
@property
|
213
|
+
@deprecated("Use `hparams` instead")
|
214
|
+
def config(self):
|
215
|
+
return cast(Never, self.hparams)
|
205
216
|
|
206
217
|
@classmethod
|
207
218
|
@abstractmethod
|
208
|
-
def
|
209
|
-
|
210
|
-
@classmethod
|
211
|
-
def load_checkpoint(
|
212
|
-
cls,
|
213
|
-
checkpoint_path: _PATH | IO,
|
214
|
-
hparams: THparams | MutableMapping[str, Any] | None = None,
|
215
|
-
map_location: _MAP_LOCATION_TYPE = None,
|
216
|
-
strict: bool = True,
|
217
|
-
) -> Self:
|
218
|
-
if strict:
|
219
|
-
cls._validate_class_for_ckpt_loading()
|
220
|
-
|
221
|
-
kwargs: dict[str, Any] = {}
|
222
|
-
if hparams is not None:
|
223
|
-
kwargs["hparams"] = hparams
|
224
|
-
|
225
|
-
return super().load_from_checkpoint(
|
226
|
-
checkpoint_path,
|
227
|
-
map_location=map_location,
|
228
|
-
hparams_file=None,
|
229
|
-
strict=strict,
|
230
|
-
**kwargs,
|
231
|
-
)
|
232
|
-
|
233
|
-
def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
|
234
|
-
"""
|
235
|
-
Override this method to update the hparams dictionary before it is used to create the hparams object.
|
236
|
-
Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
|
237
|
-
"""
|
238
|
-
return hparams
|
239
|
-
|
240
|
-
def pre_init_update_hparams(self, hparams: THparams):
|
241
|
-
"""
|
242
|
-
Override this method to update the hparams object before it is used to create the hparams_initial object.
|
243
|
-
"""
|
244
|
-
return hparams
|
219
|
+
def hparams_cls(cls) -> type[THparams]: ...
|
245
220
|
|
246
221
|
@override
|
247
|
-
def __init__(self, hparams: THparams |
|
248
|
-
if not isinstance(hparams, BaseConfig):
|
249
|
-
if not isinstance(hparams, MutableMapping):
|
250
|
-
raise TypeError(
|
251
|
-
f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
|
252
|
-
)
|
253
|
-
|
254
|
-
hparams = self.pre_init_update_hparams_dict(hparams)
|
255
|
-
hparams = self.config_cls().model_validate(hparams)
|
256
|
-
hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
|
257
|
-
hparams = self.pre_init_update_hparams(hparams)
|
258
|
-
|
222
|
+
def __init__(self, hparams: THparams | Mapping[str, Any]):
|
259
223
|
super().__init__()
|
224
|
+
|
225
|
+
# Validate and save hyperparameters
|
226
|
+
hparams_cls = self.hparams_cls()
|
227
|
+
if isinstance(hparams, Mapping):
|
228
|
+
hparams = hparams_cls.model_validate(hparams)
|
229
|
+
elif not isinstance(hparams, hparams_cls):
|
230
|
+
raise TypeError(
|
231
|
+
f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
|
232
|
+
)
|
233
|
+
hparams = hparams.model_deep_validate()
|
260
234
|
self.save_hyperparameters(hparams)
|
261
235
|
|
262
236
|
def zero_loss(self):
|
@@ -267,260 +241,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
267
241
|
loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
|
268
242
|
loss = cast(torch.Tensor, loss)
|
269
243
|
return loss
|
270
|
-
|
271
|
-
if TYPE_CHECKING:
|
272
|
-
|
273
|
-
@override
|
274
|
-
def training_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
275
|
-
self,
|
276
|
-
batch: Any,
|
277
|
-
batch_idx: int,
|
278
|
-
) -> Any:
|
279
|
-
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
|
280
|
-
logger.
|
281
|
-
|
282
|
-
Args:
|
283
|
-
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
284
|
-
batch_idx: The index of this batch.
|
285
|
-
dataloader_idx: The index of the dataloader that produced this batch.
|
286
|
-
(only if multiple dataloaders used)
|
287
|
-
|
288
|
-
Return:
|
289
|
-
- :class:`~torch.Tensor` - The loss tensor
|
290
|
-
- ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
|
291
|
-
automatic optimization.
|
292
|
-
- ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
|
293
|
-
multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
|
294
|
-
the loss is not required.
|
295
|
-
|
296
|
-
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
297
|
-
You can also do fancier things like multiple forward passes or something model specific.
|
298
|
-
|
299
|
-
Example::
|
300
|
-
|
301
|
-
def training_step(self, batch, batch_idx):
|
302
|
-
x, y, z = batch
|
303
|
-
out = self.encoder(x)
|
304
|
-
loss = self.loss(out, x)
|
305
|
-
return loss
|
306
|
-
|
307
|
-
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
|
308
|
-
|
309
|
-
.. code-block:: python
|
310
|
-
|
311
|
-
def __init__(self):
|
312
|
-
super().__init__()
|
313
|
-
self.automatic_optimization = False
|
314
|
-
|
315
|
-
|
316
|
-
# Multiple optimizers (e.g.: GANs)
|
317
|
-
def training_step(self, batch, batch_idx):
|
318
|
-
opt1, opt2 = self.optimizers()
|
319
|
-
|
320
|
-
# do training_step with encoder
|
321
|
-
...
|
322
|
-
opt1.step()
|
323
|
-
# do training_step with decoder
|
324
|
-
...
|
325
|
-
opt2.step()
|
326
|
-
|
327
|
-
Note:
|
328
|
-
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
|
329
|
-
normalized by ``accumulate_grad_batches`` internally.
|
330
|
-
|
331
|
-
"""
|
332
|
-
raise NotImplementedError
|
333
|
-
|
334
|
-
@override
|
335
|
-
def validation_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
336
|
-
self,
|
337
|
-
batch: Any,
|
338
|
-
batch_idx: int,
|
339
|
-
) -> STEP_OUTPUT:
|
340
|
-
r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
|
341
|
-
calculate anything of interest like accuracy.
|
342
|
-
|
343
|
-
Args:
|
344
|
-
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
345
|
-
batch_idx: The index of this batch.
|
346
|
-
dataloader_idx: The index of the dataloader that produced this batch.
|
347
|
-
(only if multiple dataloaders used)
|
348
|
-
|
349
|
-
Return:
|
350
|
-
- :class:`~torch.Tensor` - The loss tensor
|
351
|
-
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
|
352
|
-
- ``None`` - Skip to the next batch.
|
353
|
-
|
354
|
-
.. code-block:: python
|
355
|
-
|
356
|
-
# if you have one val dataloader:
|
357
|
-
def validation_step(self, batch, batch_idx): ...
|
358
|
-
|
359
|
-
|
360
|
-
# if you have multiple val dataloaders:
|
361
|
-
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
|
362
|
-
|
363
|
-
Examples::
|
364
|
-
|
365
|
-
# CASE 1: A single validation dataset
|
366
|
-
def validation_step(self, batch, batch_idx):
|
367
|
-
x, y = batch
|
368
|
-
|
369
|
-
# implement your own
|
370
|
-
out = self(x)
|
371
|
-
loss = self.loss(out, y)
|
372
|
-
|
373
|
-
# log 6 example images
|
374
|
-
# or generated text... or whatever
|
375
|
-
sample_imgs = x[:6]
|
376
|
-
grid = torchvision.utils.make_grid(sample_imgs)
|
377
|
-
self.logger.experiment.add_image('example_images', grid, 0)
|
378
|
-
|
379
|
-
# calculate acc
|
380
|
-
labels_hat = torch.argmax(out, dim=1)
|
381
|
-
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
382
|
-
|
383
|
-
# log the outputs!
|
384
|
-
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
|
385
|
-
|
386
|
-
If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
|
387
|
-
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
|
388
|
-
|
389
|
-
.. code-block:: python
|
390
|
-
|
391
|
-
# CASE 2: multiple validation dataloaders
|
392
|
-
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
393
|
-
# dataloader_idx tells you which dataset this is.
|
394
|
-
...
|
395
|
-
|
396
|
-
Note:
|
397
|
-
If you don't need to validate you don't need to implement this method.
|
398
|
-
|
399
|
-
Note:
|
400
|
-
When the :meth:`validation_step` is called, the model has been put in eval mode
|
401
|
-
and PyTorch gradients have been disabled. At the end of validation,
|
402
|
-
the model goes back to training mode and gradients are enabled.
|
403
|
-
|
404
|
-
"""
|
405
|
-
raise NotImplementedError
|
406
|
-
|
407
|
-
@override
|
408
|
-
def test_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
409
|
-
self,
|
410
|
-
batch: Any,
|
411
|
-
batch_idx: int,
|
412
|
-
) -> STEP_OUTPUT:
|
413
|
-
r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
|
414
|
-
calculate anything of interest such as accuracy.
|
415
|
-
|
416
|
-
Args:
|
417
|
-
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
418
|
-
batch_idx: The index of this batch.
|
419
|
-
dataloader_idx: The index of the dataloader that produced this batch.
|
420
|
-
(only if multiple dataloaders used)
|
421
|
-
|
422
|
-
Return:
|
423
|
-
- :class:`~torch.Tensor` - The loss tensor
|
424
|
-
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
|
425
|
-
- ``None`` - Skip to the next batch.
|
426
|
-
|
427
|
-
.. code-block:: python
|
428
|
-
|
429
|
-
# if you have one test dataloader:
|
430
|
-
def test_step(self, batch, batch_idx): ...
|
431
|
-
|
432
|
-
|
433
|
-
# if you have multiple test dataloaders:
|
434
|
-
def test_step(self, batch, batch_idx, dataloader_idx=0): ...
|
435
|
-
|
436
|
-
Examples::
|
437
|
-
|
438
|
-
# CASE 1: A single test dataset
|
439
|
-
def test_step(self, batch, batch_idx):
|
440
|
-
x, y = batch
|
441
|
-
|
442
|
-
# implement your own
|
443
|
-
out = self(x)
|
444
|
-
loss = self.loss(out, y)
|
445
|
-
|
446
|
-
# log 6 example images
|
447
|
-
# or generated text... or whatever
|
448
|
-
sample_imgs = x[:6]
|
449
|
-
grid = torchvision.utils.make_grid(sample_imgs)
|
450
|
-
self.logger.experiment.add_image('example_images', grid, 0)
|
451
|
-
|
452
|
-
# calculate acc
|
453
|
-
labels_hat = torch.argmax(out, dim=1)
|
454
|
-
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
455
|
-
|
456
|
-
# log the outputs!
|
457
|
-
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
|
458
|
-
|
459
|
-
If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
|
460
|
-
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
|
461
|
-
|
462
|
-
.. code-block:: python
|
463
|
-
|
464
|
-
# CASE 2: multiple test dataloaders
|
465
|
-
def test_step(self, batch, batch_idx, dataloader_idx=0):
|
466
|
-
# dataloader_idx tells you which dataset this is.
|
467
|
-
...
|
468
|
-
|
469
|
-
Note:
|
470
|
-
If you don't need to test you don't need to implement this method.
|
471
|
-
|
472
|
-
Note:
|
473
|
-
When the :meth:`test_step` is called, the model has been put in eval mode and
|
474
|
-
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
475
|
-
to training mode and gradients are enabled.
|
476
|
-
|
477
|
-
"""
|
478
|
-
raise NotImplementedError
|
479
|
-
|
480
|
-
@override
|
481
|
-
def predict_step( # pyright: ignore[reportIncompatibleMethodOverride]
|
482
|
-
self,
|
483
|
-
batch: Any,
|
484
|
-
batch_idx: int,
|
485
|
-
) -> STEP_OUTPUT:
|
486
|
-
"""Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
|
487
|
-
:meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
|
488
|
-
|
489
|
-
The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
|
490
|
-
to scale inference on multi-devices.
|
491
|
-
|
492
|
-
To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
|
493
|
-
callback to write the predictions to disk or database after each batch or on epoch end.
|
494
|
-
|
495
|
-
The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
|
496
|
-
based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
|
497
|
-
or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
|
498
|
-
|
499
|
-
Args:
|
500
|
-
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
501
|
-
batch_idx: The index of this batch.
|
502
|
-
dataloader_idx: The index of the dataloader that produced this batch.
|
503
|
-
(only if multiple dataloaders used)
|
504
|
-
|
505
|
-
Return:
|
506
|
-
Predicted output (optional).
|
507
|
-
|
508
|
-
Example ::
|
509
|
-
|
510
|
-
class MyModel(LightningModule):
|
511
|
-
|
512
|
-
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
513
|
-
return self(batch)
|
514
|
-
|
515
|
-
dm = ...
|
516
|
-
model = MyModel()
|
517
|
-
trainer = Trainer(accelerator="gpu", devices=2)
|
518
|
-
predictions = trainer.predict(model, dm)
|
519
|
-
|
520
|
-
"""
|
521
|
-
prediction = self(batch)
|
522
|
-
return {
|
523
|
-
"prediction": prediction,
|
524
|
-
"batch": batch,
|
525
|
-
"batch_idx": batch_idx,
|
526
|
-
}
|
@@ -2,16 +2,18 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Callable, Iterable, Sequence
|
5
|
-
from typing import Any, TypeAlias, cast
|
5
|
+
from typing import Any, TypeAlias, cast
|
6
6
|
|
7
7
|
from lightning.pytorch import Callback, LightningModule
|
8
8
|
from typing_extensions import override
|
9
9
|
|
10
|
+
from ..._callback import NTCallbackBase
|
10
11
|
from ...util.typing_utils import mixin_base_type
|
11
12
|
|
12
13
|
log = logging.getLogger(__name__)
|
13
14
|
|
14
|
-
|
15
|
+
_Callback = Callback | NTCallbackBase
|
16
|
+
CallbackFn: TypeAlias = Callable[[], _Callback | Iterable[_Callback] | None]
|
15
17
|
|
16
18
|
|
17
19
|
class CallbackRegistrarModuleMixin:
|
@@ -23,7 +25,7 @@ class CallbackRegistrarModuleMixin:
|
|
23
25
|
|
24
26
|
def register_callback(
|
25
27
|
self,
|
26
|
-
callback:
|
28
|
+
callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
|
27
29
|
):
|
28
30
|
if not callable(callback):
|
29
31
|
callback_ = cast(CallbackFn, lambda: callback)
|
@@ -34,8 +36,26 @@ class CallbackRegistrarModuleMixin:
|
|
34
36
|
|
35
37
|
|
36
38
|
class CallbackModuleMixin(
|
37
|
-
CallbackRegistrarModuleMixin,
|
39
|
+
CallbackRegistrarModuleMixin,
|
40
|
+
mixin_base_type(LightningModule),
|
38
41
|
):
|
42
|
+
@property
|
43
|
+
def _nshtrainer_callbacks(self) -> list[CallbackFn]:
|
44
|
+
if not hasattr(self, "_private_nshtrainer_callbacks_list"):
|
45
|
+
self._private_nshtrainer_callbacks_list = []
|
46
|
+
return self._private_nshtrainer_callbacks_list
|
47
|
+
|
48
|
+
def register_callback(
|
49
|
+
self,
|
50
|
+
callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
|
51
|
+
):
|
52
|
+
if not callable(callback):
|
53
|
+
callback_ = cast(CallbackFn, lambda: callback)
|
54
|
+
else:
|
55
|
+
callback_ = callback
|
56
|
+
|
57
|
+
self._nshtrainer_callbacks.append(callback_)
|
58
|
+
|
39
59
|
def _gather_all_callbacks(self):
|
40
60
|
modules: list[Any] = []
|
41
61
|
if isinstance(self, CallbackRegistrarModuleMixin):
|
@@ -52,7 +72,6 @@ class CallbackModuleMixin(
|
|
52
72
|
for module in modules:
|
53
73
|
yield from module._nshtrainer_callbacks
|
54
74
|
|
55
|
-
@final
|
56
75
|
@override
|
57
76
|
def configure_callbacks(self):
|
58
77
|
callbacks = super().configure_callbacks()
|