nshtrainer 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,778 @@
1
+ import logging
2
+ from collections.abc import Iterable, Sequence
3
+ from datetime import timedelta
4
+ from pathlib import Path
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Annotated,
8
+ Any,
9
+ Literal,
10
+ Protocol,
11
+ TypeAlias,
12
+ runtime_checkable,
13
+ )
14
+
15
+ import nshconfig as C
16
+ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
17
+ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
18
+ from lightning.pytorch.accelerators import Accelerator
19
+ from lightning.pytorch.callbacks.callback import Callback
20
+ from lightning.pytorch.loggers import Logger
21
+ from lightning.pytorch.plugins import _PLUGIN_INPUT
22
+ from lightning.pytorch.plugins.layer_sync import LayerSync
23
+ from lightning.pytorch.plugins.precision.precision import Precision
24
+ from lightning.pytorch.profilers import Profiler
25
+ from lightning.pytorch.strategies.strategy import Strategy
26
+ from typing_extensions import TypedDict, TypeVar, override
27
+
28
+ from .._checkpoint.loader import CheckpointLoadingConfig
29
+ from .._hf_hub import HuggingFaceHubConfig
30
+ from ..callbacks import (
31
+ BestCheckpointCallbackConfig,
32
+ CallbackConfig,
33
+ EarlyStoppingConfig,
34
+ LastCheckpointCallbackConfig,
35
+ OnExceptionCheckpointCallbackConfig,
36
+ )
37
+ from ..callbacks.base import CallbackConfigBase
38
+ from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
39
+ from ..callbacks.shared_parameters import SharedParametersConfig
40
+ from ..loggers import (
41
+ CSVLoggerConfig,
42
+ LoggerConfig,
43
+ TensorboardLoggerConfig,
44
+ WandbLoggerConfig,
45
+ )
46
+ from ..profiler import ProfilerConfig
47
+
48
+ if TYPE_CHECKING:
49
+ from ..model.config import BaseConfig
50
+
51
+ log = logging.getLogger(__name__)
52
+
53
+
54
+ class LoggingConfig(CallbackConfigBase):
55
+ enabled: bool = True
56
+ """Enable experiment tracking."""
57
+
58
+ loggers: Sequence[LoggerConfig] = [
59
+ WandbLoggerConfig(),
60
+ CSVLoggerConfig(),
61
+ TensorboardLoggerConfig(),
62
+ ]
63
+ """Loggers to use for experiment tracking."""
64
+
65
+ log_lr: bool | Literal["step", "epoch"] = True
66
+ """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
67
+ log_epoch: bool = True
68
+ """If enabled, will log the fractional epoch number to the logger."""
69
+
70
+ actsave_logged_metrics: bool = False
71
+ """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
72
+
73
+ @property
74
+ def wandb(self):
75
+ return next(
76
+ (
77
+ logger
78
+ for logger in self.loggers
79
+ if isinstance(logger, WandbLoggerConfig)
80
+ ),
81
+ None,
82
+ )
83
+
84
+ @property
85
+ def csv(self):
86
+ return next(
87
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
88
+ None,
89
+ )
90
+
91
+ @property
92
+ def tensorboard(self):
93
+ return next(
94
+ (
95
+ logger
96
+ for logger in self.loggers
97
+ if isinstance(logger, TensorboardLoggerConfig)
98
+ ),
99
+ None,
100
+ )
101
+
102
+ def create_loggers(self, root_config: "BaseConfig"):
103
+ """
104
+ Constructs and returns a list of loggers based on the provided root configuration.
105
+
106
+ Args:
107
+ root_config (BaseConfig): The root configuration object.
108
+
109
+ Returns:
110
+ list[Logger]: A list of constructed loggers.
111
+ """
112
+ if not self.enabled:
113
+ return
114
+
115
+ for logger_config in sorted(
116
+ self.loggers,
117
+ key=lambda x: x.priority,
118
+ reverse=True,
119
+ ):
120
+ if not logger_config.enabled:
121
+ continue
122
+ if (logger := logger_config.create_logger(root_config)) is None:
123
+ continue
124
+ yield logger
125
+
126
+ @override
127
+ def create_callbacks(self, root_config):
128
+ if self.log_lr:
129
+ from lightning.pytorch.callbacks import LearningRateMonitor
130
+
131
+ logging_interval: str | None = None
132
+ if isinstance(self.log_lr, str):
133
+ logging_interval = self.log_lr
134
+
135
+ yield LearningRateMonitor(logging_interval=logging_interval)
136
+
137
+ if self.log_epoch:
138
+ from ..callbacks.log_epoch import LogEpochCallback
139
+
140
+ yield LogEpochCallback()
141
+
142
+ for logger in self.loggers:
143
+ if not logger or not isinstance(logger, CallbackConfigBase):
144
+ continue
145
+
146
+ yield from logger.create_callbacks(root_config)
147
+
148
+
149
+ class GradientClippingConfig(C.Config):
150
+ enabled: bool = True
151
+ """Enable gradient clipping."""
152
+ value: int | float
153
+ """Value to use for gradient clipping."""
154
+ algorithm: Literal["value", "norm"] = "norm"
155
+ """Norm type to use for gradient clipping."""
156
+
157
+
158
+ class OptimizationConfig(CallbackConfigBase):
159
+ log_grad_norm: bool | str | float = False
160
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
161
+ log_grad_norm_per_param: bool | str | float = False
162
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
163
+
164
+ log_param_norm: bool | str | float = False
165
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
166
+ log_param_norm_per_param: bool | str | float = False
167
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
168
+
169
+ gradient_clipping: GradientClippingConfig | None = None
170
+ """Gradient clipping configuration, or None to disable gradient clipping."""
171
+
172
+ @override
173
+ def create_callbacks(self, root_config):
174
+ from ..callbacks.norm_logging import NormLoggingConfig
175
+
176
+ yield from NormLoggingConfig(
177
+ log_grad_norm=self.log_grad_norm,
178
+ log_grad_norm_per_param=self.log_grad_norm_per_param,
179
+ log_param_norm=self.log_param_norm,
180
+ log_param_norm_per_param=self.log_param_norm_per_param,
181
+ ).create_callbacks(root_config)
182
+
183
+
184
+ TPlugin = TypeVar(
185
+ "TPlugin",
186
+ Precision,
187
+ ClusterEnvironment,
188
+ CheckpointIO,
189
+ LayerSync,
190
+ infer_variance=True,
191
+ )
192
+
193
+
194
+ @runtime_checkable
195
+ class PluginConfigProtocol(Protocol[TPlugin]):
196
+ def create_plugin(self) -> TPlugin: ...
197
+
198
+
199
+ @runtime_checkable
200
+ class AcceleratorConfigProtocol(Protocol):
201
+ def create_accelerator(self) -> Accelerator: ...
202
+
203
+
204
+ @runtime_checkable
205
+ class StrategyConfigProtocol(Protocol):
206
+ def create_strategy(self) -> Strategy: ...
207
+
208
+
209
+ AcceleratorLiteral: TypeAlias = Literal[
210
+ "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
211
+ ]
212
+
213
+ StrategyLiteral: TypeAlias = Literal[
214
+ "auto",
215
+ "ddp",
216
+ "ddp_find_unused_parameters_false",
217
+ "ddp_find_unused_parameters_true",
218
+ "ddp_spawn",
219
+ "ddp_spawn_find_unused_parameters_false",
220
+ "ddp_spawn_find_unused_parameters_true",
221
+ "ddp_fork",
222
+ "ddp_fork_find_unused_parameters_false",
223
+ "ddp_fork_find_unused_parameters_true",
224
+ "ddp_notebook",
225
+ "dp",
226
+ "deepspeed",
227
+ "deepspeed_stage_1",
228
+ "deepspeed_stage_1_offload",
229
+ "deepspeed_stage_2",
230
+ "deepspeed_stage_2_offload",
231
+ "deepspeed_stage_3",
232
+ "deepspeed_stage_3_offload",
233
+ "deepspeed_stage_3_offload_nvme",
234
+ "fsdp",
235
+ "fsdp_cpu_offload",
236
+ "single_xla",
237
+ "xla_fsdp",
238
+ "xla",
239
+ "single_tpu",
240
+ ]
241
+
242
+
243
+ class ReproducibilityConfig(C.Config):
244
+ deterministic: bool | Literal["warn"] | None = None
245
+ """
246
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
247
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
248
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
249
+ """
250
+
251
+
252
+ CheckpointCallbackConfig: TypeAlias = Annotated[
253
+ BestCheckpointCallbackConfig
254
+ | LastCheckpointCallbackConfig
255
+ | OnExceptionCheckpointCallbackConfig,
256
+ C.Field(discriminator="name"),
257
+ ]
258
+
259
+
260
+ class CheckpointSavingConfig(CallbackConfigBase):
261
+ enabled: bool = True
262
+ """Enable checkpoint saving."""
263
+
264
+ checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
265
+ BestCheckpointCallbackConfig(),
266
+ LastCheckpointCallbackConfig(),
267
+ OnExceptionCheckpointCallbackConfig(),
268
+ ]
269
+ """Checkpoint callback configurations."""
270
+
271
+ def disable_(self):
272
+ self.enabled = False
273
+ return self
274
+
275
+ def should_save_checkpoints(self, root_config: "BaseConfig"):
276
+ if not self.enabled:
277
+ return False
278
+
279
+ if root_config.trainer.fast_dev_run:
280
+ return False
281
+
282
+ return True
283
+
284
+ @override
285
+ def create_callbacks(self, root_config: "BaseConfig"):
286
+ if not self.should_save_checkpoints(root_config):
287
+ return
288
+
289
+ for callback_config in self.checkpoint_callbacks:
290
+ yield from callback_config.create_callbacks(root_config)
291
+
292
+
293
+ class LightningTrainerKwargs(TypedDict, total=False):
294
+ accelerator: str | Accelerator
295
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
296
+ as well as custom accelerator instances."""
297
+
298
+ strategy: str | Strategy
299
+ """Supports different training strategies with aliases as well custom strategies.
300
+ Default: ``"auto"``.
301
+ """
302
+
303
+ devices: list[int] | str | int
304
+ """The devices to use. Can be set to a positive number (int or str), a sequence of device indices
305
+ (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
306
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
307
+ """
308
+
309
+ num_nodes: int
310
+ """Number of GPU nodes for distributed training.
311
+ Default: ``1``.
312
+ """
313
+
314
+ precision: _PRECISION_INPUT | None
315
+ """Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
316
+ 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
317
+ Can be used on CPU, GPU, TPUs, HPUs or IPUs.
318
+ Default: ``'32-true'``.
319
+ """
320
+
321
+ logger: Logger | Iterable[Logger] | bool | None
322
+ """Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
323
+ the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
324
+ ``False`` will disable logging. If multiple loggers are provided, local files
325
+ (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
326
+ Default: ``True``.
327
+ """
328
+
329
+ callbacks: list[Callback] | Callback | None
330
+ """Add a callback or list of callbacks.
331
+ Default: ``None``.
332
+ """
333
+
334
+ fast_dev_run: int | bool
335
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
336
+ of train, val and test to find any bugs (ie: a sort of unit test).
337
+ Default: ``False``.
338
+ """
339
+
340
+ max_epochs: int | None
341
+ """Stop training once this number of epochs is reached. Disabled by default (None).
342
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
343
+ To enable infinite training, set ``max_epochs = -1``.
344
+ """
345
+
346
+ min_epochs: int | None
347
+ """Force training for at least these many epochs. Disabled by default (None).
348
+ """
349
+
350
+ max_steps: int
351
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
352
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
353
+ ``max_epochs`` to ``-1``.
354
+ """
355
+
356
+ min_steps: int | None
357
+ """Force training for at least these number of steps. Disabled by default (``None``).
358
+ """
359
+
360
+ max_time: str | timedelta | dict[str, int] | None
361
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
362
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
363
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
364
+ :class:`datetime.timedelta`.
365
+ """
366
+
367
+ limit_train_batches: int | float | None
368
+ """How much of training dataset to check (float = fraction, int = num_batches).
369
+ Default: ``1.0``.
370
+ """
371
+
372
+ limit_val_batches: int | float | None
373
+ """How much of validation dataset to check (float = fraction, int = num_batches).
374
+ Default: ``1.0``.
375
+ """
376
+
377
+ limit_test_batches: int | float | None
378
+ """How much of test dataset to check (float = fraction, int = num_batches).
379
+ Default: ``1.0``.
380
+ """
381
+
382
+ limit_predict_batches: int | float | None
383
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
384
+ Default: ``1.0``.
385
+ """
386
+
387
+ overfit_batches: int | float
388
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
389
+ Default: ``0.0``.
390
+ """
391
+
392
+ val_check_interval: int | float | None
393
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
394
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
395
+ batches. An ``int`` value can only be higher than the number of training batches when
396
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
397
+ across epochs or during iteration-based training.
398
+ Default: ``1.0``.
399
+ """
400
+
401
+ check_val_every_n_epoch: int | None
402
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
403
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
404
+ to be an integer value.
405
+ Default: ``1``.
406
+ """
407
+
408
+ num_sanity_val_steps: int | None
409
+ """Sanity check runs n validation batches before starting the training routine.
410
+ Set it to `-1` to run all batches in all validation dataloaders.
411
+ Default: ``2``.
412
+ """
413
+
414
+ log_every_n_steps: int | None
415
+ """How often to log within steps.
416
+ Default: ``50``.
417
+ """
418
+
419
+ enable_checkpointing: bool | None
420
+ """If ``True``, enable checkpointing.
421
+ It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
422
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
423
+ Default: ``True``.
424
+ """
425
+
426
+ enable_progress_bar: bool | None
427
+ """Whether to enable to progress bar by default.
428
+ Default: ``True``.
429
+ """
430
+
431
+ enable_model_summary: bool | None
432
+ """Whether to enable model summarization by default.
433
+ Default: ``True``.
434
+ """
435
+
436
+ accumulate_grad_batches: int
437
+ """Accumulates gradients over k batches before stepping the optimizer.
438
+ Default: 1.
439
+ """
440
+
441
+ gradient_clip_val: int | float | None
442
+ """The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
443
+ gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
444
+ Default: ``None``.
445
+ """
446
+
447
+ gradient_clip_algorithm: str | None
448
+ """The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
449
+ to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
450
+ be set to ``"norm"``.
451
+ """
452
+
453
+ deterministic: bool | Literal["warn"] | None
454
+ """If ``True``, sets whether PyTorch operations must use deterministic algorithms.
455
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
456
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
457
+ """
458
+
459
+ benchmark: bool | None
460
+ """The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
461
+ The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
462
+ (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
463
+ is set to ``True``, this will default to ``False``. Override to manually set a different value.
464
+ Default: ``None``.
465
+ """
466
+
467
+ inference_mode: bool
468
+ """Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
469
+ evaluation (``validate``/``test``/``predict``).
470
+ """
471
+
472
+ use_distributed_sampler: bool
473
+ """Whether to wrap the DataLoader's sampler with
474
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
475
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
476
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
477
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
478
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
479
+ we don't do this automatically.
480
+ """
481
+
482
+ profiler: Profiler | str | None
483
+ """To profile individual steps during training and assist in identifying bottlenecks.
484
+ Default: ``None``.
485
+ """
486
+
487
+ detect_anomaly: bool
488
+ """Enable anomaly detection for the autograd engine.
489
+ Default: ``False``.
490
+ """
491
+
492
+ barebones: bool
493
+ """Whether to run in "barebones mode", where all features that may impact raw speed are
494
+ disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
495
+ runs. The following features are deactivated:
496
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
497
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
498
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
499
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
500
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
501
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
502
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
503
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
504
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
505
+ :meth:`~lightning.pytorch.core.LightningModule.log`,
506
+ :meth:`~lightning.pytorch.core.LightningModule.log_dict`.
507
+ """
508
+
509
+ plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
510
+ """Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
511
+ Default: ``None``.
512
+ """
513
+
514
+ sync_batchnorm: bool
515
+ """Synchronize batch norm layers between process groups/whole world.
516
+ Default: ``False``.
517
+ """
518
+
519
+ reload_dataloaders_every_n_epochs: int
520
+ """Set to a positive integer to reload dataloaders every n epochs.
521
+ Default: ``0``.
522
+ """
523
+
524
+ default_root_dir: Path | None
525
+ """Default path for logs and weights when no logger/ckpt_callback passed.
526
+ Default: ``os.getcwd()``.
527
+ Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
528
+ """
529
+
530
+
531
+ class SanityCheckingConfig(C.Config):
532
+ reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
533
+ """
534
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
535
+ - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
536
+ - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
537
+ Valid values are: "disable", "warn", "error".
538
+ """
539
+
540
+
541
+ class TrainerConfig(C.Config):
542
+ ckpt_path: Literal["none"] | str | Path | None = None
543
+ """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
544
+
545
+ checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
546
+ """Checkpoint loading configuration options.
547
+ `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
548
+ `"none"` will disable checkpoint loading.
549
+ """
550
+
551
+ checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
552
+ """Checkpoint saving configuration options."""
553
+
554
+ hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
555
+ """Hugging Face Hub configuration options."""
556
+
557
+ logging: LoggingConfig = LoggingConfig()
558
+ """Logging/experiment tracking (e.g., WandB) configuration options."""
559
+
560
+ optimizer: OptimizationConfig = OptimizationConfig()
561
+ """Optimization configuration options."""
562
+
563
+ reproducibility: ReproducibilityConfig = ReproducibilityConfig()
564
+ """Reproducibility configuration options."""
565
+
566
+ reduce_lr_on_plateau_sanity_checking: RLPSanityChecksConfig | None = (
567
+ RLPSanityChecksConfig()
568
+ )
569
+ """
570
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
571
+ - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
572
+ - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
573
+ """
574
+
575
+ early_stopping: EarlyStoppingConfig | None = None
576
+ """Early stopping configuration options."""
577
+
578
+ profiler: ProfilerConfig | None = None
579
+ """
580
+ To profile individual steps during training and assist in identifying bottlenecks.
581
+ Default: ``None``.
582
+ """
583
+
584
+ callbacks: list[CallbackConfig] = []
585
+ """Callbacks to use during training."""
586
+
587
+ detect_anomaly: bool | None = None
588
+ """Enable anomaly detection for the autograd engine.
589
+ Default: ``False``.
590
+ """
591
+
592
+ plugins: list[PluginConfigProtocol] | None = None
593
+ """
594
+ Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
595
+ Default: ``None``.
596
+ """
597
+
598
+ auto_determine_num_nodes: bool = True
599
+ """
600
+ If enabled, will automatically determine the number of nodes for distributed training.
601
+
602
+ This will only work on:
603
+ - SLURM clusters
604
+ - LSF clusters
605
+ """
606
+
607
+ fast_dev_run: int | bool = False
608
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
609
+ of train, val and test to find any bugs (ie: a sort of unit test).
610
+ Default: ``False``.
611
+ """
612
+
613
+ precision: (
614
+ Literal[
615
+ "64-true",
616
+ "32-true",
617
+ "fp16-mixed",
618
+ "bf16-mixed",
619
+ "16-mixed-auto",
620
+ ]
621
+ | None
622
+ ) = None
623
+ """
624
+ Training precision. Can be one of:
625
+ - "64-true": Double precision (64-bit).
626
+ - "32-true": Full precision (32-bit).
627
+ - "fp16-mixed": Float16 mixed precision.
628
+ - "bf16-mixed": BFloat16 mixed precision.
629
+ - "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
630
+ """
631
+
632
+ max_epochs: int | None = None
633
+ """Stop training once this number of epochs is reached. Disabled by default (None).
634
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
635
+ To enable infinite training, set ``max_epochs = -1``.
636
+ """
637
+
638
+ min_epochs: int | None = None
639
+ """Force training for at least these many epochs. Disabled by default (None).
640
+ """
641
+
642
+ max_steps: int = -1
643
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
644
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
645
+ ``max_epochs`` to ``-1``.
646
+ """
647
+
648
+ min_steps: int | None = None
649
+ """Force training for at least these number of steps. Disabled by default (``None``).
650
+ """
651
+
652
+ max_time: str | timedelta | dict[str, int] | None = None
653
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
654
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
655
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
656
+ :class:`datetime.timedelta`.
657
+ """
658
+
659
+ limit_train_batches: int | float | None = None
660
+ """How much of training dataset to check (float = fraction, int = num_batches).
661
+ Default: ``1.0``.
662
+ """
663
+
664
+ limit_val_batches: int | float | None = None
665
+ """How much of validation dataset to check (float = fraction, int = num_batches).
666
+ Default: ``1.0``.
667
+ """
668
+
669
+ limit_test_batches: int | float | None = None
670
+ """How much of test dataset to check (float = fraction, int = num_batches).
671
+ Default: ``1.0``.
672
+ """
673
+
674
+ limit_predict_batches: int | float | None = None
675
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
676
+ Default: ``1.0``.
677
+ """
678
+
679
+ overfit_batches: int | float = 0.0
680
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
681
+ Default: ``0.0``.
682
+ """
683
+
684
+ val_check_interval: int | float | None = None
685
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
686
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
687
+ batches. An ``int`` value can only be higher than the number of training batches when
688
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
689
+ across epochs or during iteration-based training.
690
+ Default: ``1.0``.
691
+ """
692
+
693
+ check_val_every_n_epoch: int | None = 1
694
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
695
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
696
+ to be an integer value.
697
+ Default: ``1``.
698
+ """
699
+
700
+ num_sanity_val_steps: int | None = None
701
+ """Sanity check runs n validation batches before starting the training routine.
702
+ Set it to `-1` to run all batches in all validation dataloaders.
703
+ Default: ``2``.
704
+ """
705
+
706
+ log_every_n_steps: int | None = None
707
+ """How often to log within steps.
708
+ Default: ``50``.
709
+ """
710
+
711
+ inference_mode: bool = True
712
+ """Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
713
+ Default: ``True``.
714
+ """
715
+
716
+ use_distributed_sampler: bool | None = None
717
+ """Whether to wrap the DataLoader's sampler with
718
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
719
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
720
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
721
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
722
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
723
+ we don't do this automatically.
724
+ Default: ``True``.
725
+ """
726
+
727
+ accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
728
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
729
+ as well as custom accelerator instances.
730
+ Default: ``"auto"``.
731
+ """
732
+
733
+ strategy: StrategyConfigProtocol | StrategyLiteral | None = None
734
+ """Supports different training strategies with aliases as well custom strategies.
735
+ Default: ``"auto"``.
736
+ """
737
+
738
+ devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
739
+ """The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
740
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
741
+ """
742
+
743
+ shared_parameters: SharedParametersConfig | None = SharedParametersConfig()
744
+ """If enabled, the model supports scaling the gradients of shared parameters that
745
+ are registered in the self.shared_parameters list. This is useful for models that
746
+ share parameters across multiple modules (e.g., in a GPT model) and want to
747
+ downscale the gradients of these parameters to avoid overfitting.
748
+ """
749
+
750
+ auto_set_default_root_dir: bool = True
751
+ """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
752
+ save_checkpoint_metadata: bool = True
753
+ """If enabled, will save additional metadata whenever a checkpoint is saved."""
754
+
755
+ lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
756
+ """
757
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
758
+
759
+ Please refer to the Lightning documentation for a list of valid keyword arguments.
760
+ """
761
+
762
+ additional_lightning_kwargs: dict[str, Any] = {}
763
+ """
764
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
765
+
766
+ This is essentially a non-type-checked version of `lightning_kwargs`.
767
+ """
768
+
769
+ set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
770
+ """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
771
+
772
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
773
+ yield self.early_stopping
774
+ yield self.checkpoint_saving
775
+ yield self.logging
776
+ yield self.optimizer
777
+ yield self.hf_hub
778
+ yield from self.callbacks