nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -15,23 +15,24 @@ from ..callbacks.wandb_watch import WandbWatchCallbackConfig
15
15
  from ._base import BaseLoggerConfig
16
16
 
17
17
  if TYPE_CHECKING:
18
- from ..model.config import BaseConfig
18
+ from ..trainer._config import TrainerConfig
19
+
19
20
 
20
21
  log = logging.getLogger(__name__)
21
22
 
22
23
 
23
24
  def _project_name(
24
- root_config: "BaseConfig",
25
+ trainer_config: TrainerConfig,
25
26
  default_project: str = "lightning_logs",
26
27
  ):
27
28
  # If the config has a project name, use that.
28
- if project := root_config.project:
29
+ if project := trainer_config.project:
29
30
  return project
30
31
 
31
32
  # Otherwise, we should use the name of the module that the config is defined in,
32
33
  # if we can find it.
33
34
  # If this isn't in a module, use the default project name.
34
- if not (module := root_config.__module__):
35
+ if not (module := trainer_config.__module__):
35
36
  return default_project
36
37
 
37
38
  # If the module is a package, use the package name.
@@ -129,7 +130,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
129
130
  assert_never(self.log_model)
130
131
 
131
132
  @override
132
- def create_logger(self, root_config):
133
+ def create_logger(self, trainer_config):
133
134
  if not self.enabled:
134
135
  return None
135
136
 
@@ -171,31 +172,31 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
171
172
 
172
173
  from lightning.pytorch.loggers.wandb import WandbLogger
173
174
 
174
- save_dir = root_config.directory._resolve_log_directory_for_logger(
175
- root_config.id,
175
+ save_dir = trainer_config.directory._resolve_log_directory_for_logger(
176
+ trainer_config.id,
176
177
  self,
177
178
  )
178
179
  return WandbLogger(
179
180
  save_dir=save_dir,
180
- project=self.project or _project_name(root_config),
181
- name=root_config.run_name,
182
- version=root_config.id,
181
+ project=self.project or _project_name(trainer_config),
182
+ name=trainer_config.full_name,
183
+ version=trainer_config.id,
183
184
  log_model=self._lightning_log_model,
184
185
  notes=(
185
- "\n".join(f"- {note}" for note in root_config.notes)
186
- if root_config.notes
186
+ "\n".join(f"- {note}" for note in trainer_config.notes)
187
+ if trainer_config.notes
187
188
  else None
188
189
  ),
189
- tags=root_config.tags,
190
+ tags=trainer_config.tags,
190
191
  offline=self.offline,
191
192
  )
192
193
 
193
194
  @override
194
- def create_callbacks(self, root_config):
195
+ def create_callbacks(self, trainer_config):
195
196
  yield FinishWandbOnTeardownCallback()
196
197
 
197
198
  if self.watch:
198
- yield from self.watch.create_callbacks(root_config)
199
+ yield from self.watch.create_callbacks(trainer_config)
199
200
 
200
201
  if self.log_code:
201
- yield from self.log_code.create_callbacks(root_config)
202
+ yield from self.log_code.create_callbacks(trainer_config)
@@ -1,17 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Literal, cast
3
+ from typing import Literal
4
4
 
5
5
  from lightning.pytorch.utilities.types import LRSchedulerConfigType
6
6
  from torch.optim.lr_scheduler import ReduceLROnPlateau
7
7
  from typing_extensions import override
8
8
 
9
- from ..model.config import MetricConfig
9
+ from ..metrics._config import MetricConfig
10
10
  from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
- if TYPE_CHECKING:
13
- from ..model.base import BaseConfig
14
-
15
12
 
16
13
  class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
17
14
  """Reduce learning rate when a metric has stopped improving."""
@@ -48,9 +45,14 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
48
45
  self, optimizer, lightning_module
49
46
  ) -> LRSchedulerConfigType:
50
47
  if (metric := self.metric) is None:
51
- lm_config = cast("BaseConfig", lightning_module.config)
48
+ from ..trainer import Trainer
49
+
50
+ assert isinstance(
51
+ trainer := lightning_module.trainer, Trainer
52
+ ), "The trainer must be a `nshtrainer.Trainer` instance."
53
+
52
54
  assert (
53
- metric := lm_config.primary_metric
55
+ metric := trainer.hparams.primary_metric
54
56
  ) is not None, "Primary metric must be provided if metric is not specified."
55
57
 
56
58
  lr_scheduler = ReduceLROnPlateau(
@@ -1,7 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from .base import LightningModuleBase as LightningModuleBase
4
- from .config import BaseConfig as BaseConfig
5
- from .config import DirectoryConfig as DirectoryConfig
6
- from .config import MetricConfig as MetricConfig
7
- from .config import TrainerConfig as TrainerConfig
nshtrainer/model/base.py CHANGED
@@ -1,28 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- import inspect
4
3
  import logging
5
4
  from abc import ABC, abstractmethod
6
- from collections.abc import MutableMapping
7
- from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
5
+ from collections.abc import Mapping
6
+ from typing import Any, Generic, Literal, cast
8
7
 
8
+ import nshconfig as C
9
9
  import torch
10
10
  import torch.distributed
11
- from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
12
11
  from lightning.pytorch import LightningModule
13
12
  from lightning.pytorch.profilers import PassThroughProfiler, Profiler
14
- from lightning.pytorch.utilities.types import STEP_OUTPUT
15
- from typing_extensions import Self, TypeVar, override
13
+ from typing_extensions import Never, TypeVar, deprecated, override
16
14
 
17
15
  from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
18
- from ..util._environment_info import EnvironmentConfig
19
- from .config import BaseConfig
20
16
  from .mixins.callback import CallbackModuleMixin
17
+ from .mixins.debug import _DebugModuleMixin, _trainer
21
18
  from .mixins.logger import LoggerLightningModuleMixin
22
19
 
23
20
  log = logging.getLogger(__name__)
24
21
 
25
- THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
22
+ THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
26
23
 
27
24
 
28
25
  T = TypeVar("T", infer_variance=True)
@@ -53,7 +50,8 @@ VALID_REDUCE_OPS = (
53
50
  )
54
51
 
55
52
 
56
- class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
53
+ class LightningModuleBase(
54
+ _DebugModuleMixin,
57
55
  _RLPSanityCheckModuleMixin,
58
56
  LoggerLightningModuleMixin,
59
57
  CallbackModuleMixin,
@@ -61,21 +59,36 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
61
59
  ABC,
62
60
  Generic[THparams],
63
61
  ):
64
- # region Config
65
- @torch.jit.unused
66
- @property
67
- def config(self) -> THparams:
68
- return self.hparams
69
-
62
+ # region Debug
70
63
  @property
71
64
  def debug(self) -> bool:
72
65
  if torch.jit.is_scripting():
73
66
  return False
74
- return self.config.debug
75
67
 
76
- # endregion
68
+ if (trainer := self._trainer) is None:
69
+ return False
77
70
 
78
- # region Debug
71
+ from ..trainer import Trainer
72
+
73
+ if not isinstance(trainer, Trainer):
74
+ return False
75
+
76
+ return trainer.debug
77
+
78
+ @debug.setter
79
+ def debug(self, value: bool):
80
+ if torch.jit.is_scripting():
81
+ return
82
+
83
+ if (trainer := self._trainer) is None:
84
+ return
85
+
86
+ from ..trainer import Trainer
87
+
88
+ if not isinstance(trainer, Trainer):
89
+ return
90
+
91
+ trainer.debug = value
79
92
 
80
93
  @torch.jit.unused
81
94
  def breakpoint(self, rank_zero_only: bool = True):
@@ -146,7 +159,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
146
159
  return object_list
147
160
 
148
161
  def barrier(self, name: str | None = None):
149
- self.trainer.strategy.barrier(name=name)
162
+ return self.trainer.strategy.barrier(name=name)
150
163
 
151
164
  def reduce(
152
165
  self,
@@ -170,7 +183,7 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
170
183
  @override
171
184
  def __repr__(self):
172
185
  parts: list[str] = []
173
- parts.append(f"config={self.hparams.concise_repr()}")
186
+ parts.append(f"hparams={repr(self.hparams)}")
174
187
  parts.append(f"device={self.device}")
175
188
  if self.debug:
176
189
  parts.append("debug=True")
@@ -178,85 +191,46 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
178
191
  parts_str = ", ".join(parts)
179
192
  return f"{self.__class__.__name__}({parts_str})"
180
193
 
181
- @classmethod
182
- def _validate_class_for_ckpt_loading(cls):
183
- # Make sure that the `__init__` method takes a single argument, `hparams`.
184
- if (init_fn := getattr(cls, "__init__", None)) is None:
185
- return
194
+ @property
195
+ @override
196
+ def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
197
+ return cast(THparams, super().hparams)
186
198
 
187
- if not inspect.isfunction(init_fn):
188
- raise TypeError(f"__init__ must be a function: {init_fn}")
199
+ @property
200
+ @override
201
+ def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
202
+ hparams = cast(THparams, super().hparams_initial)
203
+ hparams_dict = {"model": hparams.model_dump(mode="json")}
204
+ if (trainer := self._trainer) is not None:
205
+ from ..trainer import Trainer
189
206
 
190
- parameters = dict(inspect.signature(init_fn).parameters)
191
- # Remove the "self" parameter.
192
- _ = parameters.pop("self", None)
193
- if len(parameters) != 1:
194
- raise TypeError(
195
- f"__init__ must take a single argument, got {len(parameters)}: {init_fn}"
196
- )
207
+ if isinstance(trainer, Trainer):
208
+ hparams_dict["trainer"] = trainer.hparams.model_dump(mode="json")
197
209
 
198
- if "hparams" not in parameters:
199
- raise TypeError(
200
- f"__init__'s argument must be named 'hparams', got {parameters}"
201
- )
210
+ return cast(Never, hparams_dict)
202
211
 
203
- hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
204
- hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
212
+ @property
213
+ @deprecated("Use `hparams` instead")
214
+ def config(self):
215
+ return cast(Never, self.hparams)
205
216
 
206
217
  @classmethod
207
218
  @abstractmethod
208
- def config_cls(cls) -> type[THparams]: ...
209
-
210
- @classmethod
211
- def load_checkpoint(
212
- cls,
213
- checkpoint_path: _PATH | IO,
214
- hparams: THparams | MutableMapping[str, Any] | None = None,
215
- map_location: _MAP_LOCATION_TYPE = None,
216
- strict: bool = True,
217
- ) -> Self:
218
- if strict:
219
- cls._validate_class_for_ckpt_loading()
220
-
221
- kwargs: dict[str, Any] = {}
222
- if hparams is not None:
223
- kwargs["hparams"] = hparams
224
-
225
- return super().load_from_checkpoint(
226
- checkpoint_path,
227
- map_location=map_location,
228
- hparams_file=None,
229
- strict=strict,
230
- **kwargs,
231
- )
232
-
233
- def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
234
- """
235
- Override this method to update the hparams dictionary before it is used to create the hparams object.
236
- Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
237
- """
238
- return hparams
239
-
240
- def pre_init_update_hparams(self, hparams: THparams):
241
- """
242
- Override this method to update the hparams object before it is used to create the hparams_initial object.
243
- """
244
- return hparams
219
+ def hparams_cls(cls) -> type[THparams]: ...
245
220
 
246
221
  @override
247
- def __init__(self, hparams: THparams | MutableMapping[str, Any]):
248
- if not isinstance(hparams, BaseConfig):
249
- if not isinstance(hparams, MutableMapping):
250
- raise TypeError(
251
- f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
252
- )
253
-
254
- hparams = self.pre_init_update_hparams_dict(hparams)
255
- hparams = self.config_cls().model_validate(hparams)
256
- hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
257
- hparams = self.pre_init_update_hparams(hparams)
258
-
222
+ def __init__(self, hparams: THparams | Mapping[str, Any]):
259
223
  super().__init__()
224
+
225
+ # Validate and save hyperparameters
226
+ hparams_cls = self.hparams_cls()
227
+ if isinstance(hparams, Mapping):
228
+ hparams = hparams_cls.model_validate(hparams)
229
+ elif not isinstance(hparams, hparams_cls):
230
+ raise TypeError(
231
+ f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
232
+ )
233
+ hparams = hparams.model_deep_validate()
260
234
  self.save_hyperparameters(hparams)
261
235
 
262
236
  def zero_loss(self):
@@ -267,260 +241,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
267
241
  loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
268
242
  loss = cast(torch.Tensor, loss)
269
243
  return loss
270
-
271
- if TYPE_CHECKING:
272
-
273
- @override
274
- def training_step( # pyright: ignore[reportIncompatibleMethodOverride]
275
- self,
276
- batch: Any,
277
- batch_idx: int,
278
- ) -> Any:
279
- r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
280
- logger.
281
-
282
- Args:
283
- batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
284
- batch_idx: The index of this batch.
285
- dataloader_idx: The index of the dataloader that produced this batch.
286
- (only if multiple dataloaders used)
287
-
288
- Return:
289
- - :class:`~torch.Tensor` - The loss tensor
290
- - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
291
- automatic optimization.
292
- - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
293
- multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
294
- the loss is not required.
295
-
296
- In this step you'd normally do the forward pass and calculate the loss for a batch.
297
- You can also do fancier things like multiple forward passes or something model specific.
298
-
299
- Example::
300
-
301
- def training_step(self, batch, batch_idx):
302
- x, y, z = batch
303
- out = self.encoder(x)
304
- loss = self.loss(out, x)
305
- return loss
306
-
307
- To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
308
-
309
- .. code-block:: python
310
-
311
- def __init__(self):
312
- super().__init__()
313
- self.automatic_optimization = False
314
-
315
-
316
- # Multiple optimizers (e.g.: GANs)
317
- def training_step(self, batch, batch_idx):
318
- opt1, opt2 = self.optimizers()
319
-
320
- # do training_step with encoder
321
- ...
322
- opt1.step()
323
- # do training_step with decoder
324
- ...
325
- opt2.step()
326
-
327
- Note:
328
- When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
329
- normalized by ``accumulate_grad_batches`` internally.
330
-
331
- """
332
- raise NotImplementedError
333
-
334
- @override
335
- def validation_step( # pyright: ignore[reportIncompatibleMethodOverride]
336
- self,
337
- batch: Any,
338
- batch_idx: int,
339
- ) -> STEP_OUTPUT:
340
- r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
341
- calculate anything of interest like accuracy.
342
-
343
- Args:
344
- batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
345
- batch_idx: The index of this batch.
346
- dataloader_idx: The index of the dataloader that produced this batch.
347
- (only if multiple dataloaders used)
348
-
349
- Return:
350
- - :class:`~torch.Tensor` - The loss tensor
351
- - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
352
- - ``None`` - Skip to the next batch.
353
-
354
- .. code-block:: python
355
-
356
- # if you have one val dataloader:
357
- def validation_step(self, batch, batch_idx): ...
358
-
359
-
360
- # if you have multiple val dataloaders:
361
- def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
362
-
363
- Examples::
364
-
365
- # CASE 1: A single validation dataset
366
- def validation_step(self, batch, batch_idx):
367
- x, y = batch
368
-
369
- # implement your own
370
- out = self(x)
371
- loss = self.loss(out, y)
372
-
373
- # log 6 example images
374
- # or generated text... or whatever
375
- sample_imgs = x[:6]
376
- grid = torchvision.utils.make_grid(sample_imgs)
377
- self.logger.experiment.add_image('example_images', grid, 0)
378
-
379
- # calculate acc
380
- labels_hat = torch.argmax(out, dim=1)
381
- val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
382
-
383
- # log the outputs!
384
- self.log_dict({'val_loss': loss, 'val_acc': val_acc})
385
-
386
- If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
387
- setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
388
-
389
- .. code-block:: python
390
-
391
- # CASE 2: multiple validation dataloaders
392
- def validation_step(self, batch, batch_idx, dataloader_idx=0):
393
- # dataloader_idx tells you which dataset this is.
394
- ...
395
-
396
- Note:
397
- If you don't need to validate you don't need to implement this method.
398
-
399
- Note:
400
- When the :meth:`validation_step` is called, the model has been put in eval mode
401
- and PyTorch gradients have been disabled. At the end of validation,
402
- the model goes back to training mode and gradients are enabled.
403
-
404
- """
405
- raise NotImplementedError
406
-
407
- @override
408
- def test_step( # pyright: ignore[reportIncompatibleMethodOverride]
409
- self,
410
- batch: Any,
411
- batch_idx: int,
412
- ) -> STEP_OUTPUT:
413
- r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
414
- calculate anything of interest such as accuracy.
415
-
416
- Args:
417
- batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
418
- batch_idx: The index of this batch.
419
- dataloader_idx: The index of the dataloader that produced this batch.
420
- (only if multiple dataloaders used)
421
-
422
- Return:
423
- - :class:`~torch.Tensor` - The loss tensor
424
- - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
425
- - ``None`` - Skip to the next batch.
426
-
427
- .. code-block:: python
428
-
429
- # if you have one test dataloader:
430
- def test_step(self, batch, batch_idx): ...
431
-
432
-
433
- # if you have multiple test dataloaders:
434
- def test_step(self, batch, batch_idx, dataloader_idx=0): ...
435
-
436
- Examples::
437
-
438
- # CASE 1: A single test dataset
439
- def test_step(self, batch, batch_idx):
440
- x, y = batch
441
-
442
- # implement your own
443
- out = self(x)
444
- loss = self.loss(out, y)
445
-
446
- # log 6 example images
447
- # or generated text... or whatever
448
- sample_imgs = x[:6]
449
- grid = torchvision.utils.make_grid(sample_imgs)
450
- self.logger.experiment.add_image('example_images', grid, 0)
451
-
452
- # calculate acc
453
- labels_hat = torch.argmax(out, dim=1)
454
- test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
455
-
456
- # log the outputs!
457
- self.log_dict({'test_loss': loss, 'test_acc': test_acc})
458
-
459
- If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
460
- setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
461
-
462
- .. code-block:: python
463
-
464
- # CASE 2: multiple test dataloaders
465
- def test_step(self, batch, batch_idx, dataloader_idx=0):
466
- # dataloader_idx tells you which dataset this is.
467
- ...
468
-
469
- Note:
470
- If you don't need to test you don't need to implement this method.
471
-
472
- Note:
473
- When the :meth:`test_step` is called, the model has been put in eval mode and
474
- PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
475
- to training mode and gradients are enabled.
476
-
477
- """
478
- raise NotImplementedError
479
-
480
- @override
481
- def predict_step( # pyright: ignore[reportIncompatibleMethodOverride]
482
- self,
483
- batch: Any,
484
- batch_idx: int,
485
- ) -> STEP_OUTPUT:
486
- """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
487
- :meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
488
-
489
- The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
490
- to scale inference on multi-devices.
491
-
492
- To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
493
- callback to write the predictions to disk or database after each batch or on epoch end.
494
-
495
- The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
496
- based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
497
- or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
498
-
499
- Args:
500
- batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
501
- batch_idx: The index of this batch.
502
- dataloader_idx: The index of the dataloader that produced this batch.
503
- (only if multiple dataloaders used)
504
-
505
- Return:
506
- Predicted output (optional).
507
-
508
- Example ::
509
-
510
- class MyModel(LightningModule):
511
-
512
- def predict_step(self, batch, batch_idx, dataloader_idx=0):
513
- return self(batch)
514
-
515
- dm = ...
516
- model = MyModel()
517
- trainer = Trainer(accelerator="gpu", devices=2)
518
- predictions = trainer.predict(model, dm)
519
-
520
- """
521
- prediction = self(batch)
522
- return {
523
- "prediction": prediction,
524
- "batch": batch,
525
- "batch_idx": batch_idx,
526
- }
@@ -2,16 +2,18 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections.abc import Callable, Iterable, Sequence
5
- from typing import Any, TypeAlias, cast, final
5
+ from typing import Any, TypeAlias, cast
6
6
 
7
7
  from lightning.pytorch import Callback, LightningModule
8
8
  from typing_extensions import override
9
9
 
10
+ from ..._callback import NTCallbackBase
10
11
  from ...util.typing_utils import mixin_base_type
11
12
 
12
13
  log = logging.getLogger(__name__)
13
14
 
14
- CallbackFn: TypeAlias = Callable[[], Callback | Iterable[Callback] | None]
15
+ _Callback = Callback | NTCallbackBase
16
+ CallbackFn: TypeAlias = Callable[[], _Callback | Iterable[_Callback] | None]
15
17
 
16
18
 
17
19
  class CallbackRegistrarModuleMixin:
@@ -23,7 +25,7 @@ class CallbackRegistrarModuleMixin:
23
25
 
24
26
  def register_callback(
25
27
  self,
26
- callback: Callback | Iterable[Callback] | CallbackFn | None = None,
28
+ callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
27
29
  ):
28
30
  if not callable(callback):
29
31
  callback_ = cast(CallbackFn, lambda: callback)
@@ -34,8 +36,26 @@ class CallbackRegistrarModuleMixin:
34
36
 
35
37
 
36
38
  class CallbackModuleMixin(
37
- CallbackRegistrarModuleMixin, mixin_base_type(LightningModule)
39
+ CallbackRegistrarModuleMixin,
40
+ mixin_base_type(LightningModule),
38
41
  ):
42
+ @property
43
+ def _nshtrainer_callbacks(self) -> list[CallbackFn]:
44
+ if not hasattr(self, "_private_nshtrainer_callbacks_list"):
45
+ self._private_nshtrainer_callbacks_list = []
46
+ return self._private_nshtrainer_callbacks_list
47
+
48
+ def register_callback(
49
+ self,
50
+ callback: _Callback | Iterable[_Callback] | CallbackFn | None = None,
51
+ ):
52
+ if not callable(callback):
53
+ callback_ = cast(CallbackFn, lambda: callback)
54
+ else:
55
+ callback_ = callback
56
+
57
+ self._nshtrainer_callbacks.append(callback_)
58
+
39
59
  def _gather_all_callbacks(self):
40
60
  modules: list[Any] = []
41
61
  if isinstance(self, CallbackRegistrarModuleMixin):
@@ -52,7 +72,6 @@ class CallbackModuleMixin(
52
72
  for module in modules:
53
73
  yield from module._nshtrainer_callbacks
54
74
 
55
- @final
56
75
  @override
57
76
  def configure_callbacks(self):
58
77
  callbacks = super().configure_callbacks()