nshtrainer 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
@@ -3,1037 +3,24 @@ import logging
3
3
  import os
4
4
  import string
5
5
  import time
6
- from abc import ABC, abstractmethod
7
- from collections.abc import Iterable, Sequence
8
- from datetime import timedelta
6
+ from collections.abc import Iterable
9
7
  from pathlib import Path
10
- from typing import (
11
- Annotated,
12
- Any,
13
- ClassVar,
14
- Literal,
15
- Protocol,
16
- TypeAlias,
17
- runtime_checkable,
18
- )
8
+ from typing import Annotated, Any, ClassVar
19
9
 
20
10
  import nshconfig as C
21
11
  import numpy as np
22
12
  import torch
23
- from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
24
- from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
25
- from lightning.pytorch.accelerators import Accelerator
26
- from lightning.pytorch.callbacks.callback import Callback
27
- from lightning.pytorch.loggers import Logger
28
- from lightning.pytorch.plugins import _PLUGIN_INPUT
29
- from lightning.pytorch.plugins.layer_sync import LayerSync
30
- from lightning.pytorch.plugins.precision.precision import Precision
31
- from lightning.pytorch.profilers import Profiler
32
- from lightning.pytorch.strategies.strategy import Strategy
33
- from typing_extensions import Self, TypedDict, TypeVar, override
13
+ from typing_extensions import Self
34
14
 
35
- from .._checkpoint.loader import CheckpointLoadingConfig
36
- from .._hf_hub import HuggingFaceHubConfig
37
- from ..callbacks import (
38
- BestCheckpointCallbackConfig,
39
- CallbackConfig,
40
- EarlyStoppingConfig,
41
- LastCheckpointCallbackConfig,
42
- OnExceptionCheckpointCallbackConfig,
43
- )
15
+ from .._directory import DirectoryConfig
44
16
  from ..callbacks.base import CallbackConfigBase
45
- from ..loggers import (
46
- CSVLoggerConfig,
47
- LoggerConfig,
48
- TensorboardLoggerConfig,
49
- WandbLoggerConfig,
50
- )
51
17
  from ..metrics import MetricConfig
18
+ from ..trainer._config import TrainerConfig
52
19
  from ..util._environment_info import EnvironmentConfig
53
20
 
54
21
  log = logging.getLogger(__name__)
55
22
 
56
23
 
57
- class BaseProfilerConfig(C.Config, ABC):
58
- dirpath: str | Path | None = None
59
- """
60
- Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
61
- ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
62
- will be used.
63
- """
64
- filename: str | None = None
65
- """
66
- If present, filename where the profiler results will be saved instead of printing to stdout.
67
- The ``.txt`` extension will be used automatically.
68
- """
69
-
70
- @abstractmethod
71
- def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
72
-
73
-
74
- class SimpleProfilerConfig(BaseProfilerConfig):
75
- name: Literal["simple"] = "simple"
76
-
77
- extended: bool = True
78
- """
79
- If ``True``, adds extra columns representing number of calls and percentage of
80
- total time spent onrespective action.
81
- """
82
-
83
- @override
84
- def create_profiler(self, root_config):
85
- from lightning.pytorch.profilers.simple import SimpleProfiler
86
-
87
- if (dirpath := self.dirpath) is None:
88
- dirpath = root_config.directory.resolve_subdirectory(
89
- root_config.id, "profile"
90
- )
91
-
92
- if (filename := self.filename) is None:
93
- filename = f"{root_config.id}_profile.txt"
94
-
95
- return SimpleProfiler(
96
- extended=self.extended,
97
- dirpath=dirpath,
98
- filename=filename,
99
- )
100
-
101
-
102
- class AdvancedProfilerConfig(BaseProfilerConfig):
103
- name: Literal["advanced"] = "advanced"
104
-
105
- line_count_restriction: float = 1.0
106
- """
107
- This can be used to limit the number of functions
108
- reported for each action. either an integer (to select a count of lines),
109
- or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
110
- """
111
-
112
- @override
113
- def create_profiler(self, root_config):
114
- from lightning.pytorch.profilers.advanced import AdvancedProfiler
115
-
116
- if (dirpath := self.dirpath) is None:
117
- dirpath = root_config.directory.resolve_subdirectory(
118
- root_config.id, "profile"
119
- )
120
-
121
- if (filename := self.filename) is None:
122
- filename = f"{root_config.id}_profile.txt"
123
-
124
- return AdvancedProfiler(
125
- line_count_restriction=self.line_count_restriction,
126
- dirpath=dirpath,
127
- filename=filename,
128
- )
129
-
130
-
131
- class PyTorchProfilerConfig(BaseProfilerConfig):
132
- name: Literal["pytorch"] = "pytorch"
133
-
134
- group_by_input_shapes: bool = False
135
- """Include operator input shapes and group calls by shape."""
136
-
137
- emit_nvtx: bool = False
138
- """
139
- Context manager that makes every autograd operation emit an NVTX range
140
- Run::
141
-
142
- nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
143
-
144
- To visualize, you can either use::
145
-
146
- nvvp trace_name.prof
147
- torch.autograd.profiler.load_nvprof(path)
148
- """
149
-
150
- export_to_chrome: bool = True
151
- """
152
- Whether to export the sequence of profiled operators for Chrome.
153
- It will generate a ``.json`` file which can be read by Chrome.
154
- """
155
-
156
- row_limit: int = 20
157
- """
158
- Limit the number of rows in a table, ``-1`` is a special value that
159
- removes the limit completely.
160
- """
161
-
162
- sort_by_key: str | None = None
163
- """
164
- Attribute used to sort entries. By default
165
- they are printed in the same order as they were registered.
166
- Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
167
- ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
168
- ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
169
- """
170
-
171
- record_module_names: bool = True
172
- """Whether to add module names while recording autograd operation."""
173
-
174
- table_kwargs: dict[str, Any] | None = None
175
- """Dictionary with keyword arguments for the summary table."""
176
-
177
- additional_profiler_kwargs: dict[str, Any] = {}
178
- """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
179
-
180
- @override
181
- def create_profiler(self, root_config):
182
- from lightning.pytorch.profilers.pytorch import PyTorchProfiler
183
-
184
- if (dirpath := self.dirpath) is None:
185
- dirpath = root_config.directory.resolve_subdirectory(
186
- root_config.id, "profile"
187
- )
188
-
189
- if (filename := self.filename) is None:
190
- filename = f"{root_config.id}_profile.txt"
191
-
192
- return PyTorchProfiler(
193
- group_by_input_shapes=self.group_by_input_shapes,
194
- emit_nvtx=self.emit_nvtx,
195
- export_to_chrome=self.export_to_chrome,
196
- row_limit=self.row_limit,
197
- sort_by_key=self.sort_by_key,
198
- record_module_names=self.record_module_names,
199
- table_kwargs=self.table_kwargs,
200
- dirpath=dirpath,
201
- filename=filename,
202
- **self.additional_profiler_kwargs,
203
- )
204
-
205
-
206
- ProfilerConfig: TypeAlias = Annotated[
207
- SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
208
- C.Field(discriminator="name"),
209
- ]
210
-
211
-
212
- class LoggingConfig(CallbackConfigBase):
213
- enabled: bool = True
214
- """Enable experiment tracking."""
215
-
216
- loggers: Sequence[LoggerConfig] = [
217
- WandbLoggerConfig(),
218
- CSVLoggerConfig(),
219
- TensorboardLoggerConfig(),
220
- ]
221
- """Loggers to use for experiment tracking."""
222
-
223
- log_lr: bool | Literal["step", "epoch"] = True
224
- """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
225
- log_epoch: bool = True
226
- """If enabled, will log the fractional epoch number to the logger."""
227
-
228
- actsave_logged_metrics: bool = False
229
- """If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
230
-
231
- @property
232
- def wandb(self):
233
- return next(
234
- (
235
- logger
236
- for logger in self.loggers
237
- if isinstance(logger, WandbLoggerConfig)
238
- ),
239
- None,
240
- )
241
-
242
- @property
243
- def csv(self):
244
- return next(
245
- (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
246
- None,
247
- )
248
-
249
- @property
250
- def tensorboard(self):
251
- return next(
252
- (
253
- logger
254
- for logger in self.loggers
255
- if isinstance(logger, TensorboardLoggerConfig)
256
- ),
257
- None,
258
- )
259
-
260
- def create_loggers(self, root_config: "BaseConfig"):
261
- """
262
- Constructs and returns a list of loggers based on the provided root configuration.
263
-
264
- Args:
265
- root_config (BaseConfig): The root configuration object.
266
-
267
- Returns:
268
- list[Logger]: A list of constructed loggers.
269
- """
270
- if not self.enabled:
271
- return
272
-
273
- for logger_config in sorted(
274
- self.loggers,
275
- key=lambda x: x.priority,
276
- reverse=True,
277
- ):
278
- if not logger_config.enabled:
279
- continue
280
- if (logger := logger_config.create_logger(root_config)) is None:
281
- continue
282
- yield logger
283
-
284
- @override
285
- def create_callbacks(self, root_config):
286
- if self.log_lr:
287
- from lightning.pytorch.callbacks import LearningRateMonitor
288
-
289
- logging_interval: str | None = None
290
- if isinstance(self.log_lr, str):
291
- logging_interval = self.log_lr
292
-
293
- yield LearningRateMonitor(logging_interval=logging_interval)
294
-
295
- if self.log_epoch:
296
- from ..callbacks.log_epoch import LogEpochCallback
297
-
298
- yield LogEpochCallback()
299
-
300
- for logger in self.loggers:
301
- if not logger or not isinstance(logger, CallbackConfigBase):
302
- continue
303
-
304
- yield from logger.create_callbacks(root_config)
305
-
306
-
307
- class GradientClippingConfig(C.Config):
308
- enabled: bool = True
309
- """Enable gradient clipping."""
310
- value: int | float
311
- """Value to use for gradient clipping."""
312
- algorithm: Literal["value", "norm"] = "norm"
313
- """Norm type to use for gradient clipping."""
314
-
315
-
316
- class OptimizationConfig(CallbackConfigBase):
317
- log_grad_norm: bool | str | float = False
318
- """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
319
- log_grad_norm_per_param: bool | str | float = False
320
- """If enabled, will log the gradient norm for each model parameter to the logger."""
321
-
322
- log_param_norm: bool | str | float = False
323
- """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
324
- log_param_norm_per_param: bool | str | float = False
325
- """If enabled, will log the parameter norm for each model parameter to the logger."""
326
-
327
- gradient_clipping: GradientClippingConfig | None = None
328
- """Gradient clipping configuration, or None to disable gradient clipping."""
329
-
330
- @override
331
- def create_callbacks(self, root_config):
332
- from ..callbacks.norm_logging import NormLoggingConfig
333
-
334
- yield from NormLoggingConfig(
335
- log_grad_norm=self.log_grad_norm,
336
- log_grad_norm_per_param=self.log_grad_norm_per_param,
337
- log_param_norm=self.log_param_norm,
338
- log_param_norm_per_param=self.log_param_norm_per_param,
339
- ).create_callbacks(root_config)
340
-
341
-
342
- TPlugin = TypeVar(
343
- "TPlugin",
344
- Precision,
345
- ClusterEnvironment,
346
- CheckpointIO,
347
- LayerSync,
348
- infer_variance=True,
349
- )
350
-
351
-
352
- @runtime_checkable
353
- class PluginConfigProtocol(Protocol[TPlugin]):
354
- def create_plugin(self) -> TPlugin: ...
355
-
356
-
357
- @runtime_checkable
358
- class AcceleratorConfigProtocol(Protocol):
359
- def create_accelerator(self) -> Accelerator: ...
360
-
361
-
362
- @runtime_checkable
363
- class StrategyConfigProtocol(Protocol):
364
- def create_strategy(self) -> Strategy: ...
365
-
366
-
367
- AcceleratorLiteral: TypeAlias = Literal[
368
- "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
369
- ]
370
-
371
- StrategyLiteral: TypeAlias = Literal[
372
- "auto",
373
- "ddp",
374
- "ddp_find_unused_parameters_false",
375
- "ddp_find_unused_parameters_true",
376
- "ddp_spawn",
377
- "ddp_spawn_find_unused_parameters_false",
378
- "ddp_spawn_find_unused_parameters_true",
379
- "ddp_fork",
380
- "ddp_fork_find_unused_parameters_false",
381
- "ddp_fork_find_unused_parameters_true",
382
- "ddp_notebook",
383
- "dp",
384
- "deepspeed",
385
- "deepspeed_stage_1",
386
- "deepspeed_stage_1_offload",
387
- "deepspeed_stage_2",
388
- "deepspeed_stage_2_offload",
389
- "deepspeed_stage_3",
390
- "deepspeed_stage_3_offload",
391
- "deepspeed_stage_3_offload_nvme",
392
- "fsdp",
393
- "fsdp_cpu_offload",
394
- "single_xla",
395
- "xla_fsdp",
396
- "xla",
397
- "single_tpu",
398
- ]
399
-
400
-
401
- def _create_symlink_to_nshrunner(base_dir: Path):
402
- # Resolve the current nshrunner session directory
403
- if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
404
- log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
405
- return
406
- session_dir = Path(session_dir)
407
- if not session_dir.exists() or not session_dir.is_dir():
408
- log.warning(
409
- f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
410
- "Skipping symlink creation."
411
- )
412
- return
413
-
414
- # Create the symlink
415
- symlink_path = base_dir / "nshrunner"
416
- if symlink_path.exists():
417
- # If it already points to the correct directory, we're done
418
- if symlink_path.resolve() == session_dir.resolve():
419
- return
420
-
421
- # Otherwise, we should log a warning and remove the existing symlink
422
- log.warning(
423
- f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
424
- "Removing the existing symlink."
425
- )
426
- symlink_path.unlink()
427
-
428
- symlink_path.symlink_to(session_dir)
429
-
430
-
431
- class DirectoryConfig(C.Config):
432
- project_root: Path | None = None
433
- """
434
- Root directory for this project.
435
-
436
- This isn't specific to the run; it is the parent directory of all runs.
437
- """
438
-
439
- create_symlink_to_nshrunner_root: bool = True
440
- """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
441
-
442
- log: Path | None = None
443
- """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
444
-
445
- stdio: Path | None = None
446
- """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
447
-
448
- checkpoint: Path | None = None
449
- """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
450
-
451
- activation: Path | None = None
452
- """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
453
-
454
- profile: Path | None = None
455
- """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
456
-
457
- def resolve_run_root_directory(self, run_id: str) -> Path:
458
- if (project_root_dir := self.project_root) is None:
459
- project_root_dir = Path.cwd()
460
-
461
- # The default base dir is $CWD/nshtrainer/{id}/
462
- base_dir = project_root_dir / "nshtrainer"
463
- base_dir.mkdir(exist_ok=True)
464
-
465
- # Add a .gitignore file to the nshtrainer directory
466
- # which will ignore all files except for the .gitignore file itself
467
- gitignore_path = base_dir / ".gitignore"
468
- if not gitignore_path.exists():
469
- gitignore_path.touch()
470
- gitignore_path.write_text("*\n")
471
-
472
- base_dir = base_dir / run_id
473
- base_dir.mkdir(exist_ok=True)
474
-
475
- # Create a symlink to the root folder for the Runner
476
- if self.create_symlink_to_nshrunner_root:
477
- _create_symlink_to_nshrunner(base_dir)
478
-
479
- return base_dir
480
-
481
- def resolve_subdirectory(
482
- self,
483
- run_id: str,
484
- # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
485
- subdirectory: str,
486
- ) -> Path:
487
- # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
488
- if (subdir := getattr(self, subdirectory, None)) is not None:
489
- assert isinstance(
490
- subdir, Path
491
- ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
492
- return subdir
493
-
494
- dir = self.resolve_run_root_directory(run_id)
495
- dir = dir / subdirectory
496
- dir.mkdir(exist_ok=True)
497
- return dir
498
-
499
- def _resolve_log_directory_for_logger(
500
- self,
501
- run_id: str,
502
- logger: LoggerConfig,
503
- ) -> Path:
504
- if (log_dir := logger.log_dir) is not None:
505
- return log_dir
506
-
507
- # Save to nshtrainer/{id}/log/{logger name}
508
- log_dir = self.resolve_subdirectory(run_id, "log")
509
- log_dir = log_dir / logger.name
510
- log_dir.mkdir(exist_ok=True)
511
-
512
- return log_dir
513
-
514
-
515
- class ReproducibilityConfig(C.Config):
516
- deterministic: bool | Literal["warn"] | None = None
517
- """
518
- If ``True``, sets whether PyTorch operations must use deterministic algorithms.
519
- Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
520
- that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
521
- """
522
-
523
-
524
- CheckpointCallbackConfig: TypeAlias = Annotated[
525
- BestCheckpointCallbackConfig
526
- | LastCheckpointCallbackConfig
527
- | OnExceptionCheckpointCallbackConfig,
528
- C.Field(discriminator="name"),
529
- ]
530
-
531
-
532
- class CheckpointSavingConfig(CallbackConfigBase):
533
- enabled: bool = True
534
- """Enable checkpoint saving."""
535
-
536
- checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
537
- BestCheckpointCallbackConfig(),
538
- LastCheckpointCallbackConfig(),
539
- OnExceptionCheckpointCallbackConfig(),
540
- ]
541
- """Checkpoint callback configurations."""
542
-
543
- def disable_(self):
544
- self.enabled = False
545
- return self
546
-
547
- def should_save_checkpoints(self, root_config: "BaseConfig"):
548
- if not self.enabled:
549
- return False
550
-
551
- if root_config.trainer.fast_dev_run:
552
- return False
553
-
554
- return True
555
-
556
- @override
557
- def create_callbacks(self, root_config: "BaseConfig"):
558
- if not self.should_save_checkpoints(root_config):
559
- return
560
-
561
- for callback_config in self.checkpoint_callbacks:
562
- yield from callback_config.create_callbacks(root_config)
563
-
564
-
565
- class LightningTrainerKwargs(TypedDict, total=False):
566
- accelerator: str | Accelerator
567
- """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
568
- as well as custom accelerator instances."""
569
-
570
- strategy: str | Strategy
571
- """Supports different training strategies with aliases as well custom strategies.
572
- Default: ``"auto"``.
573
- """
574
-
575
- devices: list[int] | str | int
576
- """The devices to use. Can be set to a positive number (int or str), a sequence of device indices
577
- (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
578
- automatic selection based on the chosen accelerator. Default: ``"auto"``.
579
- """
580
-
581
- num_nodes: int
582
- """Number of GPU nodes for distributed training.
583
- Default: ``1``.
584
- """
585
-
586
- precision: _PRECISION_INPUT | None
587
- """Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
588
- 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
589
- Can be used on CPU, GPU, TPUs, HPUs or IPUs.
590
- Default: ``'32-true'``.
591
- """
592
-
593
- logger: Logger | Iterable[Logger] | bool | None
594
- """Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
595
- the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
596
- ``False`` will disable logging. If multiple loggers are provided, local files
597
- (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
598
- Default: ``True``.
599
- """
600
-
601
- callbacks: list[Callback] | Callback | None
602
- """Add a callback or list of callbacks.
603
- Default: ``None``.
604
- """
605
-
606
- fast_dev_run: int | bool
607
- """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
608
- of train, val and test to find any bugs (ie: a sort of unit test).
609
- Default: ``False``.
610
- """
611
-
612
- max_epochs: int | None
613
- """Stop training once this number of epochs is reached. Disabled by default (None).
614
- If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
615
- To enable infinite training, set ``max_epochs = -1``.
616
- """
617
-
618
- min_epochs: int | None
619
- """Force training for at least these many epochs. Disabled by default (None).
620
- """
621
-
622
- max_steps: int
623
- """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
624
- and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
625
- ``max_epochs`` to ``-1``.
626
- """
627
-
628
- min_steps: int | None
629
- """Force training for at least these number of steps. Disabled by default (``None``).
630
- """
631
-
632
- max_time: str | timedelta | dict[str, int] | None
633
- """Stop training after this amount of time has passed. Disabled by default (``None``).
634
- The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
635
- :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
636
- :class:`datetime.timedelta`.
637
- """
638
-
639
- limit_train_batches: int | float | None
640
- """How much of training dataset to check (float = fraction, int = num_batches).
641
- Default: ``1.0``.
642
- """
643
-
644
- limit_val_batches: int | float | None
645
- """How much of validation dataset to check (float = fraction, int = num_batches).
646
- Default: ``1.0``.
647
- """
648
-
649
- limit_test_batches: int | float | None
650
- """How much of test dataset to check (float = fraction, int = num_batches).
651
- Default: ``1.0``.
652
- """
653
-
654
- limit_predict_batches: int | float | None
655
- """How much of prediction dataset to check (float = fraction, int = num_batches).
656
- Default: ``1.0``.
657
- """
658
-
659
- overfit_batches: int | float
660
- """Overfit a fraction of training/validation data (float) or a set number of batches (int).
661
- Default: ``0.0``.
662
- """
663
-
664
- val_check_interval: int | float | None
665
- """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
666
- after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
667
- batches. An ``int`` value can only be higher than the number of training batches when
668
- ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
669
- across epochs or during iteration-based training.
670
- Default: ``1.0``.
671
- """
672
-
673
- check_val_every_n_epoch: int | None
674
- """Perform a validation loop every after every `N` training epochs. If ``None``,
675
- validation will be done solely based on the number of training batches, requiring ``val_check_interval``
676
- to be an integer value.
677
- Default: ``1``.
678
- """
679
-
680
- num_sanity_val_steps: int | None
681
- """Sanity check runs n validation batches before starting the training routine.
682
- Set it to `-1` to run all batches in all validation dataloaders.
683
- Default: ``2``.
684
- """
685
-
686
- log_every_n_steps: int | None
687
- """How often to log within steps.
688
- Default: ``50``.
689
- """
690
-
691
- enable_checkpointing: bool | None
692
- """If ``True``, enable checkpointing.
693
- It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
694
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
695
- Default: ``True``.
696
- """
697
-
698
- enable_progress_bar: bool | None
699
- """Whether to enable to progress bar by default.
700
- Default: ``True``.
701
- """
702
-
703
- enable_model_summary: bool | None
704
- """Whether to enable model summarization by default.
705
- Default: ``True``.
706
- """
707
-
708
- accumulate_grad_batches: int
709
- """Accumulates gradients over k batches before stepping the optimizer.
710
- Default: 1.
711
- """
712
-
713
- gradient_clip_val: int | float | None
714
- """The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
715
- gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
716
- Default: ``None``.
717
- """
718
-
719
- gradient_clip_algorithm: str | None
720
- """The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
721
- to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
722
- be set to ``"norm"``.
723
- """
724
-
725
- deterministic: bool | Literal["warn"] | None
726
- """If ``True``, sets whether PyTorch operations must use deterministic algorithms.
727
- Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
728
- that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
729
- """
730
-
731
- benchmark: bool | None
732
- """The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
733
- The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
734
- (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
735
- is set to ``True``, this will default to ``False``. Override to manually set a different value.
736
- Default: ``None``.
737
- """
738
-
739
- inference_mode: bool
740
- """Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
741
- evaluation (``validate``/``test``/``predict``).
742
- """
743
-
744
- use_distributed_sampler: bool
745
- """Whether to wrap the DataLoader's sampler with
746
- :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
747
- strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
748
- ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
749
- ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
750
- sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
751
- we don't do this automatically.
752
- """
753
-
754
- profiler: Profiler | str | None
755
- """To profile individual steps during training and assist in identifying bottlenecks.
756
- Default: ``None``.
757
- """
758
-
759
- detect_anomaly: bool
760
- """Enable anomaly detection for the autograd engine.
761
- Default: ``False``.
762
- """
763
-
764
- barebones: bool
765
- """Whether to run in "barebones mode", where all features that may impact raw speed are
766
- disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
767
- runs. The following features are deactivated:
768
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
769
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
770
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
771
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
772
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
773
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
774
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
775
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
776
- :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
777
- :meth:`~lightning.pytorch.core.LightningModule.log`,
778
- :meth:`~lightning.pytorch.core.LightningModule.log_dict`.
779
- """
780
-
781
- plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
782
- """Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
783
- Default: ``None``.
784
- """
785
-
786
- sync_batchnorm: bool
787
- """Synchronize batch norm layers between process groups/whole world.
788
- Default: ``False``.
789
- """
790
-
791
- reload_dataloaders_every_n_epochs: int
792
- """Set to a positive integer to reload dataloaders every n epochs.
793
- Default: ``0``.
794
- """
795
-
796
- default_root_dir: Path | None
797
- """Default path for logs and weights when no logger/ckpt_callback passed.
798
- Default: ``os.getcwd()``.
799
- Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
800
- """
801
-
802
-
803
- class SanityCheckingConfig(C.Config):
804
- reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
805
- """
806
- If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
807
- - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
808
- - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
809
- Valid values are: "disable", "warn", "error".
810
- """
811
-
812
-
813
- class TrainerConfig(C.Config):
814
- ckpt_path: Literal["none"] | str | Path | None = None
815
- """Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
816
-
817
- checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
818
- """Checkpoint loading configuration options.
819
- `"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
820
- `"none"` will disable checkpoint loading.
821
- """
822
-
823
- checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
824
- """Checkpoint saving configuration options."""
825
-
826
- hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
827
- """Hugging Face Hub configuration options."""
828
-
829
- logging: LoggingConfig = LoggingConfig()
830
- """Logging/experiment tracking (e.g., WandB) configuration options."""
831
-
832
- optimizer: OptimizationConfig = OptimizationConfig()
833
- """Optimization configuration options."""
834
-
835
- reproducibility: ReproducibilityConfig = ReproducibilityConfig()
836
- """Reproducibility configuration options."""
837
-
838
- sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
839
- """Sanity checking configuration options."""
840
-
841
- early_stopping: EarlyStoppingConfig | None = None
842
- """Early stopping configuration options."""
843
-
844
- profiler: ProfilerConfig | None = None
845
- """
846
- To profile individual steps during training and assist in identifying bottlenecks.
847
- Default: ``None``.
848
- """
849
-
850
- callbacks: list[CallbackConfig] = []
851
- """Callbacks to use during training."""
852
-
853
- detect_anomaly: bool | None = None
854
- """Enable anomaly detection for the autograd engine.
855
- Default: ``False``.
856
- """
857
-
858
- plugins: list[PluginConfigProtocol] | None = None
859
- """
860
- Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
861
- Default: ``None``.
862
- """
863
-
864
- auto_determine_num_nodes: bool = True
865
- """
866
- If enabled, will automatically determine the number of nodes for distributed training.
867
-
868
- This will only work on:
869
- - SLURM clusters
870
- - LSF clusters
871
- """
872
-
873
- fast_dev_run: int | bool = False
874
- """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
875
- of train, val and test to find any bugs (ie: a sort of unit test).
876
- Default: ``False``.
877
- """
878
-
879
- precision: (
880
- Literal[
881
- "64-true",
882
- "32-true",
883
- "fp16-mixed",
884
- "bf16-mixed",
885
- "16-mixed-auto",
886
- ]
887
- | None
888
- ) = None
889
- """
890
- Training precision. Can be one of:
891
- - "64-true": Double precision (64-bit).
892
- - "32-true": Full precision (32-bit).
893
- - "fp16-mixed": Float16 mixed precision.
894
- - "bf16-mixed": BFloat16 mixed precision.
895
- - "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
896
- """
897
-
898
- max_epochs: int | None = None
899
- """Stop training once this number of epochs is reached. Disabled by default (None).
900
- If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
901
- To enable infinite training, set ``max_epochs = -1``.
902
- """
903
-
904
- min_epochs: int | None = None
905
- """Force training for at least these many epochs. Disabled by default (None).
906
- """
907
-
908
- max_steps: int = -1
909
- """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
910
- and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
911
- ``max_epochs`` to ``-1``.
912
- """
913
-
914
- min_steps: int | None = None
915
- """Force training for at least these number of steps. Disabled by default (``None``).
916
- """
917
-
918
- max_time: str | timedelta | dict[str, int] | None = None
919
- """Stop training after this amount of time has passed. Disabled by default (``None``).
920
- The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
921
- :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
922
- :class:`datetime.timedelta`.
923
- """
924
-
925
- limit_train_batches: int | float | None = None
926
- """How much of training dataset to check (float = fraction, int = num_batches).
927
- Default: ``1.0``.
928
- """
929
-
930
- limit_val_batches: int | float | None = None
931
- """How much of validation dataset to check (float = fraction, int = num_batches).
932
- Default: ``1.0``.
933
- """
934
-
935
- limit_test_batches: int | float | None = None
936
- """How much of test dataset to check (float = fraction, int = num_batches).
937
- Default: ``1.0``.
938
- """
939
-
940
- limit_predict_batches: int | float | None = None
941
- """How much of prediction dataset to check (float = fraction, int = num_batches).
942
- Default: ``1.0``.
943
- """
944
-
945
- overfit_batches: int | float = 0.0
946
- """Overfit a fraction of training/validation data (float) or a set number of batches (int).
947
- Default: ``0.0``.
948
- """
949
-
950
- val_check_interval: int | float | None = None
951
- """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
952
- after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
953
- batches. An ``int`` value can only be higher than the number of training batches when
954
- ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
955
- across epochs or during iteration-based training.
956
- Default: ``1.0``.
957
- """
958
-
959
- check_val_every_n_epoch: int | None = 1
960
- """Perform a validation loop every after every `N` training epochs. If ``None``,
961
- validation will be done solely based on the number of training batches, requiring ``val_check_interval``
962
- to be an integer value.
963
- Default: ``1``.
964
- """
965
-
966
- num_sanity_val_steps: int | None = None
967
- """Sanity check runs n validation batches before starting the training routine.
968
- Set it to `-1` to run all batches in all validation dataloaders.
969
- Default: ``2``.
970
- """
971
-
972
- log_every_n_steps: int | None = None
973
- """How often to log within steps.
974
- Default: ``50``.
975
- """
976
-
977
- inference_mode: bool = True
978
- """Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
979
- Default: ``True``.
980
- """
981
-
982
- use_distributed_sampler: bool | None = None
983
- """Whether to wrap the DataLoader's sampler with
984
- :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
985
- strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
986
- ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
987
- ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
988
- sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
989
- we don't do this automatically.
990
- Default: ``True``.
991
- """
992
-
993
- accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
994
- """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
995
- as well as custom accelerator instances.
996
- Default: ``"auto"``.
997
- """
998
-
999
- strategy: StrategyConfigProtocol | StrategyLiteral | None = None
1000
- """Supports different training strategies with aliases as well custom strategies.
1001
- Default: ``"auto"``.
1002
- """
1003
-
1004
- devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
1005
- """The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
1006
- automatic selection based on the chosen accelerator. Default: ``"auto"``.
1007
- """
1008
-
1009
- auto_set_default_root_dir: bool = True
1010
- """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
1011
- supports_shared_parameters: bool = True
1012
- """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1013
- save_checkpoint_metadata: bool = True
1014
- """If enabled, will save additional metadata whenever a checkpoint is saved."""
1015
-
1016
- lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1017
- """
1018
- Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
1019
-
1020
- Please refer to the Lightning documentation for a list of valid keyword arguments.
1021
- """
1022
-
1023
- additional_lightning_kwargs: dict[str, Any] = {}
1024
- """
1025
- Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
1026
-
1027
- This is essentially a non-type-checked version of `lightning_kwargs`.
1028
- """
1029
-
1030
- set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
1031
- """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1032
-
1033
-
1034
- PrimaryMetricConfig: TypeAlias = MetricConfig
1035
-
1036
-
1037
24
  class BaseConfig(C.Config):
1038
25
  id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
1039
26
  """ID of the run."""
@@ -1060,7 +47,7 @@ class BaseConfig(C.Config):
1060
47
  trainer: TrainerConfig = TrainerConfig()
1061
48
  """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
1062
49
 
1063
- primary_metric: PrimaryMetricConfig | None = None
50
+ primary_metric: MetricConfig | None = None
1064
51
  """Primary metric configuration options. This is used in the following ways:
1065
52
  - To determine the best model checkpoint to save with the ModelCheckpoint callback.
1066
53
  - To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
@@ -1216,9 +203,4 @@ class BaseConfig(C.Config):
1216
203
  return cls.model_validate(hparams)
1217
204
 
1218
205
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1219
- yield self.trainer.early_stopping
1220
- yield self.trainer.checkpoint_saving
1221
- yield self.trainer.logging
1222
- yield self.trainer.optimizer
1223
- yield self.trainer.hf_hub
1224
- yield from self.trainer.callbacks
206
+ yield from self.trainer._nshtrainer_all_callback_configs()