nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +12 -1
  4. nshtrainer/callbacks/debug_flag.py +72 -0
  5. nshtrainer/callbacks/directory_setup.py +85 -0
  6. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  7. nshtrainer/callbacks/shared_parameters.py +87 -0
  8. nshtrainer/config.py +67 -0
  9. nshtrainer/ll/__init__.py +5 -4
  10. nshtrainer/ll/model.py +7 -0
  11. nshtrainer/loggers/wandb.py +1 -1
  12. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  13. nshtrainer/model/__init__.py +0 -21
  14. nshtrainer/model/base.py +124 -67
  15. nshtrainer/model/config.py +7 -1025
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +787 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/callback.py +0 -206
  28. nshtrainer/model/modules/debug.py +0 -42
  29. nshtrainer/model/modules/distributed.py +0 -70
  30. nshtrainer/model/modules/profiler.py +0 -24
  31. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  32. nshtrainer/model/modules/shared_parameters.py +0 -72
  33. /nshtrainer/{config → util/config}/duration.py +0 -0
  34. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,787 @@
1
+ import logging
2
+ from collections.abc import Iterable, Sequence
3
+ from datetime import timedelta
4
+ from pathlib import Path
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Annotated,
8
+ Any,
9
+ Literal,
10
+ Protocol,
11
+ TypeAlias,
12
+ runtime_checkable,
13
+ )
14
+
15
+ import nshconfig as C
16
+ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
17
+ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
18
+ from lightning.pytorch.accelerators import Accelerator
19
+ from lightning.pytorch.callbacks.callback import Callback
20
+ from lightning.pytorch.loggers import Logger
21
+ from lightning.pytorch.plugins import _PLUGIN_INPUT
22
+ from lightning.pytorch.plugins.layer_sync import LayerSync
23
+ from lightning.pytorch.plugins.precision.precision import Precision
24
+ from lightning.pytorch.profilers import Profiler
25
+ from lightning.pytorch.strategies.strategy import Strategy
26
+ from typing_extensions import TypedDict, TypeVar, override
27
+
28
+ from .._checkpoint.loader import CheckpointLoadingConfig
29
+ from .._hf_hub import HuggingFaceHubConfig
30
+ from ..callbacks import (
31
+ BestCheckpointCallbackConfig,
32
+ CallbackConfig,
33
+ EarlyStoppingConfig,
34
+ LastCheckpointCallbackConfig,
35
+ OnExceptionCheckpointCallbackConfig,
36
+ )
37
+ from ..callbacks.base import CallbackConfigBase
38
+ from ..callbacks.debug_flag import DebugFlagCallbackConfig
39
+ from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
40
+ from ..callbacks.shared_parameters import SharedParametersConfig
41
+ from ..loggers import (
42
+ CSVLoggerConfig,
43
+ LoggerConfig,
44
+ TensorboardLoggerConfig,
45
+ WandbLoggerConfig,
46
+ )
47
+ from ..profiler import ProfilerConfig
48
+
49
+ if TYPE_CHECKING:
50
+ from ..model.config import BaseConfig
51
+
52
+ log = logging.getLogger(__name__)
53
+
54
+
55
+ class LoggingConfig(CallbackConfigBase):
56
+ enabled: bool = True
57
+ """Enable experiment tracking."""
58
+
59
+ loggers: Sequence[LoggerConfig] = [
60
+ WandbLoggerConfig(),
61
+ CSVLoggerConfig(),
62
+ TensorboardLoggerConfig(),
63
+ ]
64
+ """Loggers to use for experiment tracking."""
65
+
66
+ log_lr: bool | Literal["step", "epoch"] = True
67
+ """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
68
+ log_epoch: bool = True
69
+ """If enabled, will log the fractional epoch number to the logger."""
70
+
71
+ actsave_logged_metrics: bool = False
72
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
73
+
74
+ @property
75
+ def wandb(self):
76
+ return next(
77
+ (
78
+ logger
79
+ for logger in self.loggers
80
+ if isinstance(logger, WandbLoggerConfig)
81
+ ),
82
+ None,
83
+ )
84
+
85
+ @property
86
+ def csv(self):
87
+ return next(
88
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
89
+ None,
90
+ )
91
+
92
+ @property
93
+ def tensorboard(self):
94
+ return next(
95
+ (
96
+ logger
97
+ for logger in self.loggers
98
+ if isinstance(logger, TensorboardLoggerConfig)
99
+ ),
100
+ None,
101
+ )
102
+
103
+ def create_loggers(self, root_config: "BaseConfig"):
104
+ """
105
+ Constructs and returns a list of loggers based on the provided root configuration.
106
+
107
+ Args:
108
+ root_config (BaseConfig): The root configuration object.
109
+
110
+ Returns:
111
+ list[Logger]: A list of constructed loggers.
112
+ """
113
+ if not self.enabled:
114
+ return
115
+
116
+ for logger_config in sorted(
117
+ self.loggers,
118
+ key=lambda x: x.priority,
119
+ reverse=True,
120
+ ):
121
+ if not logger_config.enabled:
122
+ continue
123
+ if (logger := logger_config.create_logger(root_config)) is None:
124
+ continue
125
+ yield logger
126
+
127
+ @override
128
+ def create_callbacks(self, root_config):
129
+ if self.log_lr:
130
+ from lightning.pytorch.callbacks import LearningRateMonitor
131
+
132
+ logging_interval: str | None = None
133
+ if isinstance(self.log_lr, str):
134
+ logging_interval = self.log_lr
135
+
136
+ yield LearningRateMonitor(logging_interval=logging_interval)
137
+
138
+ if self.log_epoch:
139
+ from ..callbacks.log_epoch import LogEpochCallback
140
+
141
+ yield LogEpochCallback()
142
+
143
+ for logger in self.loggers:
144
+ if not logger or not isinstance(logger, CallbackConfigBase):
145
+ continue
146
+
147
+ yield from logger.create_callbacks(root_config)
148
+
149
+
150
+ class GradientClippingConfig(C.Config):
151
+ enabled: bool = True
152
+ """Enable gradient clipping."""
153
+ value: int | float
154
+ """Value to use for gradient clipping."""
155
+ algorithm: Literal["value", "norm"] = "norm"
156
+ """Norm type to use for gradient clipping."""
157
+
158
+
159
+ class OptimizationConfig(CallbackConfigBase):
160
+ log_grad_norm: bool | str | float = False
161
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
162
+ log_grad_norm_per_param: bool | str | float = False
163
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
164
+
165
+ log_param_norm: bool | str | float = False
166
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
167
+ log_param_norm_per_param: bool | str | float = False
168
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
169
+
170
+ gradient_clipping: GradientClippingConfig | None = None
171
+ """Gradient clipping configuration, or None to disable gradient clipping."""
172
+
173
+ @override
174
+ def create_callbacks(self, root_config):
175
+ from ..callbacks.norm_logging import NormLoggingConfig
176
+
177
+ yield from NormLoggingConfig(
178
+ log_grad_norm=self.log_grad_norm,
179
+ log_grad_norm_per_param=self.log_grad_norm_per_param,
180
+ log_param_norm=self.log_param_norm,
181
+ log_param_norm_per_param=self.log_param_norm_per_param,
182
+ ).create_callbacks(root_config)
183
+
184
+
185
+ TPlugin = TypeVar(
186
+ "TPlugin",
187
+ Precision,
188
+ ClusterEnvironment,
189
+ CheckpointIO,
190
+ LayerSync,
191
+ infer_variance=True,
192
+ )
193
+
194
+
195
+ @runtime_checkable
196
+ class PluginConfigProtocol(Protocol[TPlugin]):
197
+ def create_plugin(self) -> TPlugin: ...
198
+
199
+
200
+ @runtime_checkable
201
+ class AcceleratorConfigProtocol(Protocol):
202
+ def create_accelerator(self) -> Accelerator: ...
203
+
204
+
205
+ @runtime_checkable
206
+ class StrategyConfigProtocol(Protocol):
207
+ def create_strategy(self) -> Strategy: ...
208
+
209
+
210
+ AcceleratorLiteral: TypeAlias = Literal[
211
+ "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
212
+ ]
213
+
214
+ StrategyLiteral: TypeAlias = Literal[
215
+ "auto",
216
+ "ddp",
217
+ "ddp_find_unused_parameters_false",
218
+ "ddp_find_unused_parameters_true",
219
+ "ddp_spawn",
220
+ "ddp_spawn_find_unused_parameters_false",
221
+ "ddp_spawn_find_unused_parameters_true",
222
+ "ddp_fork",
223
+ "ddp_fork_find_unused_parameters_false",
224
+ "ddp_fork_find_unused_parameters_true",
225
+ "ddp_notebook",
226
+ "dp",
227
+ "deepspeed",
228
+ "deepspeed_stage_1",
229
+ "deepspeed_stage_1_offload",
230
+ "deepspeed_stage_2",
231
+ "deepspeed_stage_2_offload",
232
+ "deepspeed_stage_3",
233
+ "deepspeed_stage_3_offload",
234
+ "deepspeed_stage_3_offload_nvme",
235
+ "fsdp",
236
+ "fsdp_cpu_offload",
237
+ "single_xla",
238
+ "xla_fsdp",
239
+ "xla",
240
+ "single_tpu",
241
+ ]
242
+
243
+
244
+ class ReproducibilityConfig(C.Config):
245
+ deterministic: bool | Literal["warn"] | None = None
246
+ """
247
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
248
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
249
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
250
+ """
251
+
252
+
253
+ CheckpointCallbackConfig: TypeAlias = Annotated[
254
+ BestCheckpointCallbackConfig
255
+ | LastCheckpointCallbackConfig
256
+ | OnExceptionCheckpointCallbackConfig,
257
+ C.Field(discriminator="name"),
258
+ ]
259
+
260
+
261
+ class CheckpointSavingConfig(CallbackConfigBase):
262
+ enabled: bool = True
263
+ """Enable checkpoint saving."""
264
+
265
+ checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
266
+ BestCheckpointCallbackConfig(),
267
+ LastCheckpointCallbackConfig(),
268
+ OnExceptionCheckpointCallbackConfig(),
269
+ ]
270
+ """Checkpoint callback configurations."""
271
+
272
+ def disable_(self):
273
+ self.enabled = False
274
+ return self
275
+
276
+ def should_save_checkpoints(self, root_config: "BaseConfig"):
277
+ if not self.enabled:
278
+ return False
279
+
280
+ if root_config.trainer.fast_dev_run:
281
+ return False
282
+
283
+ return True
284
+
285
+ @override
286
+ def create_callbacks(self, root_config: "BaseConfig"):
287
+ if not self.should_save_checkpoints(root_config):
288
+ return
289
+
290
+ for callback_config in self.checkpoint_callbacks:
291
+ yield from callback_config.create_callbacks(root_config)
292
+
293
+
294
+ class LightningTrainerKwargs(TypedDict, total=False):
295
+ accelerator: str | Accelerator
296
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
297
+ as well as custom accelerator instances."""
298
+
299
+ strategy: str | Strategy
300
+ """Supports different training strategies with aliases as well custom strategies.
301
+ Default: ``"auto"``.
302
+ """
303
+
304
+ devices: list[int] | str | int
305
+ """The devices to use. Can be set to a positive number (int or str), a sequence of device indices
306
+ (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
307
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
308
+ """
309
+
310
+ num_nodes: int
311
+ """Number of GPU nodes for distributed training.
312
+ Default: ``1``.
313
+ """
314
+
315
+ precision: _PRECISION_INPUT | None
316
+ """Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
317
+ 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
318
+ Can be used on CPU, GPU, TPUs, HPUs or IPUs.
319
+ Default: ``'32-true'``.
320
+ """
321
+
322
+ logger: Logger | Iterable[Logger] | bool | None
323
+ """Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
324
+ the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
325
+ ``False`` will disable logging. If multiple loggers are provided, local files
326
+ (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
327
+ Default: ``True``.
328
+ """
329
+
330
+ callbacks: list[Callback] | Callback | None
331
+ """Add a callback or list of callbacks.
332
+ Default: ``None``.
333
+ """
334
+
335
+ fast_dev_run: int | bool
336
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
337
+ of train, val and test to find any bugs (ie: a sort of unit test).
338
+ Default: ``False``.
339
+ """
340
+
341
+ max_epochs: int | None
342
+ """Stop training once this number of epochs is reached. Disabled by default (None).
343
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
344
+ To enable infinite training, set ``max_epochs = -1``.
345
+ """
346
+
347
+ min_epochs: int | None
348
+ """Force training for at least these many epochs. Disabled by default (None).
349
+ """
350
+
351
+ max_steps: int
352
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
353
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
354
+ ``max_epochs`` to ``-1``.
355
+ """
356
+
357
+ min_steps: int | None
358
+ """Force training for at least these number of steps. Disabled by default (``None``).
359
+ """
360
+
361
+ max_time: str | timedelta | dict[str, int] | None
362
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
363
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
364
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
365
+ :class:`datetime.timedelta`.
366
+ """
367
+
368
+ limit_train_batches: int | float | None
369
+ """How much of training dataset to check (float = fraction, int = num_batches).
370
+ Default: ``1.0``.
371
+ """
372
+
373
+ limit_val_batches: int | float | None
374
+ """How much of validation dataset to check (float = fraction, int = num_batches).
375
+ Default: ``1.0``.
376
+ """
377
+
378
+ limit_test_batches: int | float | None
379
+ """How much of test dataset to check (float = fraction, int = num_batches).
380
+ Default: ``1.0``.
381
+ """
382
+
383
+ limit_predict_batches: int | float | None
384
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
385
+ Default: ``1.0``.
386
+ """
387
+
388
+ overfit_batches: int | float
389
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
390
+ Default: ``0.0``.
391
+ """
392
+
393
+ val_check_interval: int | float | None
394
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
395
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
396
+ batches. An ``int`` value can only be higher than the number of training batches when
397
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
398
+ across epochs or during iteration-based training.
399
+ Default: ``1.0``.
400
+ """
401
+
402
+ check_val_every_n_epoch: int | None
403
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
404
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
405
+ to be an integer value.
406
+ Default: ``1``.
407
+ """
408
+
409
+ num_sanity_val_steps: int | None
410
+ """Sanity check runs n validation batches before starting the training routine.
411
+ Set it to `-1` to run all batches in all validation dataloaders.
412
+ Default: ``2``.
413
+ """
414
+
415
+ log_every_n_steps: int | None
416
+ """How often to log within steps.
417
+ Default: ``50``.
418
+ """
419
+
420
+ enable_checkpointing: bool | None
421
+ """If ``True``, enable checkpointing.
422
+ It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
423
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
424
+ Default: ``True``.
425
+ """
426
+
427
+ enable_progress_bar: bool | None
428
+ """Whether to enable to progress bar by default.
429
+ Default: ``True``.
430
+ """
431
+
432
+ enable_model_summary: bool | None
433
+ """Whether to enable model summarization by default.
434
+ Default: ``True``.
435
+ """
436
+
437
+ accumulate_grad_batches: int
438
+ """Accumulates gradients over k batches before stepping the optimizer.
439
+ Default: 1.
440
+ """
441
+
442
+ gradient_clip_val: int | float | None
443
+ """The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
444
+ gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
445
+ Default: ``None``.
446
+ """
447
+
448
+ gradient_clip_algorithm: str | None
449
+ """The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
450
+ to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
451
+ be set to ``"norm"``.
452
+ """
453
+
454
+ deterministic: bool | Literal["warn"] | None
455
+ """If ``True``, sets whether PyTorch operations must use deterministic algorithms.
456
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
457
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
458
+ """
459
+
460
+ benchmark: bool | None
461
+ """The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
462
+ The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
463
+ (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
464
+ is set to ``True``, this will default to ``False``. Override to manually set a different value.
465
+ Default: ``None``.
466
+ """
467
+
468
+ inference_mode: bool
469
+ """Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
470
+ evaluation (``validate``/``test``/``predict``).
471
+ """
472
+
473
+ use_distributed_sampler: bool
474
+ """Whether to wrap the DataLoader's sampler with
475
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
476
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
477
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
478
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
479
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
480
+ we don't do this automatically.
481
+ """
482
+
483
+ profiler: Profiler | str | None
484
+ """To profile individual steps during training and assist in identifying bottlenecks.
485
+ Default: ``None``.
486
+ """
487
+
488
+ detect_anomaly: bool
489
+ """Enable anomaly detection for the autograd engine.
490
+ Default: ``False``.
491
+ """
492
+
493
+ barebones: bool
494
+ """Whether to run in "barebones mode", where all features that may impact raw speed are
495
+ disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
496
+ runs. The following features are deactivated:
497
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
498
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
499
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
500
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
501
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
502
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
503
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
504
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
505
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
506
+ :meth:`~lightning.pytorch.core.LightningModule.log`,
507
+ :meth:`~lightning.pytorch.core.LightningModule.log_dict`.
508
+ """
509
+
510
+ plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
511
+ """Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
512
+ Default: ``None``.
513
+ """
514
+
515
+ sync_batchnorm: bool
516
+ """Synchronize batch norm layers between process groups/whole world.
517
+ Default: ``False``.
518
+ """
519
+
520
+ reload_dataloaders_every_n_epochs: int
521
+ """Set to a positive integer to reload dataloaders every n epochs.
522
+ Default: ``0``.
523
+ """
524
+
525
+ default_root_dir: Path | None
526
+ """Default path for logs and weights when no logger/ckpt_callback passed.
527
+ Default: ``os.getcwd()``.
528
+ Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
529
+ """
530
+
531
+
532
+ class SanityCheckingConfig(C.Config):
533
+ reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
534
+ """
535
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
536
+ - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
537
+ - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
538
+ Valid values are: "disable", "warn", "error".
539
+ """
540
+
541
+
542
+ class TrainerConfig(C.Config):
543
+ ckpt_path: Literal["none"] | str | Path | None = None
544
+ """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
545
+
546
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
547
+ """Checkpoint loading configuration options.
548
+ `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
549
+ `"none"` will disable checkpoint loading.
550
+ """
551
+
552
+ checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
553
+ """Checkpoint saving configuration options."""
554
+
555
+ hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
556
+ """Hugging Face Hub configuration options."""
557
+
558
+ logging: LoggingConfig = LoggingConfig()
559
+ """Logging/experiment tracking (e.g., WandB) configuration options."""
560
+
561
+ optimizer: OptimizationConfig = OptimizationConfig()
562
+ """Optimization configuration options."""
563
+
564
+ reproducibility: ReproducibilityConfig = ReproducibilityConfig()
565
+ """Reproducibility configuration options."""
566
+
567
+ reduce_lr_on_plateau_sanity_checking: RLPSanityChecksConfig | None = (
568
+ RLPSanityChecksConfig()
569
+ )
570
+ """
571
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
572
+ - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
573
+ - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
574
+ """
575
+
576
+ early_stopping: EarlyStoppingConfig | None = None
577
+ """Early stopping configuration options."""
578
+
579
+ profiler: ProfilerConfig | None = None
580
+ """
581
+ To profile individual steps during training and assist in identifying bottlenecks.
582
+ Default: ``None``.
583
+ """
584
+
585
+ callbacks: list[CallbackConfig] = []
586
+ """Callbacks to use during training."""
587
+
588
+ detect_anomaly: bool | None = None
589
+ """Enable anomaly detection for the autograd engine.
590
+ Default: ``False``.
591
+ """
592
+
593
+ plugins: list[PluginConfigProtocol] | None = None
594
+ """
595
+ Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
596
+ Default: ``None``.
597
+ """
598
+
599
+ auto_determine_num_nodes: bool = True
600
+ """
601
+ If enabled, will automatically determine the number of nodes for distributed training.
602
+
603
+ This will only work on:
604
+ - SLURM clusters
605
+ - LSF clusters
606
+ """
607
+
608
+ fast_dev_run: int | bool = False
609
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
610
+ of train, val and test to find any bugs (ie: a sort of unit test).
611
+ Default: ``False``.
612
+ """
613
+
614
+ precision: (
615
+ Literal[
616
+ "64-true",
617
+ "32-true",
618
+ "fp16-mixed",
619
+ "bf16-mixed",
620
+ "16-mixed-auto",
621
+ ]
622
+ | None
623
+ ) = None
624
+ """
625
+ Training precision. Can be one of:
626
+ - "64-true": Double precision (64-bit).
627
+ - "32-true": Full precision (32-bit).
628
+ - "fp16-mixed": Float16 mixed precision.
629
+ - "bf16-mixed": BFloat16 mixed precision.
630
+ - "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
631
+ """
632
+
633
+ max_epochs: int | None = None
634
+ """Stop training once this number of epochs is reached. Disabled by default (None).
635
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
636
+ To enable infinite training, set ``max_epochs = -1``.
637
+ """
638
+
639
+ min_epochs: int | None = None
640
+ """Force training for at least these many epochs. Disabled by default (None).
641
+ """
642
+
643
+ max_steps: int = -1
644
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
645
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
646
+ ``max_epochs`` to ``-1``.
647
+ """
648
+
649
+ min_steps: int | None = None
650
+ """Force training for at least these number of steps. Disabled by default (``None``).
651
+ """
652
+
653
+ max_time: str | timedelta | dict[str, int] | None = None
654
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
655
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
656
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
657
+ :class:`datetime.timedelta`.
658
+ """
659
+
660
+ limit_train_batches: int | float | None = None
661
+ """How much of training dataset to check (float = fraction, int = num_batches).
662
+ Default: ``1.0``.
663
+ """
664
+
665
+ limit_val_batches: int | float | None = None
666
+ """How much of validation dataset to check (float = fraction, int = num_batches).
667
+ Default: ``1.0``.
668
+ """
669
+
670
+ limit_test_batches: int | float | None = None
671
+ """How much of test dataset to check (float = fraction, int = num_batches).
672
+ Default: ``1.0``.
673
+ """
674
+
675
+ limit_predict_batches: int | float | None = None
676
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
677
+ Default: ``1.0``.
678
+ """
679
+
680
+ overfit_batches: int | float = 0.0
681
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
682
+ Default: ``0.0``.
683
+ """
684
+
685
+ val_check_interval: int | float | None = None
686
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
687
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
688
+ batches. An ``int`` value can only be higher than the number of training batches when
689
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
690
+ across epochs or during iteration-based training.
691
+ Default: ``1.0``.
692
+ """
693
+
694
+ check_val_every_n_epoch: int | None = 1
695
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
696
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
697
+ to be an integer value.
698
+ Default: ``1``.
699
+ """
700
+
701
+ num_sanity_val_steps: int | None = None
702
+ """Sanity check runs n validation batches before starting the training routine.
703
+ Set it to `-1` to run all batches in all validation dataloaders.
704
+ Default: ``2``.
705
+ """
706
+
707
+ log_every_n_steps: int | None = None
708
+ """How often to log within steps.
709
+ Default: ``50``.
710
+ """
711
+
712
+ inference_mode: bool = True
713
+ """Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
714
+ Default: ``True``.
715
+ """
716
+
717
+ use_distributed_sampler: bool | None = None
718
+ """Whether to wrap the DataLoader's sampler with
719
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
720
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
721
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
722
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
723
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
724
+ we don't do this automatically.
725
+ Default: ``True``.
726
+ """
727
+
728
+ accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
729
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
730
+ as well as custom accelerator instances.
731
+ Default: ``"auto"``.
732
+ """
733
+
734
+ strategy: StrategyConfigProtocol | StrategyLiteral | None = None
735
+ """Supports different training strategies with aliases as well custom strategies.
736
+ Default: ``"auto"``.
737
+ """
738
+
739
+ devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
740
+ """The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
741
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
742
+ """
743
+
744
+ shared_parameters: SharedParametersConfig | None = SharedParametersConfig()
745
+ """If enabled, the model supports scaling the gradients of shared parameters that
746
+ are registered in the self.shared_parameters list. This is useful for models that
747
+ share parameters across multiple modules (e.g., in a GPT model) and want to
748
+ downscale the gradients of these parameters to avoid overfitting.
749
+ """
750
+
751
+ auto_set_default_root_dir: bool = True
752
+ """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
753
+ save_checkpoint_metadata: bool = True
754
+ """If enabled, will save additional metadata whenever a checkpoint is saved."""
755
+ auto_set_debug_flag: DebugFlagCallbackConfig | None = DebugFlagCallbackConfig()
756
+ """If enabled, will automatically set the debug flag to True if:
757
+ - The trainer is running in fast_dev_run mode.
758
+ - The trainer is running a sanity check (which happens before starting the training routine).
759
+ """
760
+
761
+ lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
762
+ """
763
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
764
+
765
+ Please refer to the Lightning documentation for a list of valid keyword arguments.
766
+ """
767
+
768
+ additional_lightning_kwargs: dict[str, Any] = {}
769
+ """
770
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
771
+
772
+ This is essentially a non-type-checked version of `lightning_kwargs`.
773
+ """
774
+
775
+ set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
776
+ """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
777
+
778
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
779
+ yield self.early_stopping
780
+ yield self.checkpoint_saving
781
+ yield self.logging
782
+ yield self.optimizer
783
+ yield self.hf_hub
784
+ yield self.shared_parameters
785
+ yield self.reduce_lr_on_plateau_sanity_checking
786
+ yield self.auto_set_debug_flag
787
+ yield from self.callbacks