nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,2064 @@
1
+ import copy
2
+ import os
3
+ import re
4
+ import signal
5
+ import socket
6
+ import string
7
+ import time
8
+ import warnings
9
+ from abc import ABC, abstractmethod
10
+ from collections.abc import Iterable, Sequence
11
+ from datetime import timedelta
12
+ from logging import getLogger
13
+ from pathlib import Path
14
+ from typing import (
15
+ Annotated,
16
+ Any,
17
+ ClassVar,
18
+ Literal,
19
+ Protocol,
20
+ TypeAlias,
21
+ runtime_checkable,
22
+ )
23
+
24
+ import numpy as np
25
+ import torch
26
+ from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
27
+ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
28
+ from lightning.pytorch.accelerators import Accelerator
29
+ from lightning.pytorch.callbacks.callback import Callback
30
+ from lightning.pytorch.loggers import Logger
31
+ from lightning.pytorch.plugins import _PLUGIN_INPUT
32
+ from lightning.pytorch.plugins.layer_sync import LayerSync
33
+ from lightning.pytorch.plugins.precision.precision import Precision
34
+ from lightning.pytorch.profilers import Profiler
35
+ from lightning.pytorch.strategies.strategy import Strategy
36
+ from pydantic import DirectoryPath
37
+ from typing_extensions import Self, TypedDict, TypeVar, override
38
+
39
+ from ..callbacks import CallbackConfig
40
+ from ..callbacks.base import CallbackConfigBase
41
+ from ..callbacks.wandb_watch import WandbWatchConfig
42
+ from ..config import Field, TypedConfig
43
+ from ..util.slurm import parse_slurm_node_list
44
+
45
+ log = getLogger(__name__)
46
+
47
+
48
+ class IdSeedWarning(Warning):
49
+ pass
50
+
51
+
52
+ class BaseProfilerConfig(TypedConfig, ABC):
53
+ dirpath: str | Path | None = None
54
+ """
55
+ Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
56
+ ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
57
+ will be used.
58
+ """
59
+ filename: str | None = None
60
+ """
61
+ If present, filename where the profiler results will be saved instead of printing to stdout.
62
+ The ``.txt`` extension will be used automatically.
63
+ """
64
+
65
+ @abstractmethod
66
+ def construct_profiler(self, root_config: "BaseConfig") -> Profiler: ...
67
+
68
+
69
+ class SimpleProfilerConfig(BaseProfilerConfig):
70
+ kind: Literal["simple"] = "simple"
71
+
72
+ extended: bool = True
73
+ """
74
+ If ``True``, adds extra columns representing number of calls and percentage of
75
+ total time spent onrespective action.
76
+ """
77
+
78
+ @override
79
+ def construct_profiler(self, root_config):
80
+ from lightning.pytorch.profilers.simple import SimpleProfiler
81
+
82
+ if (dirpath := self.dirpath) is None:
83
+ dirpath = root_config.directory.resolve_subdirectory(
84
+ root_config.id, "profile"
85
+ )
86
+
87
+ if (filename := self.filename) is None:
88
+ filename = f"{root_config.id}_profile.txt"
89
+
90
+ return SimpleProfiler(
91
+ extended=self.extended,
92
+ dirpath=dirpath,
93
+ filename=filename,
94
+ )
95
+
96
+
97
+ class AdvancedProfilerConfig(BaseProfilerConfig):
98
+ kind: Literal["advanced"] = "advanced"
99
+
100
+ line_count_restriction: float = 1.0
101
+ """
102
+ This can be used to limit the number of functions
103
+ reported for each action. either an integer (to select a count of lines),
104
+ or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
105
+ """
106
+
107
+ @override
108
+ def construct_profiler(self, root_config):
109
+ from lightning.pytorch.profilers.advanced import AdvancedProfiler
110
+
111
+ if (dirpath := self.dirpath) is None:
112
+ dirpath = root_config.directory.resolve_subdirectory(
113
+ root_config.id, "profile"
114
+ )
115
+
116
+ if (filename := self.filename) is None:
117
+ filename = f"{root_config.id}_profile.txt"
118
+
119
+ return AdvancedProfiler(
120
+ line_count_restriction=self.line_count_restriction,
121
+ dirpath=dirpath,
122
+ filename=filename,
123
+ )
124
+
125
+
126
+ class PyTorchProfilerConfig(BaseProfilerConfig):
127
+ kind: Literal["pytorch"] = "pytorch"
128
+
129
+ group_by_input_shapes: bool = False
130
+ """Include operator input shapes and group calls by shape."""
131
+
132
+ emit_nvtx: bool = False
133
+ """
134
+ Context manager that makes every autograd operation emit an NVTX range
135
+ Run::
136
+
137
+ nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
138
+
139
+ To visualize, you can either use::
140
+
141
+ nvvp trace_name.prof
142
+ torch.autograd.profiler.load_nvprof(path)
143
+ """
144
+
145
+ export_to_chrome: bool = True
146
+ """
147
+ Whether to export the sequence of profiled operators for Chrome.
148
+ It will generate a ``.json`` file which can be read by Chrome.
149
+ """
150
+
151
+ row_limit: int = 20
152
+ """
153
+ Limit the number of rows in a table, ``-1`` is a special value that
154
+ removes the limit completely.
155
+ """
156
+
157
+ sort_by_key: str | None = None
158
+ """
159
+ Attribute used to sort entries. By default
160
+ they are printed in the same order as they were registered.
161
+ Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
162
+ ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
163
+ ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
164
+ """
165
+
166
+ record_module_names: bool = True
167
+ """Whether to add module names while recording autograd operation."""
168
+
169
+ table_kwargs: dict[str, Any] | None = None
170
+ """Dictionary with keyword arguments for the summary table."""
171
+
172
+ additional_profiler_kwargs: dict[str, Any] = {}
173
+ """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
174
+
175
+ @override
176
+ def construct_profiler(self, root_config):
177
+ from lightning.pytorch.profilers.pytorch import PyTorchProfiler
178
+
179
+ if (dirpath := self.dirpath) is None:
180
+ dirpath = root_config.directory.resolve_subdirectory(
181
+ root_config.id, "profile"
182
+ )
183
+
184
+ if (filename := self.filename) is None:
185
+ filename = f"{root_config.id}_profile.txt"
186
+
187
+ return PyTorchProfiler(
188
+ group_by_input_shapes=self.group_by_input_shapes,
189
+ emit_nvtx=self.emit_nvtx,
190
+ export_to_chrome=self.export_to_chrome,
191
+ row_limit=self.row_limit,
192
+ sort_by_key=self.sort_by_key,
193
+ record_module_names=self.record_module_names,
194
+ table_kwargs=self.table_kwargs,
195
+ dirpath=dirpath,
196
+ filename=filename,
197
+ **self.additional_profiler_kwargs,
198
+ )
199
+
200
+
201
+ ProfilerConfig: TypeAlias = Annotated[
202
+ SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
203
+ Field(discriminator="kind"),
204
+ ]
205
+
206
+
207
+ class EnvironmentClassInformationConfig(TypedConfig):
208
+ name: str
209
+ module: str
210
+ full_name: str
211
+
212
+ file_path: Path
213
+ source_file_path: Path | None = None
214
+
215
+
216
+ class EnvironmentSLURMInformationConfig(TypedConfig):
217
+ hostname: str
218
+ hostnames: list[str]
219
+ job_id: str
220
+ raw_job_id: str
221
+ array_job_id: str | None
222
+ array_task_id: str | None
223
+ num_tasks: int
224
+ num_nodes: int
225
+ node: str | int | None
226
+ global_rank: int
227
+ local_rank: int
228
+
229
+ @classmethod
230
+ def from_current_environment(cls):
231
+ try:
232
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
233
+
234
+ if not SLURMEnvironment.detect():
235
+ return None
236
+
237
+ hostname = socket.gethostname()
238
+ hostnames = [hostname]
239
+ if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
240
+ hostnames = parse_slurm_node_list(node_list)
241
+
242
+ raw_job_id = os.environ["SLURM_JOB_ID"]
243
+ job_id = raw_job_id
244
+ array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
245
+ array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
246
+ if array_job_id and array_task_id:
247
+ job_id = f"{array_job_id}_{array_task_id}"
248
+
249
+ num_tasks = int(os.environ["SLURM_NTASKS"])
250
+ num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
251
+
252
+ node_id = os.environ.get("SLURM_NODEID")
253
+
254
+ global_rank = int(os.environ["SLURM_PROCID"])
255
+ local_rank = int(os.environ["SLURM_LOCALID"])
256
+
257
+ return cls(
258
+ hostname=hostname,
259
+ hostnames=hostnames,
260
+ job_id=job_id,
261
+ raw_job_id=raw_job_id,
262
+ array_job_id=array_job_id,
263
+ array_task_id=array_task_id,
264
+ num_tasks=num_tasks,
265
+ num_nodes=num_nodes,
266
+ node=node_id,
267
+ global_rank=global_rank,
268
+ local_rank=local_rank,
269
+ )
270
+ except (ImportError, RuntimeError, ValueError, KeyError):
271
+ return None
272
+
273
+
274
+ class EnvironmentLSFInformationConfig(TypedConfig):
275
+ hostname: str
276
+ hostnames: list[str]
277
+ job_id: str
278
+ array_job_id: str | None
279
+ array_task_id: str | None
280
+ num_tasks: int
281
+ num_nodes: int
282
+ node: str | int | None
283
+ global_rank: int
284
+ local_rank: int
285
+
286
+ @classmethod
287
+ def from_current_environment(cls):
288
+ try:
289
+ import os
290
+ import socket
291
+
292
+ hostname = socket.gethostname()
293
+ hostnames = [hostname]
294
+ if node_list := os.environ.get("LSB_HOSTS", ""):
295
+ hostnames = node_list.split()
296
+
297
+ job_id = os.environ["LSB_JOBID"]
298
+ array_job_id = os.environ.get("LSB_JOBINDEX")
299
+ array_task_id = os.environ.get("LSB_JOBINDEX")
300
+
301
+ num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
302
+ num_nodes = len(set(hostnames))
303
+
304
+ node_id = (
305
+ os.environ.get("LSB_HOSTS", "").split().index(hostname)
306
+ if "LSB_HOSTS" in os.environ
307
+ else None
308
+ )
309
+
310
+ # LSF doesn't have direct equivalents for global_rank and local_rank
311
+ # You might need to calculate these based on your specific setup
312
+ global_rank = int(os.environ.get("PMI_RANK", 0))
313
+ local_rank = int(os.environ.get("LSB_RANK", 0))
314
+
315
+ return cls(
316
+ hostname=hostname,
317
+ hostnames=hostnames,
318
+ job_id=job_id,
319
+ array_job_id=array_job_id,
320
+ array_task_id=array_task_id,
321
+ num_tasks=num_tasks,
322
+ num_nodes=num_nodes,
323
+ node=node_id,
324
+ global_rank=global_rank,
325
+ local_rank=local_rank,
326
+ )
327
+ except (ImportError, RuntimeError, ValueError, KeyError):
328
+ return None
329
+
330
+
331
+ class EnvironmentLinuxEnvironmentConfig(TypedConfig):
332
+ """
333
+ Information about the Linux environment (e.g., current user, hostname, etc.)
334
+ """
335
+
336
+ user: str | None = None
337
+ hostname: str | None = None
338
+ system: str | None = None
339
+ release: str | None = None
340
+ version: str | None = None
341
+ machine: str | None = None
342
+ processor: str | None = None
343
+ cpu_count: int | None = None
344
+ memory: int | None = None
345
+ uptime: timedelta | None = None
346
+ boot_time: float | None = None
347
+ load_avg: tuple[float, float, float] | None = None
348
+
349
+
350
+ class EnvironmentConfig(TypedConfig):
351
+ cwd: Path | None = None
352
+
353
+ python_executable: Path | None = None
354
+ python_path: list[Path] | None = None
355
+ python_version: str | None = None
356
+
357
+ config: EnvironmentClassInformationConfig | None = None
358
+ model: EnvironmentClassInformationConfig | None = None
359
+ data: EnvironmentClassInformationConfig | None = None
360
+
361
+ linux: EnvironmentLinuxEnvironmentConfig | None = None
362
+
363
+ slurm: EnvironmentSLURMInformationConfig | None = None
364
+ lsf: EnvironmentLSFInformationConfig | None = None
365
+
366
+ base_dir: Path | None = None
367
+ log_dir: Path | None = None
368
+ checkpoint_dir: Path | None = None
369
+ stdio_dir: Path | None = None
370
+
371
+ seed: int | None = None
372
+ seed_workers: bool | None = None
373
+
374
+
375
+ class BaseLoggerConfig(TypedConfig, ABC):
376
+ enabled: bool = True
377
+ """Enable this logger."""
378
+
379
+ priority: int = 0
380
+ """Priority of the logger. Higher values are logged first."""
381
+
382
+ log_dir: DirectoryPath | None = None
383
+ """Directory to save the logs to. If None, will use the default log directory for the trainer."""
384
+
385
+ @abstractmethod
386
+ def construct_logger(self, root_config: "BaseConfig") -> Logger | None: ...
387
+
388
+ def disable_(self):
389
+ self.enabled = False
390
+ return self
391
+
392
+
393
+ def _project_name(
394
+ root_config: "BaseConfig",
395
+ default_project: str = "lightning_logs",
396
+ ):
397
+ # If the config has a project name, use that.
398
+ if project := root_config.project:
399
+ return project
400
+
401
+ # Otherwise, we should use the name of the module that the config is defined in,
402
+ # if we can find it.
403
+ # If this isn't in a module, use the default project name.
404
+ if not (module := root_config.__module__):
405
+ return default_project
406
+
407
+ # If the module is a package, use the package name.
408
+ if not (module := module.split(".", maxsplit=1)[0].strip()):
409
+ return default_project
410
+
411
+ return module
412
+
413
+
414
+ def _wandb_available():
415
+ try:
416
+ from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
417
+
418
+ if not _WANDB_AVAILABLE:
419
+ log.warning("WandB not found. Disabling WandbLogger.")
420
+ return False
421
+ return True
422
+ except ImportError:
423
+ return False
424
+
425
+
426
+ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
427
+ kind: Literal["wandb"] = "wandb"
428
+
429
+ enabled: bool = Field(default_factory=lambda: _wandb_available())
430
+ """Enable WandB logging."""
431
+
432
+ priority: int = 2
433
+ """Priority of the logger. Higher values are logged first."""
434
+
435
+ project: str | None = None
436
+ """WandB project name to use for the logger. If None, will use the root config's project name."""
437
+
438
+ log_model: bool | Literal["all"] = False
439
+ """
440
+ Whether to log the model checkpoints to wandb.
441
+ Valid values are:
442
+ - False: Do not log the model checkpoints.
443
+ - True: Log the latest model checkpoint.
444
+ - "all": Log all model checkpoints.
445
+ """
446
+
447
+ watch: WandbWatchConfig = WandbWatchConfig()
448
+ """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
449
+
450
+ offline: bool = False
451
+ """Whether to run WandB in offline mode."""
452
+
453
+ @override
454
+ def construct_logger(self, root_config):
455
+ if not self.enabled:
456
+ return None
457
+
458
+ from lightning.pytorch.loggers.wandb import WandbLogger
459
+
460
+ save_dir = root_config.directory.resolve_log_directory_for_logger(
461
+ root_config.id,
462
+ self,
463
+ )
464
+ save_dir = save_dir / "wandb"
465
+ save_dir.mkdir(parents=True, exist_ok=True)
466
+ return WandbLogger(
467
+ save_dir=save_dir,
468
+ project=self.project or _project_name(root_config),
469
+ name=root_config.run_name,
470
+ version=root_config.id,
471
+ log_model=self.log_model,
472
+ notes=(
473
+ "\n".join(f"- {note}" for note in root_config.notes)
474
+ if root_config.notes
475
+ else None
476
+ ),
477
+ tags=root_config.tags,
478
+ offline=self.offline,
479
+ )
480
+
481
+ @override
482
+ def construct_callbacks(self, root_config):
483
+ if self.watch:
484
+ yield from self.watch.construct_callbacks(root_config)
485
+
486
+
487
+ class CSVLoggerConfig(BaseLoggerConfig):
488
+ kind: Literal["csv"] = "csv"
489
+
490
+ enabled: bool = True
491
+ """Enable CSV logging."""
492
+
493
+ priority: int = 0
494
+ """Priority of the logger. Higher values are logged first."""
495
+
496
+ prefix: str = ""
497
+ """A string to put at the beginning of metric keys."""
498
+
499
+ flush_logs_every_n_steps: int = 100
500
+ """How often to flush logs to disk."""
501
+
502
+ @override
503
+ def construct_logger(self, root_config):
504
+ if not self.enabled:
505
+ return None
506
+
507
+ from lightning.pytorch.loggers.csv_logs import CSVLogger
508
+
509
+ save_dir = root_config.directory.resolve_log_directory_for_logger(
510
+ root_config.id,
511
+ self,
512
+ )
513
+ save_dir = save_dir / "csv"
514
+ save_dir.mkdir(parents=True, exist_ok=True)
515
+ return CSVLogger(
516
+ save_dir=save_dir,
517
+ name=root_config.run_name,
518
+ version=root_config.id,
519
+ prefix=self.prefix,
520
+ flush_logs_every_n_steps=self.flush_logs_every_n_steps,
521
+ )
522
+
523
+
524
+ def _tensorboard_available():
525
+ try:
526
+ from lightning.fabric.loggers.tensorboard import (
527
+ _TENSORBOARD_AVAILABLE,
528
+ _TENSORBOARDX_AVAILABLE,
529
+ )
530
+
531
+ if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
532
+ log.warning(
533
+ "TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
534
+ "Please install TensorBoard with `pip install tensorboard` or "
535
+ "TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
536
+ )
537
+ return False
538
+ return True
539
+ except ImportError:
540
+ return False
541
+
542
+
543
+ class TensorboardLoggerConfig(BaseLoggerConfig):
544
+ kind: Literal["tensorboard"] = "tensorboard"
545
+
546
+ enabled: bool = Field(default_factory=lambda: _tensorboard_available())
547
+ """Enable TensorBoard logging."""
548
+
549
+ priority: int = 2
550
+ """Priority of the logger. Higher values are logged first."""
551
+
552
+ log_graph: bool = False
553
+ """
554
+ Adds the computational graph to tensorboard. This requires that
555
+ the user has defined the `self.example_input_array` attribute in their
556
+ model.
557
+ """
558
+
559
+ default_hp_metric: bool = True
560
+ """
561
+ Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
562
+ called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
563
+ """
564
+
565
+ prefix: str = ""
566
+ """A string to put at the beginning of metric keys."""
567
+
568
+ @override
569
+ def construct_logger(self, root_config):
570
+ if not self.enabled:
571
+ return None
572
+
573
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
574
+
575
+ save_dir = root_config.directory.resolve_log_directory_for_logger(
576
+ root_config.id,
577
+ self,
578
+ )
579
+ save_dir = save_dir / "tensorboard"
580
+ save_dir.mkdir(parents=True, exist_ok=True)
581
+ return TensorBoardLogger(
582
+ save_dir=save_dir,
583
+ name=root_config.run_name,
584
+ version=root_config.id,
585
+ log_graph=self.log_graph,
586
+ default_hp_metric=self.default_hp_metric,
587
+ )
588
+
589
+
590
+ LoggerConfig: TypeAlias = Annotated[
591
+ WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
592
+ Field(discriminator="kind"),
593
+ ]
594
+
595
+
596
+ class LoggingConfig(CallbackConfigBase):
597
+ enabled: bool = True
598
+ """Enable experiment tracking."""
599
+
600
+ loggers: Sequence[LoggerConfig] = [
601
+ WandbLoggerConfig(),
602
+ CSVLoggerConfig(),
603
+ TensorboardLoggerConfig(),
604
+ ]
605
+ """Loggers to use for experiment tracking."""
606
+
607
+ log_lr: bool | Literal["step", "epoch"] = True
608
+ """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
609
+ log_epoch: bool = True
610
+ """If enabled, will log the fractional epoch number to the logger."""
611
+
612
+ @property
613
+ def wandb(self) -> WandbLoggerConfig | None:
614
+ return next(
615
+ (
616
+ logger
617
+ for logger in self.loggers
618
+ if isinstance(logger, WandbLoggerConfig)
619
+ ),
620
+ )
621
+
622
+ @property
623
+ def csv(self) -> CSVLoggerConfig | None:
624
+ return next(
625
+ (logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
626
+ )
627
+
628
+ @property
629
+ def tensorboard(self) -> TensorboardLoggerConfig | None:
630
+ return next(
631
+ (
632
+ logger
633
+ for logger in self.loggers
634
+ if isinstance(logger, TensorboardLoggerConfig)
635
+ ),
636
+ )
637
+
638
+ def construct_loggers(self, root_config: "BaseConfig"):
639
+ """
640
+ Constructs and returns a list of loggers based on the provided root configuration.
641
+
642
+ Args:
643
+ root_config (BaseConfig): The root configuration object.
644
+
645
+ Returns:
646
+ list[Logger]: A list of constructed loggers.
647
+ """
648
+ loggers: list[Logger] = []
649
+ if not self.enabled:
650
+ return loggers
651
+
652
+ for logger_config in sorted(
653
+ self.loggers,
654
+ key=lambda x: x.priority,
655
+ reverse=True,
656
+ ):
657
+ if not logger_config.enabled:
658
+ continue
659
+ if (logger := logger_config.construct_logger(root_config)) is None:
660
+ continue
661
+ loggers.append(logger)
662
+ return loggers
663
+
664
+ @override
665
+ def construct_callbacks(self, root_config):
666
+ if self.log_lr:
667
+ from lightning.pytorch.callbacks import LearningRateMonitor
668
+
669
+ logging_interval: str | None = None
670
+ if isinstance(self.log_lr, str):
671
+ logging_interval = self.log_lr
672
+
673
+ yield LearningRateMonitor(logging_interval=logging_interval)
674
+
675
+ if self.log_epoch:
676
+ from ..callbacks.log_epoch import LogEpochCallback
677
+
678
+ yield LogEpochCallback()
679
+
680
+ for logger in self.loggers:
681
+ if not logger or not isinstance(logger, CallbackConfigBase):
682
+ continue
683
+
684
+ yield from logger.construct_callbacks(root_config)
685
+
686
+
687
+ class GradientClippingConfig(TypedConfig):
688
+ enabled: bool = True
689
+ """Enable gradient clipping."""
690
+ value: int | float
691
+ """Value to use for gradient clipping."""
692
+ algorithm: Literal["value", "norm"] = "norm"
693
+ """Norm type to use for gradient clipping."""
694
+
695
+
696
+ class OptimizationConfig(CallbackConfigBase):
697
+ log_grad_norm: bool | str | float = False
698
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
699
+ log_grad_norm_per_param: bool | str | float = False
700
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
701
+
702
+ log_param_norm: bool | str | float = False
703
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
704
+ log_param_norm_per_param: bool | str | float = False
705
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
706
+
707
+ gradient_clipping: GradientClippingConfig | None = None
708
+ """Gradient clipping configuration, or None to disable gradient clipping."""
709
+
710
+ @override
711
+ def construct_callbacks(self, root_config):
712
+ from ..callbacks.norm_logging import NormLoggingConfig
713
+
714
+ yield from NormLoggingConfig(
715
+ log_grad_norm=self.log_grad_norm,
716
+ log_grad_norm_per_param=self.log_grad_norm_per_param,
717
+ log_param_norm=self.log_param_norm,
718
+ log_param_norm_per_param=self.log_param_norm_per_param,
719
+ ).construct_callbacks(root_config)
720
+
721
+
722
+ LogLevel: TypeAlias = Literal[
723
+ "CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"
724
+ ]
725
+
726
+
727
+ class PythonLogging(TypedConfig):
728
+ log_level: LogLevel | None = None
729
+ """Log level to use for the Python logger (or None to use the default)."""
730
+
731
+ rich: bool = False
732
+ """If enabled, will use the rich library to format the Python logger output."""
733
+ rich_tracebacks: bool = True
734
+ """If enabled, will use the rich library to format the Python logger tracebacks."""
735
+
736
+ lovely_tensors: bool = False
737
+ """If enabled, will use the lovely-tensors library to format PyTorch tensors. False by default as it causes issues when used with `torch.vmap`."""
738
+ lovely_numpy: bool = False
739
+ """If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries."""
740
+
741
+ def pretty_(
742
+ self,
743
+ *,
744
+ log_level: LogLevel | None = "INFO",
745
+ torch: bool = True,
746
+ numpy: bool = True,
747
+ rich: bool = True,
748
+ rich_tracebacks: bool = True,
749
+ ):
750
+ self.log_level = log_level
751
+ self.lovely_tensors = torch
752
+ self.lovely_numpy = numpy
753
+ self.rich = rich
754
+ self.rich_tracebacks = rich_tracebacks
755
+
756
+
757
+ TPlugin = TypeVar(
758
+ "TPlugin",
759
+ Precision,
760
+ ClusterEnvironment,
761
+ CheckpointIO,
762
+ LayerSync,
763
+ infer_variance=True,
764
+ )
765
+
766
+
767
+ @runtime_checkable
768
+ class PluginConfigProtocol(Protocol[TPlugin]):
769
+ def construct_plugin(self) -> TPlugin: ...
770
+
771
+
772
+ @runtime_checkable
773
+ class AcceleratorConfigProtocol(Protocol):
774
+ def construct_accelerator(self) -> Accelerator: ...
775
+
776
+
777
+ @runtime_checkable
778
+ class StrategyConfigProtocol(Protocol):
779
+ def construct_strategy(self) -> Strategy: ...
780
+
781
+
782
+ AcceleratorLiteral: TypeAlias = Literal[
783
+ "cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
784
+ ]
785
+
786
+ StrategyLiteral: TypeAlias = Literal[
787
+ "auto",
788
+ "ddp",
789
+ "ddp_find_unused_parameters_false",
790
+ "ddp_find_unused_parameters_true",
791
+ "ddp_spawn",
792
+ "ddp_spawn_find_unused_parameters_false",
793
+ "ddp_spawn_find_unused_parameters_true",
794
+ "ddp_fork",
795
+ "ddp_fork_find_unused_parameters_false",
796
+ "ddp_fork_find_unused_parameters_true",
797
+ "ddp_notebook",
798
+ "dp",
799
+ "deepspeed",
800
+ "deepspeed_stage_1",
801
+ "deepspeed_stage_1_offload",
802
+ "deepspeed_stage_2",
803
+ "deepspeed_stage_2_offload",
804
+ "deepspeed_stage_3",
805
+ "deepspeed_stage_3_offload",
806
+ "deepspeed_stage_3_offload_nvme",
807
+ "fsdp",
808
+ "fsdp_cpu_offload",
809
+ "single_xla",
810
+ "xla_fsdp",
811
+ "xla",
812
+ "single_tpu",
813
+ ]
814
+
815
+
816
+ class CheckpointLoadingConfig(TypedConfig):
817
+ path: Literal["best", "last", "hpc"] | str | Path | None = None
818
+ """
819
+ Checkpoint path to use when loading a checkpoint.
820
+
821
+ - "best" will load the best checkpoint.
822
+ - "last" will load the last checkpoint.
823
+ - "hpc" will load the SLURM pre-empted checkpoint.
824
+ - Any other string or Path will load the checkpoint from the specified path.
825
+ """
826
+
827
+
828
+ class DirectoryConfig(TypedConfig):
829
+ project_root: Path | None = None
830
+ """
831
+ Root directory for this project.
832
+
833
+ This isn't specific to the run; it is the parent directory of all runs.
834
+ """
835
+
836
+ log: Path | None = None
837
+ """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use lltrainer/{id}/log/."""
838
+
839
+ stdio: Path | None = None
840
+ """stdout/stderr log directory to use for the trainer. If None, will use lltrainer/{id}/stdio/."""
841
+
842
+ checkpoint: Path | None = None
843
+ """Checkpoint directory to use for the trainer. If None, will use lltrainer/{id}/checkpoint/."""
844
+
845
+ activation: Path | None = None
846
+ """Activation directory to use for the trainer. If None, will use lltrainer/{id}/activation/."""
847
+
848
+ profile: Path | None = None
849
+ """Directory to save profiling information to. If None, will use lltrainer/{id}/profile/."""
850
+
851
+ def resolve_run_root_directory(self, run_id: str) -> Path:
852
+ if (project_root_dir := self.project_root) is None:
853
+ project_root_dir = Path.cwd()
854
+
855
+ # The default base dir is $CWD/lltrainer/{id}/
856
+ base_dir = project_root_dir / "lltrainer"
857
+ base_dir.mkdir(exist_ok=True)
858
+
859
+ # Add a .gitignore file to the lltrainer directory
860
+ # which will ignore all files except for the .gitignore file itself
861
+ gitignore_path = base_dir / ".gitignore"
862
+ if not gitignore_path.exists():
863
+ gitignore_path.touch()
864
+ gitignore_path.write_text("*\n")
865
+
866
+ base_dir = base_dir / run_id
867
+ base_dir.mkdir(exist_ok=True)
868
+
869
+ return base_dir
870
+
871
+ def resolve_subdirectory(
872
+ self,
873
+ run_id: str,
874
+ # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
875
+ subdirectory: str,
876
+ ) -> Path:
877
+ # The subdir will be $CWD/lltrainer/{id}/{log, stdio, checkpoint, activation}/
878
+ if (subdir := getattr(self, subdirectory, None)) is not None:
879
+ assert isinstance(
880
+ subdir, Path
881
+ ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
882
+ return subdir
883
+
884
+ dir = self.resolve_run_root_directory(run_id)
885
+ dir = dir / subdirectory
886
+ dir.mkdir(exist_ok=True)
887
+ return dir
888
+
889
+ def resolve_log_directory_for_logger(
890
+ self,
891
+ run_id: str,
892
+ logger: LoggerConfig,
893
+ ) -> Path:
894
+ if (log_dir := logger.log_dir) is not None:
895
+ return log_dir
896
+
897
+ # Save to lltrainer/{id}/log/{logger kind}/{id}/
898
+ log_dir = self.resolve_subdirectory(run_id, "log")
899
+ log_dir = log_dir / logger.kind
900
+
901
+ return log_dir
902
+
903
+
904
+ class ReproducibilityConfig(TypedConfig):
905
+ deterministic: bool | Literal["warn"] | None = None
906
+ """
907
+ If ``True``, sets whether PyTorch operations must use deterministic algorithms.
908
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
909
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
910
+ """
911
+
912
+
913
+ class ModelCheckpointCallbackConfig(CallbackConfigBase):
914
+ """Arguments for the ModelCheckpoint callback."""
915
+
916
+ kind: Literal["model_checkpoint"] = "model_checkpoint"
917
+
918
+ dirpath: str | Path | None = None
919
+ """
920
+ Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
921
+ """
922
+
923
+ filename: str | None = None
924
+ """
925
+ Checkpoint filename.
926
+ If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
927
+ """
928
+
929
+ monitor: str | None = None
930
+ """
931
+ Quantity to monitor for saving checkpoints.
932
+ If None, no metric is monitored and checkpoints are saved at the end of every epoch.
933
+ """
934
+
935
+ verbose: bool = False
936
+ """Verbosity mode. If True, print additional information about checkpoints."""
937
+
938
+ save_last: Literal[True, False, "link"] | None = "link"
939
+ """
940
+ Whether to save the last checkpoint.
941
+ If True, saves a copy of the last checkpoint separately.
942
+ If "link", creates a symbolic link to the last checkpoint.
943
+ """
944
+
945
+ save_top_k: int = 1
946
+ """
947
+ Number of best models to save.
948
+ If -1, all models are saved.
949
+ If 0, no models are saved.
950
+ """
951
+
952
+ save_weights_only: bool = False
953
+ """Whether to save only the model's weights or the entire model object."""
954
+
955
+ mode: str = "min"
956
+ """
957
+ One of "min" or "max".
958
+ If "min", training will stop when the metric monitored has stopped decreasing.
959
+ If "max", training will stop when the metric monitored has stopped increasing.
960
+ """
961
+
962
+ auto_insert_metric_name: bool = True
963
+ """Whether to automatically insert the metric name in the checkpoint filename."""
964
+
965
+ every_n_train_steps: int | None = None
966
+ """
967
+ Number of training steps between checkpoints.
968
+ If None or 0, no checkpoints are saved during training.
969
+ """
970
+
971
+ train_time_interval: timedelta | None = None
972
+ """
973
+ Time interval between checkpoints during training.
974
+ If None, no checkpoints are saved during training based on time.
975
+ """
976
+
977
+ every_n_epochs: int | None = None
978
+ """
979
+ Number of epochs between checkpoints.
980
+ If None or 0, no checkpoints are saved at the end of epochs.
981
+ """
982
+
983
+ save_on_train_epoch_end: bool | None = None
984
+ """
985
+ Whether to run checkpointing at the end of the training epoch.
986
+ If False, checkpointing runs at the end of the validation.
987
+ """
988
+
989
+ enable_version_counter: bool = True
990
+ """Whether to append a version to the existing file name."""
991
+
992
+ auto_append_metric: bool = True
993
+ """If enabled, this will automatically add "-{monitor}" to the filename."""
994
+
995
+ @staticmethod
996
+ def _convert_string(input_string: str):
997
+ # Find all variables enclosed in curly braces
998
+ variables = re.findall(r"\{(.*?)\}", input_string)
999
+
1000
+ # Replace each variable with its corresponding key-value pair
1001
+ output_string = input_string
1002
+ for variable in variables:
1003
+ # If the name is something like {variable:format}, we shouldn't process the format.
1004
+ key_name = variable
1005
+ if ":" in variable:
1006
+ key_name, _ = variable.split(":", 1)
1007
+ continue
1008
+
1009
+ # Replace '/' with '_' in the key name
1010
+ key_name = key_name.replace("/", "_")
1011
+ output_string = output_string.replace(
1012
+ f"{{{variable}}}", f"{key_name}={{{variable}}}"
1013
+ )
1014
+
1015
+ return output_string
1016
+
1017
+ @override
1018
+ def construct_callbacks(self, root_config):
1019
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
1020
+
1021
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1022
+ root_config.id, "checkpoint"
1023
+ )
1024
+
1025
+ # If `monitor` is not provided, we can use `config.primary_metric` if it is set.
1026
+ monitor = self.monitor
1027
+ mode = self.mode
1028
+ if (
1029
+ monitor is None
1030
+ and (primary_metric := root_config.primary_metric) is not None
1031
+ ):
1032
+ monitor = primary_metric.validation_monitor
1033
+ mode = primary_metric.mode
1034
+
1035
+ filename = self.filename
1036
+ if self.auto_append_metric:
1037
+ if not filename:
1038
+ filename = "{epoch}-{step}"
1039
+ filename = f"{filename}-{{{monitor}}}"
1040
+
1041
+ if self.auto_insert_metric_name and filename:
1042
+ new_filename = self._convert_string(filename)
1043
+ log.critical(
1044
+ f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
1045
+ )
1046
+ filename = new_filename
1047
+
1048
+ yield ModelCheckpoint(
1049
+ dirpath=dirpath,
1050
+ filename=filename,
1051
+ monitor=monitor,
1052
+ mode=mode,
1053
+ verbose=self.verbose,
1054
+ save_last=self.save_last,
1055
+ save_top_k=self.save_top_k,
1056
+ save_weights_only=self.save_weights_only,
1057
+ auto_insert_metric_name=False,
1058
+ every_n_train_steps=self.every_n_train_steps,
1059
+ train_time_interval=self.train_time_interval,
1060
+ every_n_epochs=self.every_n_epochs,
1061
+ save_on_train_epoch_end=self.save_on_train_epoch_end,
1062
+ enable_version_counter=self.enable_version_counter,
1063
+ )
1064
+
1065
+
1066
+ class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
1067
+ kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
1068
+
1069
+ dirpath: str | Path | None = None
1070
+ """Directory path to save the checkpoint file."""
1071
+
1072
+ filename: str | None = None
1073
+ """Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
1074
+
1075
+ save_weights_only: bool = False
1076
+ """Whether to save only the model's weights or the entire model object."""
1077
+
1078
+ @override
1079
+ def construct_callbacks(self, root_config):
1080
+ from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
1081
+
1082
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1083
+ root_config.id, "checkpoint"
1084
+ )
1085
+
1086
+ yield LatestEpochCheckpoint(
1087
+ dirpath=dirpath,
1088
+ filename=self.filename,
1089
+ save_weights_only=self.save_weights_only,
1090
+ )
1091
+
1092
+
1093
+ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
1094
+ kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
1095
+
1096
+ dirpath: str | Path | None = None
1097
+ """Directory path to save the checkpoint file."""
1098
+
1099
+ filename: str | None = None
1100
+ """Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
1101
+
1102
+ @override
1103
+ def construct_callbacks(self, root_config):
1104
+ from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
1105
+
1106
+ dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
1107
+ root_config.id, "checkpoint"
1108
+ )
1109
+
1110
+ if not (filename := self.filename):
1111
+ filename = f"on_exception_{root_config.id}"
1112
+ yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
1113
+
1114
+
1115
+ CheckpointCallbackConfig: TypeAlias = Annotated[
1116
+ ModelCheckpointCallbackConfig
1117
+ | LatestEpochCheckpointCallbackConfig
1118
+ | OnExceptionCheckpointCallbackConfig,
1119
+ Field(discriminator="kind"),
1120
+ ]
1121
+
1122
+
1123
+ class CheckpointSavingConfig(CallbackConfigBase):
1124
+ enabled: bool = True
1125
+ """Enable checkpoint saving."""
1126
+
1127
+ checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
1128
+ ModelCheckpointCallbackConfig(),
1129
+ LatestEpochCheckpointCallbackConfig(),
1130
+ OnExceptionCheckpointCallbackConfig(),
1131
+ ]
1132
+ """Checkpoint callback configurations."""
1133
+
1134
+ def disable_(self):
1135
+ self.enabled = False
1136
+ return self
1137
+
1138
+ def should_save_checkpoints(self, root_config: "BaseConfig"):
1139
+ if not self.enabled:
1140
+ return False
1141
+
1142
+ if root_config.trainer.fast_dev_run:
1143
+ return False
1144
+
1145
+ return True
1146
+
1147
+ @property
1148
+ def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
1149
+ return next(
1150
+ (
1151
+ callback
1152
+ for callback in self.checkpoint_callbacks
1153
+ if isinstance(callback, ModelCheckpointCallbackConfig)
1154
+ ),
1155
+ )
1156
+
1157
+ @property
1158
+ def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
1159
+ return next(
1160
+ (
1161
+ callback
1162
+ for callback in self.checkpoint_callbacks
1163
+ if isinstance(callback, LatestEpochCheckpointCallbackConfig)
1164
+ ),
1165
+ )
1166
+
1167
+ @property
1168
+ def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
1169
+ return next(
1170
+ (
1171
+ callback
1172
+ for callback in self.checkpoint_callbacks
1173
+ if isinstance(callback, OnExceptionCheckpointCallbackConfig)
1174
+ ),
1175
+ )
1176
+
1177
+ @override
1178
+ def construct_callbacks(self, root_config: "BaseConfig"):
1179
+ if not self.should_save_checkpoints(root_config):
1180
+ return
1181
+
1182
+ for callback_config in self.checkpoint_callbacks:
1183
+ yield from callback_config.construct_callbacks(root_config)
1184
+
1185
+
1186
+ class LightningTrainerKwargs(TypedDict, total=False):
1187
+ accelerator: str | Accelerator
1188
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
1189
+ as well as custom accelerator instances."""
1190
+
1191
+ strategy: str | Strategy
1192
+ """Supports different training strategies with aliases as well custom strategies.
1193
+ Default: ``"auto"``.
1194
+ """
1195
+
1196
+ devices: list[int] | str | int
1197
+ """The devices to use. Can be set to a positive number (int or str), a sequence of device indices
1198
+ (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
1199
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
1200
+ """
1201
+
1202
+ num_nodes: int
1203
+ """Number of GPU nodes for distributed training.
1204
+ Default: ``1``.
1205
+ """
1206
+
1207
+ precision: _PRECISION_INPUT | None
1208
+ """Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
1209
+ 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
1210
+ Can be used on CPU, GPU, TPUs, HPUs or IPUs.
1211
+ Default: ``'32-true'``.
1212
+ """
1213
+
1214
+ logger: Logger | Iterable[Logger] | bool | None
1215
+ """Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
1216
+ the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
1217
+ ``False`` will disable logging. If multiple loggers are provided, local files
1218
+ (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
1219
+ Default: ``True``.
1220
+ """
1221
+
1222
+ callbacks: list[Callback] | Callback | None
1223
+ """Add a callback or list of callbacks.
1224
+ Default: ``None``.
1225
+ """
1226
+
1227
+ fast_dev_run: int | bool
1228
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
1229
+ of train, val and test to find any bugs (ie: a sort of unit test).
1230
+ Default: ``False``.
1231
+ """
1232
+
1233
+ max_epochs: int | None
1234
+ """Stop training once this number of epochs is reached. Disabled by default (None).
1235
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
1236
+ To enable infinite training, set ``max_epochs = -1``.
1237
+ """
1238
+
1239
+ min_epochs: int | None
1240
+ """Force training for at least these many epochs. Disabled by default (None).
1241
+ """
1242
+
1243
+ max_steps: int
1244
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
1245
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
1246
+ ``max_epochs`` to ``-1``.
1247
+ """
1248
+
1249
+ min_steps: int | None
1250
+ """Force training for at least these number of steps. Disabled by default (``None``).
1251
+ """
1252
+
1253
+ max_time: str | timedelta | dict[str, int] | None
1254
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
1255
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
1256
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
1257
+ :class:`datetime.timedelta`.
1258
+ """
1259
+
1260
+ limit_train_batches: int | float | None
1261
+ """How much of training dataset to check (float = fraction, int = num_batches).
1262
+ Default: ``1.0``.
1263
+ """
1264
+
1265
+ limit_val_batches: int | float | None
1266
+ """How much of validation dataset to check (float = fraction, int = num_batches).
1267
+ Default: ``1.0``.
1268
+ """
1269
+
1270
+ limit_test_batches: int | float | None
1271
+ """How much of test dataset to check (float = fraction, int = num_batches).
1272
+ Default: ``1.0``.
1273
+ """
1274
+
1275
+ limit_predict_batches: int | float | None
1276
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
1277
+ Default: ``1.0``.
1278
+ """
1279
+
1280
+ overfit_batches: int | float
1281
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
1282
+ Default: ``0.0``.
1283
+ """
1284
+
1285
+ val_check_interval: int | float | None
1286
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
1287
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
1288
+ batches. An ``int`` value can only be higher than the number of training batches when
1289
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
1290
+ across epochs or during iteration-based training.
1291
+ Default: ``1.0``.
1292
+ """
1293
+
1294
+ check_val_every_n_epoch: int | None
1295
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
1296
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
1297
+ to be an integer value.
1298
+ Default: ``1``.
1299
+ """
1300
+
1301
+ num_sanity_val_steps: int | None
1302
+ """Sanity check runs n validation batches before starting the training routine.
1303
+ Set it to `-1` to run all batches in all validation dataloaders.
1304
+ Default: ``2``.
1305
+ """
1306
+
1307
+ log_every_n_steps: int | None
1308
+ """How often to log within steps.
1309
+ Default: ``50``.
1310
+ """
1311
+
1312
+ enable_checkpointing: bool | None
1313
+ """If ``True``, enable checkpointing.
1314
+ It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
1315
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
1316
+ Default: ``True``.
1317
+ """
1318
+
1319
+ enable_progress_bar: bool | None
1320
+ """Whether to enable to progress bar by default.
1321
+ Default: ``True``.
1322
+ """
1323
+
1324
+ enable_model_summary: bool | None
1325
+ """Whether to enable model summarization by default.
1326
+ Default: ``True``.
1327
+ """
1328
+
1329
+ accumulate_grad_batches: int
1330
+ """Accumulates gradients over k batches before stepping the optimizer.
1331
+ Default: 1.
1332
+ """
1333
+
1334
+ gradient_clip_val: int | float | None
1335
+ """The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
1336
+ gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
1337
+ Default: ``None``.
1338
+ """
1339
+
1340
+ gradient_clip_algorithm: str | None
1341
+ """The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
1342
+ to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
1343
+ be set to ``"norm"``.
1344
+ """
1345
+
1346
+ deterministic: bool | Literal["warn"] | None
1347
+ """If ``True``, sets whether PyTorch operations must use deterministic algorithms.
1348
+ Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
1349
+ that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
1350
+ """
1351
+
1352
+ benchmark: bool | None
1353
+ """The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
1354
+ The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
1355
+ (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
1356
+ is set to ``True``, this will default to ``False``. Override to manually set a different value.
1357
+ Default: ``None``.
1358
+ """
1359
+
1360
+ inference_mode: bool
1361
+ """Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
1362
+ evaluation (``validate``/``test``/``predict``).
1363
+ """
1364
+
1365
+ use_distributed_sampler: bool
1366
+ """Whether to wrap the DataLoader's sampler with
1367
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
1368
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
1369
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
1370
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
1371
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
1372
+ we don't do this automatically.
1373
+ """
1374
+
1375
+ profiler: Profiler | str | None
1376
+ """To profile individual steps during training and assist in identifying bottlenecks.
1377
+ Default: ``None``.
1378
+ """
1379
+
1380
+ detect_anomaly: bool
1381
+ """Enable anomaly detection for the autograd engine.
1382
+ Default: ``False``.
1383
+ """
1384
+
1385
+ barebones: bool
1386
+ """Whether to run in "barebones mode", where all features that may impact raw speed are
1387
+ disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
1388
+ runs. The following features are deactivated:
1389
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
1390
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
1391
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
1392
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
1393
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
1394
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
1395
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
1396
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
1397
+ :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
1398
+ :meth:`~lightning.pytorch.core.LightningModule.log`,
1399
+ :meth:`~lightning.pytorch.core.LightningModule.log_dict`.
1400
+ """
1401
+
1402
+ plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
1403
+ """Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
1404
+ Default: ``None``.
1405
+ """
1406
+
1407
+ sync_batchnorm: bool
1408
+ """Synchronize batch norm layers between process groups/whole world.
1409
+ Default: ``False``.
1410
+ """
1411
+
1412
+ reload_dataloaders_every_n_epochs: int
1413
+ """Set to a positive integer to reload dataloaders every n epochs.
1414
+ Default: ``0``.
1415
+ """
1416
+
1417
+ default_root_dir: Path | None
1418
+ """Default path for logs and weights when no logger/ckpt_callback passed.
1419
+ Default: ``os.getcwd()``.
1420
+ Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
1421
+ """
1422
+
1423
+
1424
+ class EarlyStoppingConfig(CallbackConfigBase):
1425
+ monitor: str | None = None
1426
+ """
1427
+ The metric to monitor for early stopping.
1428
+ If None, the primary metric will be used.
1429
+ """
1430
+
1431
+ mode: Literal["min", "max"] | None = None
1432
+ """
1433
+ The mode for the metric to monitor for early stopping.
1434
+ If None, the primary metric mode will be used.
1435
+ """
1436
+
1437
+ patience: int
1438
+ """
1439
+ Number of epochs with no improvement after which training will be stopped.
1440
+ """
1441
+
1442
+ min_delta: float = 1.0e-8
1443
+ """
1444
+ Minimum change in the monitored quantity to qualify as an improvement.
1445
+ """
1446
+
1447
+ min_lr: float | None = None
1448
+ """
1449
+ Minimum learning rate. If the learning rate of the model is less than this value,
1450
+ the training will be stopped.
1451
+ """
1452
+
1453
+ strict: bool = True
1454
+ """
1455
+ Whether to enforce that the monitored quantity must improve by at least `min_delta`
1456
+ to qualify as an improvement.
1457
+ """
1458
+
1459
+ @override
1460
+ def construct_callbacks(self, root_config: "BaseConfig"):
1461
+ from ..callbacks.early_stopping import EarlyStopping
1462
+
1463
+ monitor = self.monitor
1464
+ mode = self.mode
1465
+ if monitor is None:
1466
+ assert mode is None, "If `monitor` is not provided, `mode` must be None."
1467
+
1468
+ primary_metric = root_config.primary_metric
1469
+ if primary_metric is None:
1470
+ raise ValueError(
1471
+ "No primary metric is set, so `monitor` must be provided in `early_stopping`."
1472
+ )
1473
+ monitor = primary_metric.validation_monitor
1474
+ mode = primary_metric.mode
1475
+
1476
+ if mode is None:
1477
+ mode = "min"
1478
+
1479
+ return [
1480
+ EarlyStopping(
1481
+ monitor=monitor,
1482
+ mode=mode,
1483
+ patience=self.patience,
1484
+ min_delta=self.min_delta,
1485
+ min_lr=self.min_lr,
1486
+ strict=self.strict,
1487
+ )
1488
+ ]
1489
+
1490
+
1491
+ class ActSaveConfig(CallbackConfigBase):
1492
+ enabled: bool = True
1493
+ """Enable activation saving."""
1494
+
1495
+ auto_save_logged_metrics: bool = False
1496
+ """If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
1497
+
1498
+ save_dir: Path | None = None
1499
+ """Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
1500
+
1501
+ def __bool__(self):
1502
+ return self.enabled
1503
+
1504
+ def resolve_save_dir(self, root_config: "BaseConfig"):
1505
+ if self.save_dir is not None:
1506
+ return self.save_dir
1507
+
1508
+ return root_config.directory.resolve_subdirectory(root_config.id, "activation")
1509
+
1510
+ @override
1511
+ def construct_callbacks(self, root_config):
1512
+ from ..actsave import ActSaveCallback
1513
+
1514
+ return [ActSaveCallback()]
1515
+
1516
+
1517
+ class SanityCheckingConfig(TypedConfig):
1518
+ reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1519
+ """
1520
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
1521
+ - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
1522
+ - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
1523
+ Valid values are: "disable", "warn", "error".
1524
+ """
1525
+
1526
+
1527
+ class TrainerConfig(TypedConfig):
1528
+ checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
1529
+ """Checkpoint loading configuration options."""
1530
+
1531
+ checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
1532
+ """Checkpoint saving configuration options."""
1533
+
1534
+ logging: LoggingConfig = LoggingConfig()
1535
+ """Logging/experiment tracking (e.g., WandB) configuration options."""
1536
+
1537
+ optimizer: OptimizationConfig = OptimizationConfig()
1538
+ """Optimization configuration options."""
1539
+
1540
+ reproducibility: ReproducibilityConfig = ReproducibilityConfig()
1541
+ """Reproducibility configuration options."""
1542
+
1543
+ sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
1544
+ """Sanity checking configuration options."""
1545
+
1546
+ actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
1547
+ """Activation saving configuration options."""
1548
+
1549
+ early_stopping: EarlyStoppingConfig | None = None
1550
+ """Early stopping configuration options."""
1551
+
1552
+ profiler: ProfilerConfig | None = None
1553
+ """
1554
+ To profile individual steps during training and assist in identifying bottlenecks.
1555
+ Default: ``None``.
1556
+ """
1557
+
1558
+ callbacks: list[CallbackConfig] = []
1559
+ """Callbacks to use during training."""
1560
+
1561
+ detect_anomaly: bool | None = None
1562
+ """Enable anomaly detection for the autograd engine.
1563
+ Default: ``False``.
1564
+ """
1565
+
1566
+ plugins: list[PluginConfigProtocol] | None = None
1567
+ """
1568
+ Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
1569
+ Default: ``None``.
1570
+ """
1571
+
1572
+ auto_determine_num_nodes: bool = True
1573
+ """
1574
+ If enabled, will automatically determine the number of nodes for distributed training.
1575
+
1576
+ This will only work on:
1577
+ - SLURM clusters
1578
+ - LSF clusters
1579
+ """
1580
+
1581
+ fast_dev_run: int | bool = False
1582
+ """Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
1583
+ of train, val and test to find any bugs (ie: a sort of unit test).
1584
+ Default: ``False``.
1585
+ """
1586
+
1587
+ precision: (
1588
+ Literal[
1589
+ "64-true",
1590
+ "32-true",
1591
+ "fp16-mixed",
1592
+ "bf16-mixed",
1593
+ "16-mixed-auto",
1594
+ ]
1595
+ | None
1596
+ ) = None
1597
+ """
1598
+ Training precision. Can be one of:
1599
+ - "64-true": Double precision (64-bit).
1600
+ - "32-true": Full precision (32-bit).
1601
+ - "fp16-mixed": Float16 mixed precision.
1602
+ - "bf16-mixed": BFloat16 mixed precision.
1603
+ - "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
1604
+ """
1605
+
1606
+ max_epochs: int | None = None
1607
+ """Stop training once this number of epochs is reached. Disabled by default (None).
1608
+ If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
1609
+ To enable infinite training, set ``max_epochs = -1``.
1610
+ """
1611
+
1612
+ min_epochs: int | None = None
1613
+ """Force training for at least these many epochs. Disabled by default (None).
1614
+ """
1615
+
1616
+ max_steps: int = -1
1617
+ """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
1618
+ and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
1619
+ ``max_epochs`` to ``-1``.
1620
+ """
1621
+
1622
+ min_steps: int | None = None
1623
+ """Force training for at least these number of steps. Disabled by default (``None``).
1624
+ """
1625
+
1626
+ max_time: str | timedelta | dict[str, int] | None = None
1627
+ """Stop training after this amount of time has passed. Disabled by default (``None``).
1628
+ The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
1629
+ :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
1630
+ :class:`datetime.timedelta`.
1631
+ """
1632
+
1633
+ limit_train_batches: int | float | None = None
1634
+ """How much of training dataset to check (float = fraction, int = num_batches).
1635
+ Default: ``1.0``.
1636
+ """
1637
+
1638
+ limit_val_batches: int | float | None = None
1639
+ """How much of validation dataset to check (float = fraction, int = num_batches).
1640
+ Default: ``1.0``.
1641
+ """
1642
+
1643
+ limit_test_batches: int | float | None = None
1644
+ """How much of test dataset to check (float = fraction, int = num_batches).
1645
+ Default: ``1.0``.
1646
+ """
1647
+
1648
+ limit_predict_batches: int | float | None = None
1649
+ """How much of prediction dataset to check (float = fraction, int = num_batches).
1650
+ Default: ``1.0``.
1651
+ """
1652
+
1653
+ overfit_batches: int | float = 0.0
1654
+ """Overfit a fraction of training/validation data (float) or a set number of batches (int).
1655
+ Default: ``0.0``.
1656
+ """
1657
+
1658
+ val_check_interval: int | float | None = None
1659
+ """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
1660
+ after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
1661
+ batches. An ``int`` value can only be higher than the number of training batches when
1662
+ ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
1663
+ across epochs or during iteration-based training.
1664
+ Default: ``1.0``.
1665
+ """
1666
+
1667
+ check_val_every_n_epoch: int | None = 1
1668
+ """Perform a validation loop every after every `N` training epochs. If ``None``,
1669
+ validation will be done solely based on the number of training batches, requiring ``val_check_interval``
1670
+ to be an integer value.
1671
+ Default: ``1``.
1672
+ """
1673
+
1674
+ num_sanity_val_steps: int | None = None
1675
+ """Sanity check runs n validation batches before starting the training routine.
1676
+ Set it to `-1` to run all batches in all validation dataloaders.
1677
+ Default: ``2``.
1678
+ """
1679
+
1680
+ log_every_n_steps: int | None = None
1681
+ """How often to log within steps.
1682
+ Default: ``50``.
1683
+ """
1684
+
1685
+ inference_mode: bool = True
1686
+ """Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
1687
+ Default: ``True``.
1688
+ """
1689
+
1690
+ use_distributed_sampler: bool | None = None
1691
+ """Whether to wrap the DataLoader's sampler with
1692
+ :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
1693
+ strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
1694
+ ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
1695
+ ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
1696
+ sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
1697
+ we don't do this automatically.
1698
+ Default: ``True``.
1699
+ """
1700
+
1701
+ accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
1702
+ """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
1703
+ as well as custom accelerator instances.
1704
+ Default: ``"auto"``.
1705
+ """
1706
+
1707
+ strategy: StrategyConfigProtocol | StrategyLiteral | None = None
1708
+ """Supports different training strategies with aliases as well custom strategies.
1709
+ Default: ``"auto"``.
1710
+ """
1711
+
1712
+ devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
1713
+ """The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
1714
+ automatic selection based on the chosen accelerator. Default: ``"auto"``.
1715
+ """
1716
+
1717
+ auto_wrap_trainer: bool = True
1718
+ """If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
1719
+ auto_set_default_root_dir: bool = True
1720
+ """If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
1721
+ supports_shared_parameters: bool = True
1722
+ """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
1723
+
1724
+ lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
1725
+ """
1726
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
1727
+
1728
+ Please refer to the Lightning documentation for a list of valid keyword arguments.
1729
+ """
1730
+
1731
+ additional_lightning_kwargs: dict[str, Any] = {}
1732
+ """
1733
+ Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
1734
+
1735
+ This is essentially a non-type-checked version of `lightning_kwargs`.
1736
+ """
1737
+
1738
+ set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
1739
+ """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1740
+
1741
+
1742
+ class SeedConfig(TypedConfig):
1743
+ seed: int
1744
+ """Seed for the random number generator."""
1745
+
1746
+ seed_workers: bool = False
1747
+ """Whether to seed the workers of the dataloader."""
1748
+
1749
+
1750
+ Signal: TypeAlias = Literal[
1751
+ "SIGHUP",
1752
+ "SIGINT",
1753
+ "SIGQUIT",
1754
+ "SIGILL",
1755
+ "SIGTRAP",
1756
+ "SIGABRT",
1757
+ "SIGBUS",
1758
+ "SIGFPE",
1759
+ "SIGKILL",
1760
+ "SIGUSR1",
1761
+ "SIGSEGV",
1762
+ "SIGUSR2",
1763
+ "SIGPIPE",
1764
+ "SIGALRM",
1765
+ "SIGTERM",
1766
+ "SIGCHLD",
1767
+ "SIGCONT",
1768
+ "SIGSTOP",
1769
+ "SIGTSTP",
1770
+ "SIGTTIN",
1771
+ "SIGTTOU",
1772
+ "SIGURG",
1773
+ "SIGXCPU",
1774
+ "SIGXFSZ",
1775
+ "SIGVTALRM",
1776
+ "SIGPROF",
1777
+ "SIGWINCH",
1778
+ "SIGIO",
1779
+ "SIGPWR",
1780
+ "SIGSYS",
1781
+ "SIGRTMIN",
1782
+ "SIGRTMAX",
1783
+ ]
1784
+
1785
+
1786
+ class SubmitConfig(TypedConfig):
1787
+ auto_requeue_signals: list[Signal] = [
1788
+ # "SIGUSR1",
1789
+ # On SIGURG:
1790
+ # Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
1791
+ # We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
1792
+ "SIGURG"
1793
+ ]
1794
+ """Signals that will trigger an automatic requeue of the job."""
1795
+
1796
+ def _resolved_auto_requeue_signals(self) -> list[signal.Signals]:
1797
+ return [getattr(signal.Signals, sig) for sig in self.auto_requeue_signals]
1798
+
1799
+
1800
+ class RunnerConfig(TypedConfig):
1801
+ python_logging: PythonLogging = PythonLogging()
1802
+ """Python logging configuration options."""
1803
+
1804
+ seed: SeedConfig = SeedConfig(seed=0)
1805
+ """Seed everything configuration options."""
1806
+
1807
+ submit: SubmitConfig = SubmitConfig()
1808
+ """Submit (e.g., SLURM or LSF) configuration options."""
1809
+
1810
+ dump_run_information: bool = True
1811
+ """
1812
+ If enabled, will dump different bits of run information to the output directory before starting the run.
1813
+ This includes:
1814
+ - Run config
1815
+ - Full set of environment variables
1816
+ """
1817
+
1818
+ additional_env_vars: dict[str, str] = {}
1819
+ """Additional environment variables to set when running the script."""
1820
+
1821
+
1822
+ class MetricConfig(TypedConfig):
1823
+ name: str
1824
+ """The name of the primary metric."""
1825
+
1826
+ mode: Literal["min", "max"]
1827
+ """
1828
+ The mode of the primary metric:
1829
+ - "min" for metrics that should be minimized (e.g., loss)
1830
+ - "max" for metrics that should be maximized (e.g., accuracy)
1831
+ """
1832
+
1833
+ @property
1834
+ def validation_monitor(self) -> str:
1835
+ return f"val/{self.name}"
1836
+
1837
+ def __post_init__(self):
1838
+ for split in ("train", "val", "test", "predict"):
1839
+ if self.name.startswith(f"{split}/"):
1840
+ raise ValueError(
1841
+ f"Primary metric name should not start with '{split}/'. "
1842
+ f"Just use '{self.name[len(split) + 1:]}' instead. "
1843
+ "The split name is automatically added depending on the context."
1844
+ )
1845
+
1846
+ @classmethod
1847
+ def loss(cls, mode: Literal["min", "max"] = "min"):
1848
+ return cls(name="loss", mode=mode)
1849
+
1850
+
1851
+ PrimaryMetricConfig: TypeAlias = MetricConfig
1852
+
1853
+
1854
+ class BaseConfig(TypedConfig):
1855
+ id: str = Field(default_factory=lambda: BaseConfig.generate_id())
1856
+ """ID of the run."""
1857
+ name: str | None = None
1858
+ """Run name."""
1859
+ name_parts: list[str] = []
1860
+ """A list of parts used to construct the run name. This is useful for constructing the run name dynamically."""
1861
+ project: str | None = None
1862
+ """Project name."""
1863
+ tags: list[str] = []
1864
+ """Tags for the run."""
1865
+ notes: list[str] = []
1866
+ """Human readable notes for the run."""
1867
+
1868
+ debug: bool = False
1869
+ """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
1870
+ environment: Annotated[EnvironmentConfig, Field(repr=False)] = EnvironmentConfig()
1871
+ """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
1872
+
1873
+ directory: DirectoryConfig = DirectoryConfig()
1874
+ """Directory configuration options."""
1875
+ trainer: TrainerConfig = TrainerConfig()
1876
+ """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
1877
+ runner: RunnerConfig = RunnerConfig()
1878
+ """`ll.Runner` configuration options."""
1879
+
1880
+ primary_metric: PrimaryMetricConfig | None = None
1881
+ """Primary metric configuration options. This is used in the following ways:
1882
+ - To determine the best model checkpoint to save with the ModelCheckpoint callback.
1883
+ - To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
1884
+ - For the ReduceLROnPlateau scheduler.
1885
+ """
1886
+
1887
+ meta: dict[str, Any] = {}
1888
+ """Additional metadata for this run. This can be used to store arbitrary data that is not part of the config schema."""
1889
+
1890
+ @property
1891
+ def run_name(self) -> str:
1892
+ parts = self.name_parts.copy()
1893
+ if self.name is not None:
1894
+ parts = [self.name] + parts
1895
+ name = "-".join(parts)
1896
+ if not name:
1897
+ name = self.id
1898
+ return name
1899
+
1900
+ def clone(self, with_new_id: bool = True) -> Self:
1901
+ c = copy.deepcopy(self)
1902
+ if with_new_id:
1903
+ c.id = BaseConfig.generate_id()
1904
+ return c
1905
+
1906
+ def subdirectory(self, subdirectory: str) -> Path:
1907
+ return self.directory.resolve_subdirectory(self.id, subdirectory)
1908
+
1909
+ # region Helper methods
1910
+ def with_project_root_(self, project_root: str | Path | os.PathLike) -> Self:
1911
+ """
1912
+ Set the project root directory for the trainer.
1913
+
1914
+ Args:
1915
+ project_root (Path): The base directory to use.
1916
+
1917
+ Returns:
1918
+ self: The current instance of the class.
1919
+ """
1920
+ self.directory.project_root = Path(project_root)
1921
+ return self
1922
+
1923
+ def reset_(
1924
+ self,
1925
+ *,
1926
+ id: bool = True,
1927
+ basic: bool = True,
1928
+ project_root: bool = True,
1929
+ environment: bool = True,
1930
+ meta: bool = True,
1931
+ ):
1932
+ """
1933
+ Reset the configuration object to its initial state.
1934
+
1935
+ Parameters:
1936
+ - id (bool): If True, generate a new ID for the configuration object.
1937
+ - basic (bool): If True, reset basic attributes like name, project, tags, and notes.
1938
+ - project_root (bool): If True, reset the directory configuration to its initial state.
1939
+ - environment (bool): If True, reset the environment configuration to its initial state.
1940
+ - meta (bool): If True, reset the meta dictionary to an empty dictionary.
1941
+
1942
+ Returns:
1943
+ - self: The updated configuration object.
1944
+
1945
+ """
1946
+ if id:
1947
+ self.id = self.generate_id()
1948
+
1949
+ if basic:
1950
+ self.name = None
1951
+ self.name_parts = []
1952
+ self.project = None
1953
+ self.tags = []
1954
+ self.notes = []
1955
+
1956
+ if project_root:
1957
+ self.directory = DirectoryConfig()
1958
+
1959
+ if environment:
1960
+ self.environment = EnvironmentConfig()
1961
+
1962
+ if meta:
1963
+ self.meta = {}
1964
+
1965
+ return self
1966
+
1967
+ def concise_repr(self) -> str:
1968
+ """Get a concise representation of the configuration object."""
1969
+
1970
+ def _truncate(s: str, max_len: int = 50):
1971
+ return s if len(s) <= max_len else f"{s[:max_len - 3]}..."
1972
+
1973
+ cls_name = self.__class__.__name__
1974
+
1975
+ parts: list[str] = []
1976
+ parts.append(f"name={self.run_name}")
1977
+ if self.project:
1978
+ parts.append(f"project={_truncate(self.project)}")
1979
+
1980
+ return f"{cls_name}({', '.join(parts)})"
1981
+
1982
+ # endregion
1983
+
1984
+ # region Seeding
1985
+
1986
+ _rng: ClassVar[np.random.Generator | None] = None
1987
+
1988
+ @staticmethod
1989
+ def generate_id(
1990
+ *,
1991
+ length: int = 8,
1992
+ ignore_rng: bool = False,
1993
+ ) -> str:
1994
+ """
1995
+ Generate a random ID of specified length.
1996
+
1997
+ Args:
1998
+ length (int): The length of the generated ID. Default is 8.
1999
+ ignore_rng (bool): If True, ignore the global random number generator and use a new one. Default is False.
2000
+
2001
+ Returns:
2002
+ str: The generated random ID.
2003
+
2004
+ Raises:
2005
+ IdSeedWarning: If the global random number generator is None and ignore_rng is False.
2006
+
2007
+ Notes:
2008
+ - The generated IDs will not be reproducible if the global random number generator is None and ignore_rng is False.
2009
+ - To ensure reproducibility, call BaseConfig.set_seed(...) before generating any IDs.
2010
+ """
2011
+ rng = BaseConfig._rng if not ignore_rng else np.random.default_rng()
2012
+ if rng is None:
2013
+ warnings.warn(
2014
+ "BaseConfig._rng is None. The generated IDs will not be reproducible. "
2015
+ + "To fix this, call BaseConfig.set_seed(...) before generating any IDs.",
2016
+ category=IdSeedWarning,
2017
+ )
2018
+ rng = np.random.default_rng()
2019
+
2020
+ alphabet = list(string.ascii_lowercase + string.digits)
2021
+
2022
+ id = "".join(rng.choice(alphabet) for _ in range(length))
2023
+ return id
2024
+
2025
+ @staticmethod
2026
+ def set_seed(seed: int | None = None) -> None:
2027
+ """
2028
+ Set the seed for the random number generator.
2029
+
2030
+ Args:
2031
+ seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
2032
+
2033
+ Returns:
2034
+ None
2035
+ """
2036
+ if seed is None:
2037
+ seed = int(time.time() * 1000)
2038
+ log.critical(f"Seeding BaseConfig with seed {seed}")
2039
+ BaseConfig._rng = np.random.default_rng(seed)
2040
+
2041
+ # endregion
2042
+
2043
+ @classmethod
2044
+ def from_checkpoint(
2045
+ cls,
2046
+ path: str | Path,
2047
+ *,
2048
+ hparams_key: str = "hyper_parameters",
2049
+ ):
2050
+ ckpt = torch.load(path)
2051
+ if (hparams := ckpt.get(hparams_key)) is None:
2052
+ raise ValueError(
2053
+ f"The checkpoint does not contain the `{hparams_key}` attribute. "
2054
+ "Are you sure this is a valid Lightning checkpoint?"
2055
+ )
2056
+ return cls.model_validate(hparams)
2057
+
2058
+ def ll_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
2059
+ yield self.trainer.actsave
2060
+ yield self.trainer.early_stopping
2061
+ yield self.trainer.checkpoint_saving
2062
+ yield self.trainer.logging
2063
+ yield self.trainer.optimizer
2064
+ yield from self.trainer.callbacks