nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,641 @@
1
+ import getpass
2
+ import inspect
3
+ import os
4
+ import platform
5
+ import sys
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import Callable, MutableMapping
8
+ from datetime import timedelta
9
+ from logging import getLogger
10
+ from pathlib import Path
11
+ from typing import IO, TYPE_CHECKING, Any, Generic, cast
12
+
13
+ import torch
14
+ from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
15
+ from lightning.pytorch import LightningDataModule, LightningModule, Trainer
16
+ from lightning.pytorch.callbacks import Callback
17
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
18
+ from typing_extensions import Self, TypeVar, deprecated, override
19
+
20
+ from .config import (
21
+ BaseConfig,
22
+ EnvironmentClassInformationConfig,
23
+ EnvironmentLinuxEnvironmentConfig,
24
+ EnvironmentLSFInformationConfig,
25
+ EnvironmentSLURMInformationConfig,
26
+ )
27
+ from .modules.callback import CallbackModuleMixin, CallbackRegistrarModuleMixin
28
+ from .modules.debug import DebugModuleMixin
29
+ from .modules.distributed import DistributedMixin
30
+ from .modules.logger import LoggerLightningModuleMixin, LoggerModuleMixin
31
+ from .modules.profiler import ProfilerMixin
32
+ from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
33
+ from .modules.shared_parameters import SharedParametersModuleMixin
34
+
35
+ log = getLogger(__name__)
36
+
37
+ THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
38
+
39
+
40
+ class Base(DebugModuleMixin, Generic[THparams]):
41
+ @deprecated("Use `ll.nn.MLP` instead.")
42
+ def mlp(self, *args, **kwargs):
43
+ from ..nn.mlp import MLP
44
+
45
+ return MLP(*args, **kwargs)
46
+
47
+ @torch.jit.unused
48
+ @property
49
+ def config(self) -> THparams:
50
+ return self.hparams
51
+
52
+ @torch.jit.unused
53
+ @property
54
+ def C(self) -> THparams:
55
+ return self.hparams
56
+
57
+ @property
58
+ def debug(self) -> bool:
59
+ if torch.jit.is_scripting():
60
+ return False
61
+ return self.config.debug
62
+
63
+ @property
64
+ def dev(self) -> bool:
65
+ if torch.jit.is_scripting():
66
+ return False
67
+ return self.config.debug
68
+
69
+ @override
70
+ def __init__(self, hparams: THparams):
71
+ super().__init__()
72
+
73
+ if not hasattr(self, "hparams"):
74
+ self.hparams = hparams
75
+
76
+
77
+ class DebugFlagCallback(Callback):
78
+ """
79
+ Sets the debug flag to true in the following circumstances:
80
+ - fast_dev_run is enabled
81
+ - sanity check is running
82
+ """
83
+
84
+ @override
85
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
86
+ if not getattr(trainer, "fast_dev_run", False):
87
+ return
88
+
89
+ hparams = cast(BaseConfig, pl_module.hparams)
90
+ if not hparams.debug:
91
+ log.critical("Fast dev run detected, setting debug flag to True.")
92
+ hparams.debug = True
93
+
94
+ @override
95
+ def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
96
+ hparams = cast(BaseConfig, pl_module.hparams)
97
+ self._debug = hparams.debug
98
+ if not self._debug:
99
+ log.critical("Enabling debug flag during sanity check routine.")
100
+ hparams.debug = True
101
+
102
+ @override
103
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
104
+ hparams = cast(BaseConfig, pl_module.hparams)
105
+ if not self._debug:
106
+ log.critical("Sanity check routine complete, disabling debug flag.")
107
+ hparams.debug = self._debug
108
+
109
+
110
+ def _cls_info(cls: type):
111
+ name = cls.__name__
112
+ module = cls.__module__
113
+ full_name = f"{cls.__module__}.{cls.__qualname__}"
114
+
115
+ file_path = inspect.getfile(cls)
116
+ source_file_path = inspect.getsourcefile(cls)
117
+ return EnvironmentClassInformationConfig(
118
+ name=name,
119
+ module=module,
120
+ full_name=full_name,
121
+ file_path=Path(file_path),
122
+ source_file_path=Path(source_file_path) if source_file_path else None,
123
+ )
124
+
125
+
126
+ T = TypeVar("T")
127
+
128
+
129
+ def _psutil():
130
+ import psutil
131
+
132
+ return psutil
133
+
134
+
135
+ def _try_get(fn: Callable[[], T]) -> T | None:
136
+ try:
137
+ return fn()
138
+ except Exception as e:
139
+ log.warning(f"Failed to get value: {e}")
140
+ return None
141
+
142
+
143
+ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
144
+ ProfilerMixin,
145
+ RLPSanityCheckModuleMixin,
146
+ LoggerLightningModuleMixin,
147
+ SharedParametersModuleMixin,
148
+ DistributedMixin,
149
+ CallbackModuleMixin,
150
+ Base[THparams],
151
+ LightningModule,
152
+ ABC,
153
+ Generic[THparams],
154
+ ):
155
+ # Our own custom __repr__ method.
156
+ # Torch's __repr__ method is too verbose and doesn't provide any useful information.
157
+ @override
158
+ def __repr__(self):
159
+ parts: list[str] = []
160
+ parts.append(f"config={self.hparams.concise_repr()}")
161
+ parts.append(f"device={self.device}")
162
+ if self.debug:
163
+ parts.append("debug=True")
164
+
165
+ parts_str = ", ".join(parts)
166
+ return f"{self.__class__.__name__}({parts_str})"
167
+
168
+ @classmethod
169
+ def _validate_class_for_ckpt_loading(cls):
170
+ # Make sure that the `__init__` method takes a single argument, `hparams`.
171
+ if (init_fn := getattr(cls, "__init__", None)) is None:
172
+ return
173
+
174
+ if not inspect.isfunction(init_fn):
175
+ raise TypeError(f"__init__ must be a function: {init_fn}")
176
+
177
+ parameters = dict(inspect.signature(init_fn).parameters)
178
+ # Remove the "self" parameter.
179
+ _ = parameters.pop("self", None)
180
+ if len(parameters) != 1:
181
+ raise TypeError(
182
+ f"__init__ must take a single argument, got {len(parameters)}: {init_fn}"
183
+ )
184
+
185
+ if "hparams" not in parameters:
186
+ raise TypeError(
187
+ f"__init__'s argument must be named 'hparams', got {parameters}"
188
+ )
189
+
190
+ hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
191
+ hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
192
+
193
+ @classmethod
194
+ @abstractmethod
195
+ def config_cls(cls) -> type[THparams]: ...
196
+
197
+ @classmethod
198
+ def load_checkpoint(
199
+ cls,
200
+ checkpoint_path: _PATH | IO,
201
+ hparams: THparams | MutableMapping[str, Any] | None = None,
202
+ map_location: _MAP_LOCATION_TYPE = None,
203
+ strict: bool = True,
204
+ ) -> Self:
205
+ if strict:
206
+ cls._validate_class_for_ckpt_loading()
207
+
208
+ kwargs: dict[str, Any] = {}
209
+ if hparams is not None:
210
+ kwargs["hparams"] = hparams
211
+
212
+ return super().load_from_checkpoint(
213
+ checkpoint_path,
214
+ map_location=map_location,
215
+ hparams_file=None,
216
+ strict=strict,
217
+ **kwargs,
218
+ )
219
+
220
+ @classmethod
221
+ def _update_environment(cls, hparams: THparams):
222
+ hparams.environment.cwd = Path(os.getcwd())
223
+ hparams.environment.python_executable = Path(sys.executable)
224
+ hparams.environment.python_path = [Path(path) for path in sys.path]
225
+ hparams.environment.python_version = sys.version
226
+ hparams.environment.config = _cls_info(cls.config_cls())
227
+ hparams.environment.model = _cls_info(cls)
228
+ hparams.environment.slurm = (
229
+ EnvironmentSLURMInformationConfig.from_current_environment()
230
+ )
231
+ hparams.environment.lsf = (
232
+ EnvironmentLSFInformationConfig.from_current_environment()
233
+ )
234
+ hparams.environment.base_dir = hparams.directory.resolve_run_root_directory(
235
+ hparams.id
236
+ )
237
+ hparams.environment.log_dir = hparams.directory.resolve_subdirectory(
238
+ hparams.id, "log"
239
+ )
240
+ hparams.environment.checkpoint_dir = hparams.directory.resolve_subdirectory(
241
+ hparams.id, "checkpoint"
242
+ )
243
+ hparams.environment.stdio_dir = hparams.directory.resolve_subdirectory(
244
+ hparams.id, "stdio"
245
+ )
246
+ hparams.environment.seed = (
247
+ int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
248
+ )
249
+ hparams.environment.seed_workers = (
250
+ bool(int(seed_everything))
251
+ if (seed_everything := os.environ.get("PL_SEED_WORKERS"))
252
+ else None
253
+ )
254
+ hparams.environment.linux = EnvironmentLinuxEnvironmentConfig(
255
+ user=_try_get(lambda: getpass.getuser()),
256
+ hostname=_try_get(lambda: platform.node()),
257
+ system=_try_get(lambda: platform.system()),
258
+ release=_try_get(lambda: platform.release()),
259
+ version=_try_get(lambda: platform.version()),
260
+ machine=_try_get(lambda: platform.machine()),
261
+ processor=_try_get(lambda: platform.processor()),
262
+ cpu_count=_try_get(lambda: os.cpu_count()),
263
+ memory=_try_get(lambda: _psutil().virtual_memory().total),
264
+ uptime=_try_get(lambda: timedelta(seconds=_psutil().boot_time())),
265
+ boot_time=_try_get(lambda: _psutil().boot_time()),
266
+ load_avg=_try_get(lambda: os.getloadavg()),
267
+ )
268
+
269
+ def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
270
+ """
271
+ Override this method to update the hparams dictionary before it is used to create the hparams object.
272
+ Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
273
+ """
274
+ return hparams
275
+
276
+ def pre_init_update_hparams(self, hparams: THparams):
277
+ """
278
+ Override this method to update the hparams object before it is used to create the hparams_initial object.
279
+ """
280
+ return hparams
281
+
282
+ @override
283
+ def __init__(self, hparams: THparams | MutableMapping[str, Any]):
284
+ if not isinstance(hparams, BaseConfig):
285
+ if not isinstance(hparams, MutableMapping):
286
+ raise TypeError(
287
+ f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
288
+ )
289
+
290
+ hparams = self.pre_init_update_hparams_dict(hparams)
291
+ hparams = self.config_cls().model_validate(hparams)
292
+ self._update_environment(hparams)
293
+ hparams = self.pre_init_update_hparams(hparams)
294
+ super().__init__(hparams)
295
+
296
+ self.save_hyperparameters(hparams)
297
+
298
+ self.register_callback(lambda: DebugFlagCallback())
299
+
300
+ def zero_loss(self):
301
+ """
302
+ Returns a loss tensor with the value 0.
303
+ It multiples each weight by 0 and returns the sum, so we don't run into issues with ununsed parameters in DDP.
304
+ """
305
+ loss = sum((0.0 * v).sum() for v in self.parameters() if v.requires_grad)
306
+ loss = cast(torch.Tensor, loss)
307
+ return loss
308
+
309
+ @property
310
+ def datamodule(self):
311
+ datamodule = getattr(self.trainer, "datamodule", None)
312
+ if datamodule is None:
313
+ return None
314
+
315
+ if not isinstance(datamodule, LightningDataModuleBase):
316
+ raise TypeError(
317
+ f"datamodule must be a LightningDataModuleBase: {type(datamodule)}"
318
+ )
319
+
320
+ datamodule = cast(LightningDataModuleBase[THparams], datamodule)
321
+ return datamodule
322
+
323
+ if TYPE_CHECKING:
324
+
325
+ @override
326
+ def training_step( # pyright: ignore[reportIncompatibleMethodOverride]
327
+ self,
328
+ batch: Any,
329
+ batch_idx: int,
330
+ ) -> Any:
331
+ r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
332
+ logger.
333
+
334
+ Args:
335
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
336
+ batch_idx: The index of this batch.
337
+ dataloader_idx: The index of the dataloader that produced this batch.
338
+ (only if multiple dataloaders used)
339
+
340
+ Return:
341
+ - :class:`~torch.Tensor` - The loss tensor
342
+ - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of
343
+ automatic optimization.
344
+ - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for
345
+ multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning
346
+ the loss is not required.
347
+
348
+ In this step you'd normally do the forward pass and calculate the loss for a batch.
349
+ You can also do fancier things like multiple forward passes or something model specific.
350
+
351
+ Example::
352
+
353
+ def training_step(self, batch, batch_idx):
354
+ x, y, z = batch
355
+ out = self.encoder(x)
356
+ loss = self.loss(out, x)
357
+ return loss
358
+
359
+ To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
360
+
361
+ .. code-block:: python
362
+
363
+ def __init__(self):
364
+ super().__init__()
365
+ self.automatic_optimization = False
366
+
367
+
368
+ # Multiple optimizers (e.g.: GANs)
369
+ def training_step(self, batch, batch_idx):
370
+ opt1, opt2 = self.optimizers()
371
+
372
+ # do training_step with encoder
373
+ ...
374
+ opt1.step()
375
+ # do training_step with decoder
376
+ ...
377
+ opt2.step()
378
+
379
+ Note:
380
+ When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
381
+ normalized by ``accumulate_grad_batches`` internally.
382
+
383
+ """
384
+ raise NotImplementedError
385
+
386
+ @override
387
+ def validation_step( # pyright: ignore[reportIncompatibleMethodOverride]
388
+ self,
389
+ batch: Any,
390
+ batch_idx: int,
391
+ ) -> STEP_OUTPUT:
392
+ r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
393
+ calculate anything of interest like accuracy.
394
+
395
+ Args:
396
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
397
+ batch_idx: The index of this batch.
398
+ dataloader_idx: The index of the dataloader that produced this batch.
399
+ (only if multiple dataloaders used)
400
+
401
+ Return:
402
+ - :class:`~torch.Tensor` - The loss tensor
403
+ - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
404
+ - ``None`` - Skip to the next batch.
405
+
406
+ .. code-block:: python
407
+
408
+ # if you have one val dataloader:
409
+ def validation_step(self, batch, batch_idx): ...
410
+
411
+
412
+ # if you have multiple val dataloaders:
413
+ def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
414
+
415
+ Examples::
416
+
417
+ # CASE 1: A single validation dataset
418
+ def validation_step(self, batch, batch_idx):
419
+ x, y = batch
420
+
421
+ # implement your own
422
+ out = self(x)
423
+ loss = self.loss(out, y)
424
+
425
+ # log 6 example images
426
+ # or generated text... or whatever
427
+ sample_imgs = x[:6]
428
+ grid = torchvision.utils.make_grid(sample_imgs)
429
+ self.logger.experiment.add_image('example_images', grid, 0)
430
+
431
+ # calculate acc
432
+ labels_hat = torch.argmax(out, dim=1)
433
+ val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
434
+
435
+ # log the outputs!
436
+ self.log_dict({'val_loss': loss, 'val_acc': val_acc})
437
+
438
+ If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument. We recommend
439
+ setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
440
+
441
+ .. code-block:: python
442
+
443
+ # CASE 2: multiple validation dataloaders
444
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
445
+ # dataloader_idx tells you which dataset this is.
446
+ ...
447
+
448
+ Note:
449
+ If you don't need to validate you don't need to implement this method.
450
+
451
+ Note:
452
+ When the :meth:`validation_step` is called, the model has been put in eval mode
453
+ and PyTorch gradients have been disabled. At the end of validation,
454
+ the model goes back to training mode and gradients are enabled.
455
+
456
+ """
457
+ raise NotImplementedError
458
+
459
+ @override
460
+ def test_step( # pyright: ignore[reportIncompatibleMethodOverride]
461
+ self,
462
+ batch: Any,
463
+ batch_idx: int,
464
+ ) -> STEP_OUTPUT:
465
+ r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or
466
+ calculate anything of interest such as accuracy.
467
+
468
+ Args:
469
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
470
+ batch_idx: The index of this batch.
471
+ dataloader_idx: The index of the dataloader that produced this batch.
472
+ (only if multiple dataloaders used)
473
+
474
+ Return:
475
+ - :class:`~torch.Tensor` - The loss tensor
476
+ - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
477
+ - ``None`` - Skip to the next batch.
478
+
479
+ .. code-block:: python
480
+
481
+ # if you have one test dataloader:
482
+ def test_step(self, batch, batch_idx): ...
483
+
484
+
485
+ # if you have multiple test dataloaders:
486
+ def test_step(self, batch, batch_idx, dataloader_idx=0): ...
487
+
488
+ Examples::
489
+
490
+ # CASE 1: A single test dataset
491
+ def test_step(self, batch, batch_idx):
492
+ x, y = batch
493
+
494
+ # implement your own
495
+ out = self(x)
496
+ loss = self.loss(out, y)
497
+
498
+ # log 6 example images
499
+ # or generated text... or whatever
500
+ sample_imgs = x[:6]
501
+ grid = torchvision.utils.make_grid(sample_imgs)
502
+ self.logger.experiment.add_image('example_images', grid, 0)
503
+
504
+ # calculate acc
505
+ labels_hat = torch.argmax(out, dim=1)
506
+ test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
507
+
508
+ # log the outputs!
509
+ self.log_dict({'test_loss': loss, 'test_acc': test_acc})
510
+
511
+ If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument. We recommend
512
+ setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
513
+
514
+ .. code-block:: python
515
+
516
+ # CASE 2: multiple test dataloaders
517
+ def test_step(self, batch, batch_idx, dataloader_idx=0):
518
+ # dataloader_idx tells you which dataset this is.
519
+ ...
520
+
521
+ Note:
522
+ If you don't need to test you don't need to implement this method.
523
+
524
+ Note:
525
+ When the :meth:`test_step` is called, the model has been put in eval mode and
526
+ PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
527
+ to training mode and gradients are enabled.
528
+
529
+ """
530
+ raise NotImplementedError
531
+
532
+ @override
533
+ def predict_step( # pyright: ignore[reportIncompatibleMethodOverride]
534
+ self,
535
+ batch: Any,
536
+ batch_idx: int,
537
+ ) -> STEP_OUTPUT:
538
+ """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
539
+ :meth:`~lightning.pytorch.core.LightningModule.forward`. Override to add any processing logic.
540
+
541
+ The :meth:`~lightning.pytorch.core.LightningModule.predict_step` is used
542
+ to scale inference on multi-devices.
543
+
544
+ To prevent an OOM error, it is possible to use :class:`~lightning.pytorch.callbacks.BasePredictionWriter`
545
+ callback to write the predictions to disk or database after each batch or on epoch end.
546
+
547
+ The :class:`~lightning.pytorch.callbacks.BasePredictionWriter` should be used while using a spawn
548
+ based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
549
+ or training on 8 TPU cores with ``Trainer(accelerator="tpu", devices=8)`` as predictions won't be returned.
550
+
551
+ Args:
552
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
553
+ batch_idx: The index of this batch.
554
+ dataloader_idx: The index of the dataloader that produced this batch.
555
+ (only if multiple dataloaders used)
556
+
557
+ Return:
558
+ Predicted output (optional).
559
+
560
+ Example ::
561
+
562
+ class MyModel(LightningModule):
563
+
564
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
565
+ return self(batch)
566
+
567
+ dm = ...
568
+ model = MyModel()
569
+ trainer = Trainer(accelerator="gpu", devices=2)
570
+ predictions = trainer.predict(model, dm)
571
+
572
+ """
573
+ prediction = self(batch)
574
+ return {
575
+ "prediction": prediction,
576
+ "batch": batch,
577
+ "batch_idx": batch_idx,
578
+ }
579
+
580
+
581
+ class LightningDataModuleBase(
582
+ LoggerModuleMixin,
583
+ CallbackRegistrarModuleMixin,
584
+ Base[THparams],
585
+ LightningDataModule,
586
+ ABC,
587
+ Generic[THparams],
588
+ ):
589
+ hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
590
+ hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
591
+
592
+ def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
593
+ """
594
+ Override this method to update the hparams dictionary before it is used to create the hparams object.
595
+ Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
596
+ """
597
+ return hparams
598
+
599
+ def pre_init_update_hparams(self, hparams: THparams):
600
+ """
601
+ Override this method to update the hparams object before it is used to create the hparams_initial object.
602
+ """
603
+ return hparams
604
+
605
+ @classmethod
606
+ def _update_environment(cls, hparams: THparams):
607
+ hparams.environment.data = _cls_info(cls)
608
+
609
+ @override
610
+ def __init__(self, hparams: THparams):
611
+ if not isinstance(hparams, BaseConfig):
612
+ if not isinstance(hparams, MutableMapping):
613
+ raise TypeError(
614
+ f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
615
+ )
616
+
617
+ hparams = self.pre_init_update_hparams_dict(hparams)
618
+ hparams = self.config_cls().from_dict(hparams)
619
+ self._update_environment(hparams)
620
+ hparams = self.pre_init_update_hparams(hparams)
621
+ super().__init__(hparams)
622
+
623
+ self.save_hyperparameters(hparams)
624
+
625
+ @property
626
+ def lightning_module(self):
627
+ if not self.trainer:
628
+ raise ValueError("Trainer has not been set.")
629
+
630
+ module = self.trainer.lightning_module
631
+ if not isinstance(module, LightningModuleBase):
632
+ raise ValueError(
633
+ f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}"
634
+ )
635
+
636
+ module = cast(LightningModuleBase[THparams], module)
637
+ return module
638
+
639
+ @property
640
+ def device(self):
641
+ return self.lightning_module.device