nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,2064 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import signal
|
|
5
|
+
import socket
|
|
6
|
+
import string
|
|
7
|
+
import time
|
|
8
|
+
import warnings
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from collections.abc import Iterable, Sequence
|
|
11
|
+
from datetime import timedelta
|
|
12
|
+
from logging import getLogger
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import (
|
|
15
|
+
Annotated,
|
|
16
|
+
Any,
|
|
17
|
+
ClassVar,
|
|
18
|
+
Literal,
|
|
19
|
+
Protocol,
|
|
20
|
+
TypeAlias,
|
|
21
|
+
runtime_checkable,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import torch
|
|
26
|
+
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
|
27
|
+
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
28
|
+
from lightning.pytorch.accelerators import Accelerator
|
|
29
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
30
|
+
from lightning.pytorch.loggers import Logger
|
|
31
|
+
from lightning.pytorch.plugins import _PLUGIN_INPUT
|
|
32
|
+
from lightning.pytorch.plugins.layer_sync import LayerSync
|
|
33
|
+
from lightning.pytorch.plugins.precision.precision import Precision
|
|
34
|
+
from lightning.pytorch.profilers import Profiler
|
|
35
|
+
from lightning.pytorch.strategies.strategy import Strategy
|
|
36
|
+
from pydantic import DirectoryPath
|
|
37
|
+
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
38
|
+
|
|
39
|
+
from ..callbacks import CallbackConfig
|
|
40
|
+
from ..callbacks.base import CallbackConfigBase
|
|
41
|
+
from ..callbacks.wandb_watch import WandbWatchConfig
|
|
42
|
+
from ..config import Field, TypedConfig
|
|
43
|
+
from ..util.slurm import parse_slurm_node_list
|
|
44
|
+
|
|
45
|
+
log = getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class IdSeedWarning(Warning):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class BaseProfilerConfig(TypedConfig, ABC):
|
|
53
|
+
dirpath: str | Path | None = None
|
|
54
|
+
"""
|
|
55
|
+
Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
|
|
56
|
+
``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
|
|
57
|
+
will be used.
|
|
58
|
+
"""
|
|
59
|
+
filename: str | None = None
|
|
60
|
+
"""
|
|
61
|
+
If present, filename where the profiler results will be saved instead of printing to stdout.
|
|
62
|
+
The ``.txt`` extension will be used automatically.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def construct_profiler(self, root_config: "BaseConfig") -> Profiler: ...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SimpleProfilerConfig(BaseProfilerConfig):
|
|
70
|
+
kind: Literal["simple"] = "simple"
|
|
71
|
+
|
|
72
|
+
extended: bool = True
|
|
73
|
+
"""
|
|
74
|
+
If ``True``, adds extra columns representing number of calls and percentage of
|
|
75
|
+
total time spent onrespective action.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def construct_profiler(self, root_config):
|
|
80
|
+
from lightning.pytorch.profilers.simple import SimpleProfiler
|
|
81
|
+
|
|
82
|
+
if (dirpath := self.dirpath) is None:
|
|
83
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
84
|
+
root_config.id, "profile"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if (filename := self.filename) is None:
|
|
88
|
+
filename = f"{root_config.id}_profile.txt"
|
|
89
|
+
|
|
90
|
+
return SimpleProfiler(
|
|
91
|
+
extended=self.extended,
|
|
92
|
+
dirpath=dirpath,
|
|
93
|
+
filename=filename,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
98
|
+
kind: Literal["advanced"] = "advanced"
|
|
99
|
+
|
|
100
|
+
line_count_restriction: float = 1.0
|
|
101
|
+
"""
|
|
102
|
+
This can be used to limit the number of functions
|
|
103
|
+
reported for each action. either an integer (to select a count of lines),
|
|
104
|
+
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
@override
|
|
108
|
+
def construct_profiler(self, root_config):
|
|
109
|
+
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
|
110
|
+
|
|
111
|
+
if (dirpath := self.dirpath) is None:
|
|
112
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
113
|
+
root_config.id, "profile"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if (filename := self.filename) is None:
|
|
117
|
+
filename = f"{root_config.id}_profile.txt"
|
|
118
|
+
|
|
119
|
+
return AdvancedProfiler(
|
|
120
|
+
line_count_restriction=self.line_count_restriction,
|
|
121
|
+
dirpath=dirpath,
|
|
122
|
+
filename=filename,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
127
|
+
kind: Literal["pytorch"] = "pytorch"
|
|
128
|
+
|
|
129
|
+
group_by_input_shapes: bool = False
|
|
130
|
+
"""Include operator input shapes and group calls by shape."""
|
|
131
|
+
|
|
132
|
+
emit_nvtx: bool = False
|
|
133
|
+
"""
|
|
134
|
+
Context manager that makes every autograd operation emit an NVTX range
|
|
135
|
+
Run::
|
|
136
|
+
|
|
137
|
+
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
|
|
138
|
+
|
|
139
|
+
To visualize, you can either use::
|
|
140
|
+
|
|
141
|
+
nvvp trace_name.prof
|
|
142
|
+
torch.autograd.profiler.load_nvprof(path)
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
export_to_chrome: bool = True
|
|
146
|
+
"""
|
|
147
|
+
Whether to export the sequence of profiled operators for Chrome.
|
|
148
|
+
It will generate a ``.json`` file which can be read by Chrome.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
row_limit: int = 20
|
|
152
|
+
"""
|
|
153
|
+
Limit the number of rows in a table, ``-1`` is a special value that
|
|
154
|
+
removes the limit completely.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
sort_by_key: str | None = None
|
|
158
|
+
"""
|
|
159
|
+
Attribute used to sort entries. By default
|
|
160
|
+
they are printed in the same order as they were registered.
|
|
161
|
+
Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
|
|
162
|
+
``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
|
|
163
|
+
``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
record_module_names: bool = True
|
|
167
|
+
"""Whether to add module names while recording autograd operation."""
|
|
168
|
+
|
|
169
|
+
table_kwargs: dict[str, Any] | None = None
|
|
170
|
+
"""Dictionary with keyword arguments for the summary table."""
|
|
171
|
+
|
|
172
|
+
additional_profiler_kwargs: dict[str, Any] = {}
|
|
173
|
+
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
|
174
|
+
|
|
175
|
+
@override
|
|
176
|
+
def construct_profiler(self, root_config):
|
|
177
|
+
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
|
178
|
+
|
|
179
|
+
if (dirpath := self.dirpath) is None:
|
|
180
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
181
|
+
root_config.id, "profile"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if (filename := self.filename) is None:
|
|
185
|
+
filename = f"{root_config.id}_profile.txt"
|
|
186
|
+
|
|
187
|
+
return PyTorchProfiler(
|
|
188
|
+
group_by_input_shapes=self.group_by_input_shapes,
|
|
189
|
+
emit_nvtx=self.emit_nvtx,
|
|
190
|
+
export_to_chrome=self.export_to_chrome,
|
|
191
|
+
row_limit=self.row_limit,
|
|
192
|
+
sort_by_key=self.sort_by_key,
|
|
193
|
+
record_module_names=self.record_module_names,
|
|
194
|
+
table_kwargs=self.table_kwargs,
|
|
195
|
+
dirpath=dirpath,
|
|
196
|
+
filename=filename,
|
|
197
|
+
**self.additional_profiler_kwargs,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
ProfilerConfig: TypeAlias = Annotated[
|
|
202
|
+
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
|
203
|
+
Field(discriminator="kind"),
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class EnvironmentClassInformationConfig(TypedConfig):
|
|
208
|
+
name: str
|
|
209
|
+
module: str
|
|
210
|
+
full_name: str
|
|
211
|
+
|
|
212
|
+
file_path: Path
|
|
213
|
+
source_file_path: Path | None = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class EnvironmentSLURMInformationConfig(TypedConfig):
|
|
217
|
+
hostname: str
|
|
218
|
+
hostnames: list[str]
|
|
219
|
+
job_id: str
|
|
220
|
+
raw_job_id: str
|
|
221
|
+
array_job_id: str | None
|
|
222
|
+
array_task_id: str | None
|
|
223
|
+
num_tasks: int
|
|
224
|
+
num_nodes: int
|
|
225
|
+
node: str | int | None
|
|
226
|
+
global_rank: int
|
|
227
|
+
local_rank: int
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_current_environment(cls):
|
|
231
|
+
try:
|
|
232
|
+
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
233
|
+
|
|
234
|
+
if not SLURMEnvironment.detect():
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
hostname = socket.gethostname()
|
|
238
|
+
hostnames = [hostname]
|
|
239
|
+
if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
|
|
240
|
+
hostnames = parse_slurm_node_list(node_list)
|
|
241
|
+
|
|
242
|
+
raw_job_id = os.environ["SLURM_JOB_ID"]
|
|
243
|
+
job_id = raw_job_id
|
|
244
|
+
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
|
|
245
|
+
array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
|
|
246
|
+
if array_job_id and array_task_id:
|
|
247
|
+
job_id = f"{array_job_id}_{array_task_id}"
|
|
248
|
+
|
|
249
|
+
num_tasks = int(os.environ["SLURM_NTASKS"])
|
|
250
|
+
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
|
251
|
+
|
|
252
|
+
node_id = os.environ.get("SLURM_NODEID")
|
|
253
|
+
|
|
254
|
+
global_rank = int(os.environ["SLURM_PROCID"])
|
|
255
|
+
local_rank = int(os.environ["SLURM_LOCALID"])
|
|
256
|
+
|
|
257
|
+
return cls(
|
|
258
|
+
hostname=hostname,
|
|
259
|
+
hostnames=hostnames,
|
|
260
|
+
job_id=job_id,
|
|
261
|
+
raw_job_id=raw_job_id,
|
|
262
|
+
array_job_id=array_job_id,
|
|
263
|
+
array_task_id=array_task_id,
|
|
264
|
+
num_tasks=num_tasks,
|
|
265
|
+
num_nodes=num_nodes,
|
|
266
|
+
node=node_id,
|
|
267
|
+
global_rank=global_rank,
|
|
268
|
+
local_rank=local_rank,
|
|
269
|
+
)
|
|
270
|
+
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class EnvironmentLSFInformationConfig(TypedConfig):
|
|
275
|
+
hostname: str
|
|
276
|
+
hostnames: list[str]
|
|
277
|
+
job_id: str
|
|
278
|
+
array_job_id: str | None
|
|
279
|
+
array_task_id: str | None
|
|
280
|
+
num_tasks: int
|
|
281
|
+
num_nodes: int
|
|
282
|
+
node: str | int | None
|
|
283
|
+
global_rank: int
|
|
284
|
+
local_rank: int
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def from_current_environment(cls):
|
|
288
|
+
try:
|
|
289
|
+
import os
|
|
290
|
+
import socket
|
|
291
|
+
|
|
292
|
+
hostname = socket.gethostname()
|
|
293
|
+
hostnames = [hostname]
|
|
294
|
+
if node_list := os.environ.get("LSB_HOSTS", ""):
|
|
295
|
+
hostnames = node_list.split()
|
|
296
|
+
|
|
297
|
+
job_id = os.environ["LSB_JOBID"]
|
|
298
|
+
array_job_id = os.environ.get("LSB_JOBINDEX")
|
|
299
|
+
array_task_id = os.environ.get("LSB_JOBINDEX")
|
|
300
|
+
|
|
301
|
+
num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
|
|
302
|
+
num_nodes = len(set(hostnames))
|
|
303
|
+
|
|
304
|
+
node_id = (
|
|
305
|
+
os.environ.get("LSB_HOSTS", "").split().index(hostname)
|
|
306
|
+
if "LSB_HOSTS" in os.environ
|
|
307
|
+
else None
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# LSF doesn't have direct equivalents for global_rank and local_rank
|
|
311
|
+
# You might need to calculate these based on your specific setup
|
|
312
|
+
global_rank = int(os.environ.get("PMI_RANK", 0))
|
|
313
|
+
local_rank = int(os.environ.get("LSB_RANK", 0))
|
|
314
|
+
|
|
315
|
+
return cls(
|
|
316
|
+
hostname=hostname,
|
|
317
|
+
hostnames=hostnames,
|
|
318
|
+
job_id=job_id,
|
|
319
|
+
array_job_id=array_job_id,
|
|
320
|
+
array_task_id=array_task_id,
|
|
321
|
+
num_tasks=num_tasks,
|
|
322
|
+
num_nodes=num_nodes,
|
|
323
|
+
node=node_id,
|
|
324
|
+
global_rank=global_rank,
|
|
325
|
+
local_rank=local_rank,
|
|
326
|
+
)
|
|
327
|
+
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
328
|
+
return None
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class EnvironmentLinuxEnvironmentConfig(TypedConfig):
|
|
332
|
+
"""
|
|
333
|
+
Information about the Linux environment (e.g., current user, hostname, etc.)
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
user: str | None = None
|
|
337
|
+
hostname: str | None = None
|
|
338
|
+
system: str | None = None
|
|
339
|
+
release: str | None = None
|
|
340
|
+
version: str | None = None
|
|
341
|
+
machine: str | None = None
|
|
342
|
+
processor: str | None = None
|
|
343
|
+
cpu_count: int | None = None
|
|
344
|
+
memory: int | None = None
|
|
345
|
+
uptime: timedelta | None = None
|
|
346
|
+
boot_time: float | None = None
|
|
347
|
+
load_avg: tuple[float, float, float] | None = None
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class EnvironmentConfig(TypedConfig):
|
|
351
|
+
cwd: Path | None = None
|
|
352
|
+
|
|
353
|
+
python_executable: Path | None = None
|
|
354
|
+
python_path: list[Path] | None = None
|
|
355
|
+
python_version: str | None = None
|
|
356
|
+
|
|
357
|
+
config: EnvironmentClassInformationConfig | None = None
|
|
358
|
+
model: EnvironmentClassInformationConfig | None = None
|
|
359
|
+
data: EnvironmentClassInformationConfig | None = None
|
|
360
|
+
|
|
361
|
+
linux: EnvironmentLinuxEnvironmentConfig | None = None
|
|
362
|
+
|
|
363
|
+
slurm: EnvironmentSLURMInformationConfig | None = None
|
|
364
|
+
lsf: EnvironmentLSFInformationConfig | None = None
|
|
365
|
+
|
|
366
|
+
base_dir: Path | None = None
|
|
367
|
+
log_dir: Path | None = None
|
|
368
|
+
checkpoint_dir: Path | None = None
|
|
369
|
+
stdio_dir: Path | None = None
|
|
370
|
+
|
|
371
|
+
seed: int | None = None
|
|
372
|
+
seed_workers: bool | None = None
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class BaseLoggerConfig(TypedConfig, ABC):
|
|
376
|
+
enabled: bool = True
|
|
377
|
+
"""Enable this logger."""
|
|
378
|
+
|
|
379
|
+
priority: int = 0
|
|
380
|
+
"""Priority of the logger. Higher values are logged first."""
|
|
381
|
+
|
|
382
|
+
log_dir: DirectoryPath | None = None
|
|
383
|
+
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
384
|
+
|
|
385
|
+
@abstractmethod
|
|
386
|
+
def construct_logger(self, root_config: "BaseConfig") -> Logger | None: ...
|
|
387
|
+
|
|
388
|
+
def disable_(self):
|
|
389
|
+
self.enabled = False
|
|
390
|
+
return self
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _project_name(
|
|
394
|
+
root_config: "BaseConfig",
|
|
395
|
+
default_project: str = "lightning_logs",
|
|
396
|
+
):
|
|
397
|
+
# If the config has a project name, use that.
|
|
398
|
+
if project := root_config.project:
|
|
399
|
+
return project
|
|
400
|
+
|
|
401
|
+
# Otherwise, we should use the name of the module that the config is defined in,
|
|
402
|
+
# if we can find it.
|
|
403
|
+
# If this isn't in a module, use the default project name.
|
|
404
|
+
if not (module := root_config.__module__):
|
|
405
|
+
return default_project
|
|
406
|
+
|
|
407
|
+
# If the module is a package, use the package name.
|
|
408
|
+
if not (module := module.split(".", maxsplit=1)[0].strip()):
|
|
409
|
+
return default_project
|
|
410
|
+
|
|
411
|
+
return module
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _wandb_available():
|
|
415
|
+
try:
|
|
416
|
+
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
|
|
417
|
+
|
|
418
|
+
if not _WANDB_AVAILABLE:
|
|
419
|
+
log.warning("WandB not found. Disabling WandbLogger.")
|
|
420
|
+
return False
|
|
421
|
+
return True
|
|
422
|
+
except ImportError:
|
|
423
|
+
return False
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
427
|
+
kind: Literal["wandb"] = "wandb"
|
|
428
|
+
|
|
429
|
+
enabled: bool = Field(default_factory=lambda: _wandb_available())
|
|
430
|
+
"""Enable WandB logging."""
|
|
431
|
+
|
|
432
|
+
priority: int = 2
|
|
433
|
+
"""Priority of the logger. Higher values are logged first."""
|
|
434
|
+
|
|
435
|
+
project: str | None = None
|
|
436
|
+
"""WandB project name to use for the logger. If None, will use the root config's project name."""
|
|
437
|
+
|
|
438
|
+
log_model: bool | Literal["all"] = False
|
|
439
|
+
"""
|
|
440
|
+
Whether to log the model checkpoints to wandb.
|
|
441
|
+
Valid values are:
|
|
442
|
+
- False: Do not log the model checkpoints.
|
|
443
|
+
- True: Log the latest model checkpoint.
|
|
444
|
+
- "all": Log all model checkpoints.
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
watch: WandbWatchConfig = WandbWatchConfig()
|
|
448
|
+
"""WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
|
|
449
|
+
|
|
450
|
+
offline: bool = False
|
|
451
|
+
"""Whether to run WandB in offline mode."""
|
|
452
|
+
|
|
453
|
+
@override
|
|
454
|
+
def construct_logger(self, root_config):
|
|
455
|
+
if not self.enabled:
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
459
|
+
|
|
460
|
+
save_dir = root_config.directory.resolve_log_directory_for_logger(
|
|
461
|
+
root_config.id,
|
|
462
|
+
self,
|
|
463
|
+
)
|
|
464
|
+
save_dir = save_dir / "wandb"
|
|
465
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
466
|
+
return WandbLogger(
|
|
467
|
+
save_dir=save_dir,
|
|
468
|
+
project=self.project or _project_name(root_config),
|
|
469
|
+
name=root_config.run_name,
|
|
470
|
+
version=root_config.id,
|
|
471
|
+
log_model=self.log_model,
|
|
472
|
+
notes=(
|
|
473
|
+
"\n".join(f"- {note}" for note in root_config.notes)
|
|
474
|
+
if root_config.notes
|
|
475
|
+
else None
|
|
476
|
+
),
|
|
477
|
+
tags=root_config.tags,
|
|
478
|
+
offline=self.offline,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
@override
|
|
482
|
+
def construct_callbacks(self, root_config):
|
|
483
|
+
if self.watch:
|
|
484
|
+
yield from self.watch.construct_callbacks(root_config)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
class CSVLoggerConfig(BaseLoggerConfig):
|
|
488
|
+
kind: Literal["csv"] = "csv"
|
|
489
|
+
|
|
490
|
+
enabled: bool = True
|
|
491
|
+
"""Enable CSV logging."""
|
|
492
|
+
|
|
493
|
+
priority: int = 0
|
|
494
|
+
"""Priority of the logger. Higher values are logged first."""
|
|
495
|
+
|
|
496
|
+
prefix: str = ""
|
|
497
|
+
"""A string to put at the beginning of metric keys."""
|
|
498
|
+
|
|
499
|
+
flush_logs_every_n_steps: int = 100
|
|
500
|
+
"""How often to flush logs to disk."""
|
|
501
|
+
|
|
502
|
+
@override
|
|
503
|
+
def construct_logger(self, root_config):
|
|
504
|
+
if not self.enabled:
|
|
505
|
+
return None
|
|
506
|
+
|
|
507
|
+
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
|
508
|
+
|
|
509
|
+
save_dir = root_config.directory.resolve_log_directory_for_logger(
|
|
510
|
+
root_config.id,
|
|
511
|
+
self,
|
|
512
|
+
)
|
|
513
|
+
save_dir = save_dir / "csv"
|
|
514
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
515
|
+
return CSVLogger(
|
|
516
|
+
save_dir=save_dir,
|
|
517
|
+
name=root_config.run_name,
|
|
518
|
+
version=root_config.id,
|
|
519
|
+
prefix=self.prefix,
|
|
520
|
+
flush_logs_every_n_steps=self.flush_logs_every_n_steps,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def _tensorboard_available():
|
|
525
|
+
try:
|
|
526
|
+
from lightning.fabric.loggers.tensorboard import (
|
|
527
|
+
_TENSORBOARD_AVAILABLE,
|
|
528
|
+
_TENSORBOARDX_AVAILABLE,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE:
|
|
532
|
+
log.warning(
|
|
533
|
+
"TensorBoard/TensorBoardX not found. Disabling TensorBoardLogger. "
|
|
534
|
+
"Please install TensorBoard with `pip install tensorboard` or "
|
|
535
|
+
"TensorBoardX with `pip install tensorboardx` to enable TensorBoard logging."
|
|
536
|
+
)
|
|
537
|
+
return False
|
|
538
|
+
return True
|
|
539
|
+
except ImportError:
|
|
540
|
+
return False
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
544
|
+
kind: Literal["tensorboard"] = "tensorboard"
|
|
545
|
+
|
|
546
|
+
enabled: bool = Field(default_factory=lambda: _tensorboard_available())
|
|
547
|
+
"""Enable TensorBoard logging."""
|
|
548
|
+
|
|
549
|
+
priority: int = 2
|
|
550
|
+
"""Priority of the logger. Higher values are logged first."""
|
|
551
|
+
|
|
552
|
+
log_graph: bool = False
|
|
553
|
+
"""
|
|
554
|
+
Adds the computational graph to tensorboard. This requires that
|
|
555
|
+
the user has defined the `self.example_input_array` attribute in their
|
|
556
|
+
model.
|
|
557
|
+
"""
|
|
558
|
+
|
|
559
|
+
default_hp_metric: bool = True
|
|
560
|
+
"""
|
|
561
|
+
Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
|
|
562
|
+
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
|
|
563
|
+
"""
|
|
564
|
+
|
|
565
|
+
prefix: str = ""
|
|
566
|
+
"""A string to put at the beginning of metric keys."""
|
|
567
|
+
|
|
568
|
+
@override
|
|
569
|
+
def construct_logger(self, root_config):
|
|
570
|
+
if not self.enabled:
|
|
571
|
+
return None
|
|
572
|
+
|
|
573
|
+
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
|
574
|
+
|
|
575
|
+
save_dir = root_config.directory.resolve_log_directory_for_logger(
|
|
576
|
+
root_config.id,
|
|
577
|
+
self,
|
|
578
|
+
)
|
|
579
|
+
save_dir = save_dir / "tensorboard"
|
|
580
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
581
|
+
return TensorBoardLogger(
|
|
582
|
+
save_dir=save_dir,
|
|
583
|
+
name=root_config.run_name,
|
|
584
|
+
version=root_config.id,
|
|
585
|
+
log_graph=self.log_graph,
|
|
586
|
+
default_hp_metric=self.default_hp_metric,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
LoggerConfig: TypeAlias = Annotated[
|
|
591
|
+
WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
|
|
592
|
+
Field(discriminator="kind"),
|
|
593
|
+
]
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class LoggingConfig(CallbackConfigBase):
|
|
597
|
+
enabled: bool = True
|
|
598
|
+
"""Enable experiment tracking."""
|
|
599
|
+
|
|
600
|
+
loggers: Sequence[LoggerConfig] = [
|
|
601
|
+
WandbLoggerConfig(),
|
|
602
|
+
CSVLoggerConfig(),
|
|
603
|
+
TensorboardLoggerConfig(),
|
|
604
|
+
]
|
|
605
|
+
"""Loggers to use for experiment tracking."""
|
|
606
|
+
|
|
607
|
+
log_lr: bool | Literal["step", "epoch"] = True
|
|
608
|
+
"""If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
|
|
609
|
+
log_epoch: bool = True
|
|
610
|
+
"""If enabled, will log the fractional epoch number to the logger."""
|
|
611
|
+
|
|
612
|
+
@property
|
|
613
|
+
def wandb(self) -> WandbLoggerConfig | None:
|
|
614
|
+
return next(
|
|
615
|
+
(
|
|
616
|
+
logger
|
|
617
|
+
for logger in self.loggers
|
|
618
|
+
if isinstance(logger, WandbLoggerConfig)
|
|
619
|
+
),
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
@property
|
|
623
|
+
def csv(self) -> CSVLoggerConfig | None:
|
|
624
|
+
return next(
|
|
625
|
+
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
@property
|
|
629
|
+
def tensorboard(self) -> TensorboardLoggerConfig | None:
|
|
630
|
+
return next(
|
|
631
|
+
(
|
|
632
|
+
logger
|
|
633
|
+
for logger in self.loggers
|
|
634
|
+
if isinstance(logger, TensorboardLoggerConfig)
|
|
635
|
+
),
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
def construct_loggers(self, root_config: "BaseConfig"):
|
|
639
|
+
"""
|
|
640
|
+
Constructs and returns a list of loggers based on the provided root configuration.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
root_config (BaseConfig): The root configuration object.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
list[Logger]: A list of constructed loggers.
|
|
647
|
+
"""
|
|
648
|
+
loggers: list[Logger] = []
|
|
649
|
+
if not self.enabled:
|
|
650
|
+
return loggers
|
|
651
|
+
|
|
652
|
+
for logger_config in sorted(
|
|
653
|
+
self.loggers,
|
|
654
|
+
key=lambda x: x.priority,
|
|
655
|
+
reverse=True,
|
|
656
|
+
):
|
|
657
|
+
if not logger_config.enabled:
|
|
658
|
+
continue
|
|
659
|
+
if (logger := logger_config.construct_logger(root_config)) is None:
|
|
660
|
+
continue
|
|
661
|
+
loggers.append(logger)
|
|
662
|
+
return loggers
|
|
663
|
+
|
|
664
|
+
@override
|
|
665
|
+
def construct_callbacks(self, root_config):
|
|
666
|
+
if self.log_lr:
|
|
667
|
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
668
|
+
|
|
669
|
+
logging_interval: str | None = None
|
|
670
|
+
if isinstance(self.log_lr, str):
|
|
671
|
+
logging_interval = self.log_lr
|
|
672
|
+
|
|
673
|
+
yield LearningRateMonitor(logging_interval=logging_interval)
|
|
674
|
+
|
|
675
|
+
if self.log_epoch:
|
|
676
|
+
from ..callbacks.log_epoch import LogEpochCallback
|
|
677
|
+
|
|
678
|
+
yield LogEpochCallback()
|
|
679
|
+
|
|
680
|
+
for logger in self.loggers:
|
|
681
|
+
if not logger or not isinstance(logger, CallbackConfigBase):
|
|
682
|
+
continue
|
|
683
|
+
|
|
684
|
+
yield from logger.construct_callbacks(root_config)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
class GradientClippingConfig(TypedConfig):
|
|
688
|
+
enabled: bool = True
|
|
689
|
+
"""Enable gradient clipping."""
|
|
690
|
+
value: int | float
|
|
691
|
+
"""Value to use for gradient clipping."""
|
|
692
|
+
algorithm: Literal["value", "norm"] = "norm"
|
|
693
|
+
"""Norm type to use for gradient clipping."""
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
class OptimizationConfig(CallbackConfigBase):
|
|
697
|
+
log_grad_norm: bool | str | float = False
|
|
698
|
+
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
|
699
|
+
log_grad_norm_per_param: bool | str | float = False
|
|
700
|
+
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
|
701
|
+
|
|
702
|
+
log_param_norm: bool | str | float = False
|
|
703
|
+
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
|
704
|
+
log_param_norm_per_param: bool | str | float = False
|
|
705
|
+
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
|
706
|
+
|
|
707
|
+
gradient_clipping: GradientClippingConfig | None = None
|
|
708
|
+
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
|
709
|
+
|
|
710
|
+
@override
|
|
711
|
+
def construct_callbacks(self, root_config):
|
|
712
|
+
from ..callbacks.norm_logging import NormLoggingConfig
|
|
713
|
+
|
|
714
|
+
yield from NormLoggingConfig(
|
|
715
|
+
log_grad_norm=self.log_grad_norm,
|
|
716
|
+
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
|
717
|
+
log_param_norm=self.log_param_norm,
|
|
718
|
+
log_param_norm_per_param=self.log_param_norm_per_param,
|
|
719
|
+
).construct_callbacks(root_config)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
LogLevel: TypeAlias = Literal[
|
|
723
|
+
"CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"
|
|
724
|
+
]
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class PythonLogging(TypedConfig):
|
|
728
|
+
log_level: LogLevel | None = None
|
|
729
|
+
"""Log level to use for the Python logger (or None to use the default)."""
|
|
730
|
+
|
|
731
|
+
rich: bool = False
|
|
732
|
+
"""If enabled, will use the rich library to format the Python logger output."""
|
|
733
|
+
rich_tracebacks: bool = True
|
|
734
|
+
"""If enabled, will use the rich library to format the Python logger tracebacks."""
|
|
735
|
+
|
|
736
|
+
lovely_tensors: bool = False
|
|
737
|
+
"""If enabled, will use the lovely-tensors library to format PyTorch tensors. False by default as it causes issues when used with `torch.vmap`."""
|
|
738
|
+
lovely_numpy: bool = False
|
|
739
|
+
"""If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries."""
|
|
740
|
+
|
|
741
|
+
def pretty_(
|
|
742
|
+
self,
|
|
743
|
+
*,
|
|
744
|
+
log_level: LogLevel | None = "INFO",
|
|
745
|
+
torch: bool = True,
|
|
746
|
+
numpy: bool = True,
|
|
747
|
+
rich: bool = True,
|
|
748
|
+
rich_tracebacks: bool = True,
|
|
749
|
+
):
|
|
750
|
+
self.log_level = log_level
|
|
751
|
+
self.lovely_tensors = torch
|
|
752
|
+
self.lovely_numpy = numpy
|
|
753
|
+
self.rich = rich
|
|
754
|
+
self.rich_tracebacks = rich_tracebacks
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
TPlugin = TypeVar(
|
|
758
|
+
"TPlugin",
|
|
759
|
+
Precision,
|
|
760
|
+
ClusterEnvironment,
|
|
761
|
+
CheckpointIO,
|
|
762
|
+
LayerSync,
|
|
763
|
+
infer_variance=True,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
@runtime_checkable
|
|
768
|
+
class PluginConfigProtocol(Protocol[TPlugin]):
|
|
769
|
+
def construct_plugin(self) -> TPlugin: ...
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
@runtime_checkable
|
|
773
|
+
class AcceleratorConfigProtocol(Protocol):
|
|
774
|
+
def construct_accelerator(self) -> Accelerator: ...
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
@runtime_checkable
|
|
778
|
+
class StrategyConfigProtocol(Protocol):
|
|
779
|
+
def construct_strategy(self) -> Strategy: ...
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
AcceleratorLiteral: TypeAlias = Literal[
|
|
783
|
+
"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
|
|
784
|
+
]
|
|
785
|
+
|
|
786
|
+
StrategyLiteral: TypeAlias = Literal[
|
|
787
|
+
"auto",
|
|
788
|
+
"ddp",
|
|
789
|
+
"ddp_find_unused_parameters_false",
|
|
790
|
+
"ddp_find_unused_parameters_true",
|
|
791
|
+
"ddp_spawn",
|
|
792
|
+
"ddp_spawn_find_unused_parameters_false",
|
|
793
|
+
"ddp_spawn_find_unused_parameters_true",
|
|
794
|
+
"ddp_fork",
|
|
795
|
+
"ddp_fork_find_unused_parameters_false",
|
|
796
|
+
"ddp_fork_find_unused_parameters_true",
|
|
797
|
+
"ddp_notebook",
|
|
798
|
+
"dp",
|
|
799
|
+
"deepspeed",
|
|
800
|
+
"deepspeed_stage_1",
|
|
801
|
+
"deepspeed_stage_1_offload",
|
|
802
|
+
"deepspeed_stage_2",
|
|
803
|
+
"deepspeed_stage_2_offload",
|
|
804
|
+
"deepspeed_stage_3",
|
|
805
|
+
"deepspeed_stage_3_offload",
|
|
806
|
+
"deepspeed_stage_3_offload_nvme",
|
|
807
|
+
"fsdp",
|
|
808
|
+
"fsdp_cpu_offload",
|
|
809
|
+
"single_xla",
|
|
810
|
+
"xla_fsdp",
|
|
811
|
+
"xla",
|
|
812
|
+
"single_tpu",
|
|
813
|
+
]
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
class CheckpointLoadingConfig(TypedConfig):
|
|
817
|
+
path: Literal["best", "last", "hpc"] | str | Path | None = None
|
|
818
|
+
"""
|
|
819
|
+
Checkpoint path to use when loading a checkpoint.
|
|
820
|
+
|
|
821
|
+
- "best" will load the best checkpoint.
|
|
822
|
+
- "last" will load the last checkpoint.
|
|
823
|
+
- "hpc" will load the SLURM pre-empted checkpoint.
|
|
824
|
+
- Any other string or Path will load the checkpoint from the specified path.
|
|
825
|
+
"""
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
class DirectoryConfig(TypedConfig):
|
|
829
|
+
project_root: Path | None = None
|
|
830
|
+
"""
|
|
831
|
+
Root directory for this project.
|
|
832
|
+
|
|
833
|
+
This isn't specific to the run; it is the parent directory of all runs.
|
|
834
|
+
"""
|
|
835
|
+
|
|
836
|
+
log: Path | None = None
|
|
837
|
+
"""Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use lltrainer/{id}/log/."""
|
|
838
|
+
|
|
839
|
+
stdio: Path | None = None
|
|
840
|
+
"""stdout/stderr log directory to use for the trainer. If None, will use lltrainer/{id}/stdio/."""
|
|
841
|
+
|
|
842
|
+
checkpoint: Path | None = None
|
|
843
|
+
"""Checkpoint directory to use for the trainer. If None, will use lltrainer/{id}/checkpoint/."""
|
|
844
|
+
|
|
845
|
+
activation: Path | None = None
|
|
846
|
+
"""Activation directory to use for the trainer. If None, will use lltrainer/{id}/activation/."""
|
|
847
|
+
|
|
848
|
+
profile: Path | None = None
|
|
849
|
+
"""Directory to save profiling information to. If None, will use lltrainer/{id}/profile/."""
|
|
850
|
+
|
|
851
|
+
def resolve_run_root_directory(self, run_id: str) -> Path:
|
|
852
|
+
if (project_root_dir := self.project_root) is None:
|
|
853
|
+
project_root_dir = Path.cwd()
|
|
854
|
+
|
|
855
|
+
# The default base dir is $CWD/lltrainer/{id}/
|
|
856
|
+
base_dir = project_root_dir / "lltrainer"
|
|
857
|
+
base_dir.mkdir(exist_ok=True)
|
|
858
|
+
|
|
859
|
+
# Add a .gitignore file to the lltrainer directory
|
|
860
|
+
# which will ignore all files except for the .gitignore file itself
|
|
861
|
+
gitignore_path = base_dir / ".gitignore"
|
|
862
|
+
if not gitignore_path.exists():
|
|
863
|
+
gitignore_path.touch()
|
|
864
|
+
gitignore_path.write_text("*\n")
|
|
865
|
+
|
|
866
|
+
base_dir = base_dir / run_id
|
|
867
|
+
base_dir.mkdir(exist_ok=True)
|
|
868
|
+
|
|
869
|
+
return base_dir
|
|
870
|
+
|
|
871
|
+
def resolve_subdirectory(
|
|
872
|
+
self,
|
|
873
|
+
run_id: str,
|
|
874
|
+
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
|
875
|
+
subdirectory: str,
|
|
876
|
+
) -> Path:
|
|
877
|
+
# The subdir will be $CWD/lltrainer/{id}/{log, stdio, checkpoint, activation}/
|
|
878
|
+
if (subdir := getattr(self, subdirectory, None)) is not None:
|
|
879
|
+
assert isinstance(
|
|
880
|
+
subdir, Path
|
|
881
|
+
), f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
|
882
|
+
return subdir
|
|
883
|
+
|
|
884
|
+
dir = self.resolve_run_root_directory(run_id)
|
|
885
|
+
dir = dir / subdirectory
|
|
886
|
+
dir.mkdir(exist_ok=True)
|
|
887
|
+
return dir
|
|
888
|
+
|
|
889
|
+
def resolve_log_directory_for_logger(
|
|
890
|
+
self,
|
|
891
|
+
run_id: str,
|
|
892
|
+
logger: LoggerConfig,
|
|
893
|
+
) -> Path:
|
|
894
|
+
if (log_dir := logger.log_dir) is not None:
|
|
895
|
+
return log_dir
|
|
896
|
+
|
|
897
|
+
# Save to lltrainer/{id}/log/{logger kind}/{id}/
|
|
898
|
+
log_dir = self.resolve_subdirectory(run_id, "log")
|
|
899
|
+
log_dir = log_dir / logger.kind
|
|
900
|
+
|
|
901
|
+
return log_dir
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
class ReproducibilityConfig(TypedConfig):
|
|
905
|
+
deterministic: bool | Literal["warn"] | None = None
|
|
906
|
+
"""
|
|
907
|
+
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
|
908
|
+
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
|
909
|
+
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
|
910
|
+
"""
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
class ModelCheckpointCallbackConfig(CallbackConfigBase):
|
|
914
|
+
"""Arguments for the ModelCheckpoint callback."""
|
|
915
|
+
|
|
916
|
+
kind: Literal["model_checkpoint"] = "model_checkpoint"
|
|
917
|
+
|
|
918
|
+
dirpath: str | Path | None = None
|
|
919
|
+
"""
|
|
920
|
+
Directory path to save the model file. If `None`, we save to the checkpoint directory set in `config.directory`.
|
|
921
|
+
"""
|
|
922
|
+
|
|
923
|
+
filename: str | None = None
|
|
924
|
+
"""
|
|
925
|
+
Checkpoint filename.
|
|
926
|
+
If None, a default template is used (see :attr:`ModelCheckpoint.CHECKPOINT_JOIN_CHAR`).
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
monitor: str | None = None
|
|
930
|
+
"""
|
|
931
|
+
Quantity to monitor for saving checkpoints.
|
|
932
|
+
If None, no metric is monitored and checkpoints are saved at the end of every epoch.
|
|
933
|
+
"""
|
|
934
|
+
|
|
935
|
+
verbose: bool = False
|
|
936
|
+
"""Verbosity mode. If True, print additional information about checkpoints."""
|
|
937
|
+
|
|
938
|
+
save_last: Literal[True, False, "link"] | None = "link"
|
|
939
|
+
"""
|
|
940
|
+
Whether to save the last checkpoint.
|
|
941
|
+
If True, saves a copy of the last checkpoint separately.
|
|
942
|
+
If "link", creates a symbolic link to the last checkpoint.
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
save_top_k: int = 1
|
|
946
|
+
"""
|
|
947
|
+
Number of best models to save.
|
|
948
|
+
If -1, all models are saved.
|
|
949
|
+
If 0, no models are saved.
|
|
950
|
+
"""
|
|
951
|
+
|
|
952
|
+
save_weights_only: bool = False
|
|
953
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
954
|
+
|
|
955
|
+
mode: str = "min"
|
|
956
|
+
"""
|
|
957
|
+
One of "min" or "max".
|
|
958
|
+
If "min", training will stop when the metric monitored has stopped decreasing.
|
|
959
|
+
If "max", training will stop when the metric monitored has stopped increasing.
|
|
960
|
+
"""
|
|
961
|
+
|
|
962
|
+
auto_insert_metric_name: bool = True
|
|
963
|
+
"""Whether to automatically insert the metric name in the checkpoint filename."""
|
|
964
|
+
|
|
965
|
+
every_n_train_steps: int | None = None
|
|
966
|
+
"""
|
|
967
|
+
Number of training steps between checkpoints.
|
|
968
|
+
If None or 0, no checkpoints are saved during training.
|
|
969
|
+
"""
|
|
970
|
+
|
|
971
|
+
train_time_interval: timedelta | None = None
|
|
972
|
+
"""
|
|
973
|
+
Time interval between checkpoints during training.
|
|
974
|
+
If None, no checkpoints are saved during training based on time.
|
|
975
|
+
"""
|
|
976
|
+
|
|
977
|
+
every_n_epochs: int | None = None
|
|
978
|
+
"""
|
|
979
|
+
Number of epochs between checkpoints.
|
|
980
|
+
If None or 0, no checkpoints are saved at the end of epochs.
|
|
981
|
+
"""
|
|
982
|
+
|
|
983
|
+
save_on_train_epoch_end: bool | None = None
|
|
984
|
+
"""
|
|
985
|
+
Whether to run checkpointing at the end of the training epoch.
|
|
986
|
+
If False, checkpointing runs at the end of the validation.
|
|
987
|
+
"""
|
|
988
|
+
|
|
989
|
+
enable_version_counter: bool = True
|
|
990
|
+
"""Whether to append a version to the existing file name."""
|
|
991
|
+
|
|
992
|
+
auto_append_metric: bool = True
|
|
993
|
+
"""If enabled, this will automatically add "-{monitor}" to the filename."""
|
|
994
|
+
|
|
995
|
+
@staticmethod
|
|
996
|
+
def _convert_string(input_string: str):
|
|
997
|
+
# Find all variables enclosed in curly braces
|
|
998
|
+
variables = re.findall(r"\{(.*?)\}", input_string)
|
|
999
|
+
|
|
1000
|
+
# Replace each variable with its corresponding key-value pair
|
|
1001
|
+
output_string = input_string
|
|
1002
|
+
for variable in variables:
|
|
1003
|
+
# If the name is something like {variable:format}, we shouldn't process the format.
|
|
1004
|
+
key_name = variable
|
|
1005
|
+
if ":" in variable:
|
|
1006
|
+
key_name, _ = variable.split(":", 1)
|
|
1007
|
+
continue
|
|
1008
|
+
|
|
1009
|
+
# Replace '/' with '_' in the key name
|
|
1010
|
+
key_name = key_name.replace("/", "_")
|
|
1011
|
+
output_string = output_string.replace(
|
|
1012
|
+
f"{{{variable}}}", f"{key_name}={{{variable}}}"
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
return output_string
|
|
1016
|
+
|
|
1017
|
+
@override
|
|
1018
|
+
def construct_callbacks(self, root_config):
|
|
1019
|
+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
1020
|
+
|
|
1021
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1022
|
+
root_config.id, "checkpoint"
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# If `monitor` is not provided, we can use `config.primary_metric` if it is set.
|
|
1026
|
+
monitor = self.monitor
|
|
1027
|
+
mode = self.mode
|
|
1028
|
+
if (
|
|
1029
|
+
monitor is None
|
|
1030
|
+
and (primary_metric := root_config.primary_metric) is not None
|
|
1031
|
+
):
|
|
1032
|
+
monitor = primary_metric.validation_monitor
|
|
1033
|
+
mode = primary_metric.mode
|
|
1034
|
+
|
|
1035
|
+
filename = self.filename
|
|
1036
|
+
if self.auto_append_metric:
|
|
1037
|
+
if not filename:
|
|
1038
|
+
filename = "{epoch}-{step}"
|
|
1039
|
+
filename = f"{filename}-{{{monitor}}}"
|
|
1040
|
+
|
|
1041
|
+
if self.auto_insert_metric_name and filename:
|
|
1042
|
+
new_filename = self._convert_string(filename)
|
|
1043
|
+
log.critical(
|
|
1044
|
+
f"Updated ModelCheckpoint filename: {filename} -> {new_filename}"
|
|
1045
|
+
)
|
|
1046
|
+
filename = new_filename
|
|
1047
|
+
|
|
1048
|
+
yield ModelCheckpoint(
|
|
1049
|
+
dirpath=dirpath,
|
|
1050
|
+
filename=filename,
|
|
1051
|
+
monitor=monitor,
|
|
1052
|
+
mode=mode,
|
|
1053
|
+
verbose=self.verbose,
|
|
1054
|
+
save_last=self.save_last,
|
|
1055
|
+
save_top_k=self.save_top_k,
|
|
1056
|
+
save_weights_only=self.save_weights_only,
|
|
1057
|
+
auto_insert_metric_name=False,
|
|
1058
|
+
every_n_train_steps=self.every_n_train_steps,
|
|
1059
|
+
train_time_interval=self.train_time_interval,
|
|
1060
|
+
every_n_epochs=self.every_n_epochs,
|
|
1061
|
+
save_on_train_epoch_end=self.save_on_train_epoch_end,
|
|
1062
|
+
enable_version_counter=self.enable_version_counter,
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
|
|
1067
|
+
kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
|
|
1068
|
+
|
|
1069
|
+
dirpath: str | Path | None = None
|
|
1070
|
+
"""Directory path to save the checkpoint file."""
|
|
1071
|
+
|
|
1072
|
+
filename: str | None = None
|
|
1073
|
+
"""Checkpoint filename. This must not include the extension. If `None`, `latest_epoch_{id}_{timestamp}` is used."""
|
|
1074
|
+
|
|
1075
|
+
save_weights_only: bool = False
|
|
1076
|
+
"""Whether to save only the model's weights or the entire model object."""
|
|
1077
|
+
|
|
1078
|
+
@override
|
|
1079
|
+
def construct_callbacks(self, root_config):
|
|
1080
|
+
from ..callbacks.latest_epoch_checkpoint import LatestEpochCheckpoint
|
|
1081
|
+
|
|
1082
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1083
|
+
root_config.id, "checkpoint"
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
yield LatestEpochCheckpoint(
|
|
1087
|
+
dirpath=dirpath,
|
|
1088
|
+
filename=self.filename,
|
|
1089
|
+
save_weights_only=self.save_weights_only,
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
|
|
1094
|
+
kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
|
|
1095
|
+
|
|
1096
|
+
dirpath: str | Path | None = None
|
|
1097
|
+
"""Directory path to save the checkpoint file."""
|
|
1098
|
+
|
|
1099
|
+
filename: str | None = None
|
|
1100
|
+
"""Checkpoint filename. This must not include the extension. If `None`, `on_exception_{id}_{timestamp}` is used."""
|
|
1101
|
+
|
|
1102
|
+
@override
|
|
1103
|
+
def construct_callbacks(self, root_config):
|
|
1104
|
+
from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
|
|
1105
|
+
|
|
1106
|
+
dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
|
|
1107
|
+
root_config.id, "checkpoint"
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
if not (filename := self.filename):
|
|
1111
|
+
filename = f"on_exception_{root_config.id}"
|
|
1112
|
+
yield OnExceptionCheckpoint(dirpath=dirpath, filename=filename)
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
1116
|
+
ModelCheckpointCallbackConfig
|
|
1117
|
+
| LatestEpochCheckpointCallbackConfig
|
|
1118
|
+
| OnExceptionCheckpointCallbackConfig,
|
|
1119
|
+
Field(discriminator="kind"),
|
|
1120
|
+
]
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
class CheckpointSavingConfig(CallbackConfigBase):
|
|
1124
|
+
enabled: bool = True
|
|
1125
|
+
"""Enable checkpoint saving."""
|
|
1126
|
+
|
|
1127
|
+
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
1128
|
+
ModelCheckpointCallbackConfig(),
|
|
1129
|
+
LatestEpochCheckpointCallbackConfig(),
|
|
1130
|
+
OnExceptionCheckpointCallbackConfig(),
|
|
1131
|
+
]
|
|
1132
|
+
"""Checkpoint callback configurations."""
|
|
1133
|
+
|
|
1134
|
+
def disable_(self):
|
|
1135
|
+
self.enabled = False
|
|
1136
|
+
return self
|
|
1137
|
+
|
|
1138
|
+
def should_save_checkpoints(self, root_config: "BaseConfig"):
|
|
1139
|
+
if not self.enabled:
|
|
1140
|
+
return False
|
|
1141
|
+
|
|
1142
|
+
if root_config.trainer.fast_dev_run:
|
|
1143
|
+
return False
|
|
1144
|
+
|
|
1145
|
+
return True
|
|
1146
|
+
|
|
1147
|
+
@property
|
|
1148
|
+
def model_checkpoint(self) -> ModelCheckpointCallbackConfig | None:
|
|
1149
|
+
return next(
|
|
1150
|
+
(
|
|
1151
|
+
callback
|
|
1152
|
+
for callback in self.checkpoint_callbacks
|
|
1153
|
+
if isinstance(callback, ModelCheckpointCallbackConfig)
|
|
1154
|
+
),
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
@property
|
|
1158
|
+
def latest_epoch_checkpoint(self) -> LatestEpochCheckpointCallbackConfig | None:
|
|
1159
|
+
return next(
|
|
1160
|
+
(
|
|
1161
|
+
callback
|
|
1162
|
+
for callback in self.checkpoint_callbacks
|
|
1163
|
+
if isinstance(callback, LatestEpochCheckpointCallbackConfig)
|
|
1164
|
+
),
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
@property
|
|
1168
|
+
def on_exception_checkpoint(self) -> OnExceptionCheckpointCallbackConfig | None:
|
|
1169
|
+
return next(
|
|
1170
|
+
(
|
|
1171
|
+
callback
|
|
1172
|
+
for callback in self.checkpoint_callbacks
|
|
1173
|
+
if isinstance(callback, OnExceptionCheckpointCallbackConfig)
|
|
1174
|
+
),
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
@override
|
|
1178
|
+
def construct_callbacks(self, root_config: "BaseConfig"):
|
|
1179
|
+
if not self.should_save_checkpoints(root_config):
|
|
1180
|
+
return
|
|
1181
|
+
|
|
1182
|
+
for callback_config in self.checkpoint_callbacks:
|
|
1183
|
+
yield from callback_config.construct_callbacks(root_config)
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
class LightningTrainerKwargs(TypedDict, total=False):
|
|
1187
|
+
accelerator: str | Accelerator
|
|
1188
|
+
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
|
1189
|
+
as well as custom accelerator instances."""
|
|
1190
|
+
|
|
1191
|
+
strategy: str | Strategy
|
|
1192
|
+
"""Supports different training strategies with aliases as well custom strategies.
|
|
1193
|
+
Default: ``"auto"``.
|
|
1194
|
+
"""
|
|
1195
|
+
|
|
1196
|
+
devices: list[int] | str | int
|
|
1197
|
+
"""The devices to use. Can be set to a positive number (int or str), a sequence of device indices
|
|
1198
|
+
(list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
|
|
1199
|
+
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
1200
|
+
"""
|
|
1201
|
+
|
|
1202
|
+
num_nodes: int
|
|
1203
|
+
"""Number of GPU nodes for distributed training.
|
|
1204
|
+
Default: ``1``.
|
|
1205
|
+
"""
|
|
1206
|
+
|
|
1207
|
+
precision: _PRECISION_INPUT | None
|
|
1208
|
+
"""Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
|
|
1209
|
+
16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
|
|
1210
|
+
Can be used on CPU, GPU, TPUs, HPUs or IPUs.
|
|
1211
|
+
Default: ``'32-true'``.
|
|
1212
|
+
"""
|
|
1213
|
+
|
|
1214
|
+
logger: Logger | Iterable[Logger] | bool | None
|
|
1215
|
+
"""Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
|
|
1216
|
+
the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
|
|
1217
|
+
``False`` will disable logging. If multiple loggers are provided, local files
|
|
1218
|
+
(checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
|
|
1219
|
+
Default: ``True``.
|
|
1220
|
+
"""
|
|
1221
|
+
|
|
1222
|
+
callbacks: list[Callback] | Callback | None
|
|
1223
|
+
"""Add a callback or list of callbacks.
|
|
1224
|
+
Default: ``None``.
|
|
1225
|
+
"""
|
|
1226
|
+
|
|
1227
|
+
fast_dev_run: int | bool
|
|
1228
|
+
"""Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
|
|
1229
|
+
of train, val and test to find any bugs (ie: a sort of unit test).
|
|
1230
|
+
Default: ``False``.
|
|
1231
|
+
"""
|
|
1232
|
+
|
|
1233
|
+
max_epochs: int | None
|
|
1234
|
+
"""Stop training once this number of epochs is reached. Disabled by default (None).
|
|
1235
|
+
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
|
|
1236
|
+
To enable infinite training, set ``max_epochs = -1``.
|
|
1237
|
+
"""
|
|
1238
|
+
|
|
1239
|
+
min_epochs: int | None
|
|
1240
|
+
"""Force training for at least these many epochs. Disabled by default (None).
|
|
1241
|
+
"""
|
|
1242
|
+
|
|
1243
|
+
max_steps: int
|
|
1244
|
+
"""Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
|
1245
|
+
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
|
1246
|
+
``max_epochs`` to ``-1``.
|
|
1247
|
+
"""
|
|
1248
|
+
|
|
1249
|
+
min_steps: int | None
|
|
1250
|
+
"""Force training for at least these number of steps. Disabled by default (``None``).
|
|
1251
|
+
"""
|
|
1252
|
+
|
|
1253
|
+
max_time: str | timedelta | dict[str, int] | None
|
|
1254
|
+
"""Stop training after this amount of time has passed. Disabled by default (``None``).
|
|
1255
|
+
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
|
|
1256
|
+
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
|
|
1257
|
+
:class:`datetime.timedelta`.
|
|
1258
|
+
"""
|
|
1259
|
+
|
|
1260
|
+
limit_train_batches: int | float | None
|
|
1261
|
+
"""How much of training dataset to check (float = fraction, int = num_batches).
|
|
1262
|
+
Default: ``1.0``.
|
|
1263
|
+
"""
|
|
1264
|
+
|
|
1265
|
+
limit_val_batches: int | float | None
|
|
1266
|
+
"""How much of validation dataset to check (float = fraction, int = num_batches).
|
|
1267
|
+
Default: ``1.0``.
|
|
1268
|
+
"""
|
|
1269
|
+
|
|
1270
|
+
limit_test_batches: int | float | None
|
|
1271
|
+
"""How much of test dataset to check (float = fraction, int = num_batches).
|
|
1272
|
+
Default: ``1.0``.
|
|
1273
|
+
"""
|
|
1274
|
+
|
|
1275
|
+
limit_predict_batches: int | float | None
|
|
1276
|
+
"""How much of prediction dataset to check (float = fraction, int = num_batches).
|
|
1277
|
+
Default: ``1.0``.
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
overfit_batches: int | float
|
|
1281
|
+
"""Overfit a fraction of training/validation data (float) or a set number of batches (int).
|
|
1282
|
+
Default: ``0.0``.
|
|
1283
|
+
"""
|
|
1284
|
+
|
|
1285
|
+
val_check_interval: int | float | None
|
|
1286
|
+
"""How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
|
|
1287
|
+
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
|
|
1288
|
+
batches. An ``int`` value can only be higher than the number of training batches when
|
|
1289
|
+
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
|
|
1290
|
+
across epochs or during iteration-based training.
|
|
1291
|
+
Default: ``1.0``.
|
|
1292
|
+
"""
|
|
1293
|
+
|
|
1294
|
+
check_val_every_n_epoch: int | None
|
|
1295
|
+
"""Perform a validation loop every after every `N` training epochs. If ``None``,
|
|
1296
|
+
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
|
|
1297
|
+
to be an integer value.
|
|
1298
|
+
Default: ``1``.
|
|
1299
|
+
"""
|
|
1300
|
+
|
|
1301
|
+
num_sanity_val_steps: int | None
|
|
1302
|
+
"""Sanity check runs n validation batches before starting the training routine.
|
|
1303
|
+
Set it to `-1` to run all batches in all validation dataloaders.
|
|
1304
|
+
Default: ``2``.
|
|
1305
|
+
"""
|
|
1306
|
+
|
|
1307
|
+
log_every_n_steps: int | None
|
|
1308
|
+
"""How often to log within steps.
|
|
1309
|
+
Default: ``50``.
|
|
1310
|
+
"""
|
|
1311
|
+
|
|
1312
|
+
enable_checkpointing: bool | None
|
|
1313
|
+
"""If ``True``, enable checkpointing.
|
|
1314
|
+
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
|
|
1315
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
|
|
1316
|
+
Default: ``True``.
|
|
1317
|
+
"""
|
|
1318
|
+
|
|
1319
|
+
enable_progress_bar: bool | None
|
|
1320
|
+
"""Whether to enable to progress bar by default.
|
|
1321
|
+
Default: ``True``.
|
|
1322
|
+
"""
|
|
1323
|
+
|
|
1324
|
+
enable_model_summary: bool | None
|
|
1325
|
+
"""Whether to enable model summarization by default.
|
|
1326
|
+
Default: ``True``.
|
|
1327
|
+
"""
|
|
1328
|
+
|
|
1329
|
+
accumulate_grad_batches: int
|
|
1330
|
+
"""Accumulates gradients over k batches before stepping the optimizer.
|
|
1331
|
+
Default: 1.
|
|
1332
|
+
"""
|
|
1333
|
+
|
|
1334
|
+
gradient_clip_val: int | float | None
|
|
1335
|
+
"""The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
|
|
1336
|
+
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
|
|
1337
|
+
Default: ``None``.
|
|
1338
|
+
"""
|
|
1339
|
+
|
|
1340
|
+
gradient_clip_algorithm: str | None
|
|
1341
|
+
"""The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
|
|
1342
|
+
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
|
|
1343
|
+
be set to ``"norm"``.
|
|
1344
|
+
"""
|
|
1345
|
+
|
|
1346
|
+
deterministic: bool | Literal["warn"] | None
|
|
1347
|
+
"""If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
|
1348
|
+
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
|
1349
|
+
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
|
1350
|
+
"""
|
|
1351
|
+
|
|
1352
|
+
benchmark: bool | None
|
|
1353
|
+
"""The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
|
|
1354
|
+
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
|
|
1355
|
+
(``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
|
|
1356
|
+
is set to ``True``, this will default to ``False``. Override to manually set a different value.
|
|
1357
|
+
Default: ``None``.
|
|
1358
|
+
"""
|
|
1359
|
+
|
|
1360
|
+
inference_mode: bool
|
|
1361
|
+
"""Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
|
|
1362
|
+
evaluation (``validate``/``test``/``predict``).
|
|
1363
|
+
"""
|
|
1364
|
+
|
|
1365
|
+
use_distributed_sampler: bool
|
|
1366
|
+
"""Whether to wrap the DataLoader's sampler with
|
|
1367
|
+
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
|
|
1368
|
+
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
|
|
1369
|
+
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
|
|
1370
|
+
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
|
|
1371
|
+
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
|
|
1372
|
+
we don't do this automatically.
|
|
1373
|
+
"""
|
|
1374
|
+
|
|
1375
|
+
profiler: Profiler | str | None
|
|
1376
|
+
"""To profile individual steps during training and assist in identifying bottlenecks.
|
|
1377
|
+
Default: ``None``.
|
|
1378
|
+
"""
|
|
1379
|
+
|
|
1380
|
+
detect_anomaly: bool
|
|
1381
|
+
"""Enable anomaly detection for the autograd engine.
|
|
1382
|
+
Default: ``False``.
|
|
1383
|
+
"""
|
|
1384
|
+
|
|
1385
|
+
barebones: bool
|
|
1386
|
+
"""Whether to run in "barebones mode", where all features that may impact raw speed are
|
|
1387
|
+
disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
|
|
1388
|
+
runs. The following features are deactivated:
|
|
1389
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
|
|
1390
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
|
|
1391
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
|
|
1392
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
|
|
1393
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
|
|
1394
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
|
|
1395
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
|
|
1396
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
|
|
1397
|
+
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
|
|
1398
|
+
:meth:`~lightning.pytorch.core.LightningModule.log`,
|
|
1399
|
+
:meth:`~lightning.pytorch.core.LightningModule.log_dict`.
|
|
1400
|
+
"""
|
|
1401
|
+
|
|
1402
|
+
plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
|
|
1403
|
+
"""Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
|
1404
|
+
Default: ``None``.
|
|
1405
|
+
"""
|
|
1406
|
+
|
|
1407
|
+
sync_batchnorm: bool
|
|
1408
|
+
"""Synchronize batch norm layers between process groups/whole world.
|
|
1409
|
+
Default: ``False``.
|
|
1410
|
+
"""
|
|
1411
|
+
|
|
1412
|
+
reload_dataloaders_every_n_epochs: int
|
|
1413
|
+
"""Set to a positive integer to reload dataloaders every n epochs.
|
|
1414
|
+
Default: ``0``.
|
|
1415
|
+
"""
|
|
1416
|
+
|
|
1417
|
+
default_root_dir: Path | None
|
|
1418
|
+
"""Default path for logs and weights when no logger/ckpt_callback passed.
|
|
1419
|
+
Default: ``os.getcwd()``.
|
|
1420
|
+
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
|
|
1421
|
+
"""
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
class EarlyStoppingConfig(CallbackConfigBase):
|
|
1425
|
+
monitor: str | None = None
|
|
1426
|
+
"""
|
|
1427
|
+
The metric to monitor for early stopping.
|
|
1428
|
+
If None, the primary metric will be used.
|
|
1429
|
+
"""
|
|
1430
|
+
|
|
1431
|
+
mode: Literal["min", "max"] | None = None
|
|
1432
|
+
"""
|
|
1433
|
+
The mode for the metric to monitor for early stopping.
|
|
1434
|
+
If None, the primary metric mode will be used.
|
|
1435
|
+
"""
|
|
1436
|
+
|
|
1437
|
+
patience: int
|
|
1438
|
+
"""
|
|
1439
|
+
Number of epochs with no improvement after which training will be stopped.
|
|
1440
|
+
"""
|
|
1441
|
+
|
|
1442
|
+
min_delta: float = 1.0e-8
|
|
1443
|
+
"""
|
|
1444
|
+
Minimum change in the monitored quantity to qualify as an improvement.
|
|
1445
|
+
"""
|
|
1446
|
+
|
|
1447
|
+
min_lr: float | None = None
|
|
1448
|
+
"""
|
|
1449
|
+
Minimum learning rate. If the learning rate of the model is less than this value,
|
|
1450
|
+
the training will be stopped.
|
|
1451
|
+
"""
|
|
1452
|
+
|
|
1453
|
+
strict: bool = True
|
|
1454
|
+
"""
|
|
1455
|
+
Whether to enforce that the monitored quantity must improve by at least `min_delta`
|
|
1456
|
+
to qualify as an improvement.
|
|
1457
|
+
"""
|
|
1458
|
+
|
|
1459
|
+
@override
|
|
1460
|
+
def construct_callbacks(self, root_config: "BaseConfig"):
|
|
1461
|
+
from ..callbacks.early_stopping import EarlyStopping
|
|
1462
|
+
|
|
1463
|
+
monitor = self.monitor
|
|
1464
|
+
mode = self.mode
|
|
1465
|
+
if monitor is None:
|
|
1466
|
+
assert mode is None, "If `monitor` is not provided, `mode` must be None."
|
|
1467
|
+
|
|
1468
|
+
primary_metric = root_config.primary_metric
|
|
1469
|
+
if primary_metric is None:
|
|
1470
|
+
raise ValueError(
|
|
1471
|
+
"No primary metric is set, so `monitor` must be provided in `early_stopping`."
|
|
1472
|
+
)
|
|
1473
|
+
monitor = primary_metric.validation_monitor
|
|
1474
|
+
mode = primary_metric.mode
|
|
1475
|
+
|
|
1476
|
+
if mode is None:
|
|
1477
|
+
mode = "min"
|
|
1478
|
+
|
|
1479
|
+
return [
|
|
1480
|
+
EarlyStopping(
|
|
1481
|
+
monitor=monitor,
|
|
1482
|
+
mode=mode,
|
|
1483
|
+
patience=self.patience,
|
|
1484
|
+
min_delta=self.min_delta,
|
|
1485
|
+
min_lr=self.min_lr,
|
|
1486
|
+
strict=self.strict,
|
|
1487
|
+
)
|
|
1488
|
+
]
|
|
1489
|
+
|
|
1490
|
+
|
|
1491
|
+
class ActSaveConfig(CallbackConfigBase):
|
|
1492
|
+
enabled: bool = True
|
|
1493
|
+
"""Enable activation saving."""
|
|
1494
|
+
|
|
1495
|
+
auto_save_logged_metrics: bool = False
|
|
1496
|
+
"""If enabled, will automatically save logged metrics (using `LightningModule.log`) as activations."""
|
|
1497
|
+
|
|
1498
|
+
save_dir: Path | None = None
|
|
1499
|
+
"""Directory to save activations to. If None, will use the activation directory set in `config.directory`."""
|
|
1500
|
+
|
|
1501
|
+
def __bool__(self):
|
|
1502
|
+
return self.enabled
|
|
1503
|
+
|
|
1504
|
+
def resolve_save_dir(self, root_config: "BaseConfig"):
|
|
1505
|
+
if self.save_dir is not None:
|
|
1506
|
+
return self.save_dir
|
|
1507
|
+
|
|
1508
|
+
return root_config.directory.resolve_subdirectory(root_config.id, "activation")
|
|
1509
|
+
|
|
1510
|
+
@override
|
|
1511
|
+
def construct_callbacks(self, root_config):
|
|
1512
|
+
from ..actsave import ActSaveCallback
|
|
1513
|
+
|
|
1514
|
+
return [ActSaveCallback()]
|
|
1515
|
+
|
|
1516
|
+
|
|
1517
|
+
class SanityCheckingConfig(TypedConfig):
|
|
1518
|
+
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
1519
|
+
"""
|
|
1520
|
+
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
|
1521
|
+
- If the `interval` is step, it makes sure that validation is called every `frequency` steps.
|
|
1522
|
+
- If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
|
|
1523
|
+
Valid values are: "disable", "warn", "error".
|
|
1524
|
+
"""
|
|
1525
|
+
|
|
1526
|
+
|
|
1527
|
+
class TrainerConfig(TypedConfig):
|
|
1528
|
+
checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
|
|
1529
|
+
"""Checkpoint loading configuration options."""
|
|
1530
|
+
|
|
1531
|
+
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
1532
|
+
"""Checkpoint saving configuration options."""
|
|
1533
|
+
|
|
1534
|
+
logging: LoggingConfig = LoggingConfig()
|
|
1535
|
+
"""Logging/experiment tracking (e.g., WandB) configuration options."""
|
|
1536
|
+
|
|
1537
|
+
optimizer: OptimizationConfig = OptimizationConfig()
|
|
1538
|
+
"""Optimization configuration options."""
|
|
1539
|
+
|
|
1540
|
+
reproducibility: ReproducibilityConfig = ReproducibilityConfig()
|
|
1541
|
+
"""Reproducibility configuration options."""
|
|
1542
|
+
|
|
1543
|
+
sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
|
|
1544
|
+
"""Sanity checking configuration options."""
|
|
1545
|
+
|
|
1546
|
+
actsave: ActSaveConfig | None = ActSaveConfig(enabled=False)
|
|
1547
|
+
"""Activation saving configuration options."""
|
|
1548
|
+
|
|
1549
|
+
early_stopping: EarlyStoppingConfig | None = None
|
|
1550
|
+
"""Early stopping configuration options."""
|
|
1551
|
+
|
|
1552
|
+
profiler: ProfilerConfig | None = None
|
|
1553
|
+
"""
|
|
1554
|
+
To profile individual steps during training and assist in identifying bottlenecks.
|
|
1555
|
+
Default: ``None``.
|
|
1556
|
+
"""
|
|
1557
|
+
|
|
1558
|
+
callbacks: list[CallbackConfig] = []
|
|
1559
|
+
"""Callbacks to use during training."""
|
|
1560
|
+
|
|
1561
|
+
detect_anomaly: bool | None = None
|
|
1562
|
+
"""Enable anomaly detection for the autograd engine.
|
|
1563
|
+
Default: ``False``.
|
|
1564
|
+
"""
|
|
1565
|
+
|
|
1566
|
+
plugins: list[PluginConfigProtocol] | None = None
|
|
1567
|
+
"""
|
|
1568
|
+
Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
|
1569
|
+
Default: ``None``.
|
|
1570
|
+
"""
|
|
1571
|
+
|
|
1572
|
+
auto_determine_num_nodes: bool = True
|
|
1573
|
+
"""
|
|
1574
|
+
If enabled, will automatically determine the number of nodes for distributed training.
|
|
1575
|
+
|
|
1576
|
+
This will only work on:
|
|
1577
|
+
- SLURM clusters
|
|
1578
|
+
- LSF clusters
|
|
1579
|
+
"""
|
|
1580
|
+
|
|
1581
|
+
fast_dev_run: int | bool = False
|
|
1582
|
+
"""Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
|
|
1583
|
+
of train, val and test to find any bugs (ie: a sort of unit test).
|
|
1584
|
+
Default: ``False``.
|
|
1585
|
+
"""
|
|
1586
|
+
|
|
1587
|
+
precision: (
|
|
1588
|
+
Literal[
|
|
1589
|
+
"64-true",
|
|
1590
|
+
"32-true",
|
|
1591
|
+
"fp16-mixed",
|
|
1592
|
+
"bf16-mixed",
|
|
1593
|
+
"16-mixed-auto",
|
|
1594
|
+
]
|
|
1595
|
+
| None
|
|
1596
|
+
) = None
|
|
1597
|
+
"""
|
|
1598
|
+
Training precision. Can be one of:
|
|
1599
|
+
- "64-true": Double precision (64-bit).
|
|
1600
|
+
- "32-true": Full precision (32-bit).
|
|
1601
|
+
- "fp16-mixed": Float16 mixed precision.
|
|
1602
|
+
- "bf16-mixed": BFloat16 mixed precision.
|
|
1603
|
+
- "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
|
|
1604
|
+
"""
|
|
1605
|
+
|
|
1606
|
+
max_epochs: int | None = None
|
|
1607
|
+
"""Stop training once this number of epochs is reached. Disabled by default (None).
|
|
1608
|
+
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
|
|
1609
|
+
To enable infinite training, set ``max_epochs = -1``.
|
|
1610
|
+
"""
|
|
1611
|
+
|
|
1612
|
+
min_epochs: int | None = None
|
|
1613
|
+
"""Force training for at least these many epochs. Disabled by default (None).
|
|
1614
|
+
"""
|
|
1615
|
+
|
|
1616
|
+
max_steps: int = -1
|
|
1617
|
+
"""Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
|
1618
|
+
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
|
1619
|
+
``max_epochs`` to ``-1``.
|
|
1620
|
+
"""
|
|
1621
|
+
|
|
1622
|
+
min_steps: int | None = None
|
|
1623
|
+
"""Force training for at least these number of steps. Disabled by default (``None``).
|
|
1624
|
+
"""
|
|
1625
|
+
|
|
1626
|
+
max_time: str | timedelta | dict[str, int] | None = None
|
|
1627
|
+
"""Stop training after this amount of time has passed. Disabled by default (``None``).
|
|
1628
|
+
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
|
|
1629
|
+
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
|
|
1630
|
+
:class:`datetime.timedelta`.
|
|
1631
|
+
"""
|
|
1632
|
+
|
|
1633
|
+
limit_train_batches: int | float | None = None
|
|
1634
|
+
"""How much of training dataset to check (float = fraction, int = num_batches).
|
|
1635
|
+
Default: ``1.0``.
|
|
1636
|
+
"""
|
|
1637
|
+
|
|
1638
|
+
limit_val_batches: int | float | None = None
|
|
1639
|
+
"""How much of validation dataset to check (float = fraction, int = num_batches).
|
|
1640
|
+
Default: ``1.0``.
|
|
1641
|
+
"""
|
|
1642
|
+
|
|
1643
|
+
limit_test_batches: int | float | None = None
|
|
1644
|
+
"""How much of test dataset to check (float = fraction, int = num_batches).
|
|
1645
|
+
Default: ``1.0``.
|
|
1646
|
+
"""
|
|
1647
|
+
|
|
1648
|
+
limit_predict_batches: int | float | None = None
|
|
1649
|
+
"""How much of prediction dataset to check (float = fraction, int = num_batches).
|
|
1650
|
+
Default: ``1.0``.
|
|
1651
|
+
"""
|
|
1652
|
+
|
|
1653
|
+
overfit_batches: int | float = 0.0
|
|
1654
|
+
"""Overfit a fraction of training/validation data (float) or a set number of batches (int).
|
|
1655
|
+
Default: ``0.0``.
|
|
1656
|
+
"""
|
|
1657
|
+
|
|
1658
|
+
val_check_interval: int | float | None = None
|
|
1659
|
+
"""How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
|
|
1660
|
+
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
|
|
1661
|
+
batches. An ``int`` value can only be higher than the number of training batches when
|
|
1662
|
+
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
|
|
1663
|
+
across epochs or during iteration-based training.
|
|
1664
|
+
Default: ``1.0``.
|
|
1665
|
+
"""
|
|
1666
|
+
|
|
1667
|
+
check_val_every_n_epoch: int | None = 1
|
|
1668
|
+
"""Perform a validation loop every after every `N` training epochs. If ``None``,
|
|
1669
|
+
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
|
|
1670
|
+
to be an integer value.
|
|
1671
|
+
Default: ``1``.
|
|
1672
|
+
"""
|
|
1673
|
+
|
|
1674
|
+
num_sanity_val_steps: int | None = None
|
|
1675
|
+
"""Sanity check runs n validation batches before starting the training routine.
|
|
1676
|
+
Set it to `-1` to run all batches in all validation dataloaders.
|
|
1677
|
+
Default: ``2``.
|
|
1678
|
+
"""
|
|
1679
|
+
|
|
1680
|
+
log_every_n_steps: int | None = None
|
|
1681
|
+
"""How often to log within steps.
|
|
1682
|
+
Default: ``50``.
|
|
1683
|
+
"""
|
|
1684
|
+
|
|
1685
|
+
inference_mode: bool = True
|
|
1686
|
+
"""Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
|
|
1687
|
+
Default: ``True``.
|
|
1688
|
+
"""
|
|
1689
|
+
|
|
1690
|
+
use_distributed_sampler: bool | None = None
|
|
1691
|
+
"""Whether to wrap the DataLoader's sampler with
|
|
1692
|
+
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
|
|
1693
|
+
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
|
|
1694
|
+
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
|
|
1695
|
+
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
|
|
1696
|
+
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
|
|
1697
|
+
we don't do this automatically.
|
|
1698
|
+
Default: ``True``.
|
|
1699
|
+
"""
|
|
1700
|
+
|
|
1701
|
+
accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
|
|
1702
|
+
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
|
1703
|
+
as well as custom accelerator instances.
|
|
1704
|
+
Default: ``"auto"``.
|
|
1705
|
+
"""
|
|
1706
|
+
|
|
1707
|
+
strategy: StrategyConfigProtocol | StrategyLiteral | None = None
|
|
1708
|
+
"""Supports different training strategies with aliases as well custom strategies.
|
|
1709
|
+
Default: ``"auto"``.
|
|
1710
|
+
"""
|
|
1711
|
+
|
|
1712
|
+
devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
|
|
1713
|
+
"""The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
|
|
1714
|
+
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
1715
|
+
"""
|
|
1716
|
+
|
|
1717
|
+
auto_wrap_trainer: bool = True
|
|
1718
|
+
"""If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time."""
|
|
1719
|
+
auto_set_default_root_dir: bool = True
|
|
1720
|
+
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
|
1721
|
+
supports_shared_parameters: bool = True
|
|
1722
|
+
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1723
|
+
|
|
1724
|
+
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1725
|
+
"""
|
|
1726
|
+
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
|
1727
|
+
|
|
1728
|
+
Please refer to the Lightning documentation for a list of valid keyword arguments.
|
|
1729
|
+
"""
|
|
1730
|
+
|
|
1731
|
+
additional_lightning_kwargs: dict[str, Any] = {}
|
|
1732
|
+
"""
|
|
1733
|
+
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
|
1734
|
+
|
|
1735
|
+
This is essentially a non-type-checked version of `lightning_kwargs`.
|
|
1736
|
+
"""
|
|
1737
|
+
|
|
1738
|
+
set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
|
|
1739
|
+
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
|
1740
|
+
|
|
1741
|
+
|
|
1742
|
+
class SeedConfig(TypedConfig):
|
|
1743
|
+
seed: int
|
|
1744
|
+
"""Seed for the random number generator."""
|
|
1745
|
+
|
|
1746
|
+
seed_workers: bool = False
|
|
1747
|
+
"""Whether to seed the workers of the dataloader."""
|
|
1748
|
+
|
|
1749
|
+
|
|
1750
|
+
Signal: TypeAlias = Literal[
|
|
1751
|
+
"SIGHUP",
|
|
1752
|
+
"SIGINT",
|
|
1753
|
+
"SIGQUIT",
|
|
1754
|
+
"SIGILL",
|
|
1755
|
+
"SIGTRAP",
|
|
1756
|
+
"SIGABRT",
|
|
1757
|
+
"SIGBUS",
|
|
1758
|
+
"SIGFPE",
|
|
1759
|
+
"SIGKILL",
|
|
1760
|
+
"SIGUSR1",
|
|
1761
|
+
"SIGSEGV",
|
|
1762
|
+
"SIGUSR2",
|
|
1763
|
+
"SIGPIPE",
|
|
1764
|
+
"SIGALRM",
|
|
1765
|
+
"SIGTERM",
|
|
1766
|
+
"SIGCHLD",
|
|
1767
|
+
"SIGCONT",
|
|
1768
|
+
"SIGSTOP",
|
|
1769
|
+
"SIGTSTP",
|
|
1770
|
+
"SIGTTIN",
|
|
1771
|
+
"SIGTTOU",
|
|
1772
|
+
"SIGURG",
|
|
1773
|
+
"SIGXCPU",
|
|
1774
|
+
"SIGXFSZ",
|
|
1775
|
+
"SIGVTALRM",
|
|
1776
|
+
"SIGPROF",
|
|
1777
|
+
"SIGWINCH",
|
|
1778
|
+
"SIGIO",
|
|
1779
|
+
"SIGPWR",
|
|
1780
|
+
"SIGSYS",
|
|
1781
|
+
"SIGRTMIN",
|
|
1782
|
+
"SIGRTMAX",
|
|
1783
|
+
]
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
class SubmitConfig(TypedConfig):
|
|
1787
|
+
auto_requeue_signals: list[Signal] = [
|
|
1788
|
+
# "SIGUSR1",
|
|
1789
|
+
# On SIGURG:
|
|
1790
|
+
# Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
|
|
1791
|
+
# We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
|
|
1792
|
+
"SIGURG"
|
|
1793
|
+
]
|
|
1794
|
+
"""Signals that will trigger an automatic requeue of the job."""
|
|
1795
|
+
|
|
1796
|
+
def _resolved_auto_requeue_signals(self) -> list[signal.Signals]:
|
|
1797
|
+
return [getattr(signal.Signals, sig) for sig in self.auto_requeue_signals]
|
|
1798
|
+
|
|
1799
|
+
|
|
1800
|
+
class RunnerConfig(TypedConfig):
|
|
1801
|
+
python_logging: PythonLogging = PythonLogging()
|
|
1802
|
+
"""Python logging configuration options."""
|
|
1803
|
+
|
|
1804
|
+
seed: SeedConfig = SeedConfig(seed=0)
|
|
1805
|
+
"""Seed everything configuration options."""
|
|
1806
|
+
|
|
1807
|
+
submit: SubmitConfig = SubmitConfig()
|
|
1808
|
+
"""Submit (e.g., SLURM or LSF) configuration options."""
|
|
1809
|
+
|
|
1810
|
+
dump_run_information: bool = True
|
|
1811
|
+
"""
|
|
1812
|
+
If enabled, will dump different bits of run information to the output directory before starting the run.
|
|
1813
|
+
This includes:
|
|
1814
|
+
- Run config
|
|
1815
|
+
- Full set of environment variables
|
|
1816
|
+
"""
|
|
1817
|
+
|
|
1818
|
+
additional_env_vars: dict[str, str] = {}
|
|
1819
|
+
"""Additional environment variables to set when running the script."""
|
|
1820
|
+
|
|
1821
|
+
|
|
1822
|
+
class MetricConfig(TypedConfig):
|
|
1823
|
+
name: str
|
|
1824
|
+
"""The name of the primary metric."""
|
|
1825
|
+
|
|
1826
|
+
mode: Literal["min", "max"]
|
|
1827
|
+
"""
|
|
1828
|
+
The mode of the primary metric:
|
|
1829
|
+
- "min" for metrics that should be minimized (e.g., loss)
|
|
1830
|
+
- "max" for metrics that should be maximized (e.g., accuracy)
|
|
1831
|
+
"""
|
|
1832
|
+
|
|
1833
|
+
@property
|
|
1834
|
+
def validation_monitor(self) -> str:
|
|
1835
|
+
return f"val/{self.name}"
|
|
1836
|
+
|
|
1837
|
+
def __post_init__(self):
|
|
1838
|
+
for split in ("train", "val", "test", "predict"):
|
|
1839
|
+
if self.name.startswith(f"{split}/"):
|
|
1840
|
+
raise ValueError(
|
|
1841
|
+
f"Primary metric name should not start with '{split}/'. "
|
|
1842
|
+
f"Just use '{self.name[len(split) + 1:]}' instead. "
|
|
1843
|
+
"The split name is automatically added depending on the context."
|
|
1844
|
+
)
|
|
1845
|
+
|
|
1846
|
+
@classmethod
|
|
1847
|
+
def loss(cls, mode: Literal["min", "max"] = "min"):
|
|
1848
|
+
return cls(name="loss", mode=mode)
|
|
1849
|
+
|
|
1850
|
+
|
|
1851
|
+
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
1852
|
+
|
|
1853
|
+
|
|
1854
|
+
class BaseConfig(TypedConfig):
|
|
1855
|
+
id: str = Field(default_factory=lambda: BaseConfig.generate_id())
|
|
1856
|
+
"""ID of the run."""
|
|
1857
|
+
name: str | None = None
|
|
1858
|
+
"""Run name."""
|
|
1859
|
+
name_parts: list[str] = []
|
|
1860
|
+
"""A list of parts used to construct the run name. This is useful for constructing the run name dynamically."""
|
|
1861
|
+
project: str | None = None
|
|
1862
|
+
"""Project name."""
|
|
1863
|
+
tags: list[str] = []
|
|
1864
|
+
"""Tags for the run."""
|
|
1865
|
+
notes: list[str] = []
|
|
1866
|
+
"""Human readable notes for the run."""
|
|
1867
|
+
|
|
1868
|
+
debug: bool = False
|
|
1869
|
+
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
|
1870
|
+
environment: Annotated[EnvironmentConfig, Field(repr=False)] = EnvironmentConfig()
|
|
1871
|
+
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
|
1872
|
+
|
|
1873
|
+
directory: DirectoryConfig = DirectoryConfig()
|
|
1874
|
+
"""Directory configuration options."""
|
|
1875
|
+
trainer: TrainerConfig = TrainerConfig()
|
|
1876
|
+
"""PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
|
|
1877
|
+
runner: RunnerConfig = RunnerConfig()
|
|
1878
|
+
"""`ll.Runner` configuration options."""
|
|
1879
|
+
|
|
1880
|
+
primary_metric: PrimaryMetricConfig | None = None
|
|
1881
|
+
"""Primary metric configuration options. This is used in the following ways:
|
|
1882
|
+
- To determine the best model checkpoint to save with the ModelCheckpoint callback.
|
|
1883
|
+
- To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
|
|
1884
|
+
- For the ReduceLROnPlateau scheduler.
|
|
1885
|
+
"""
|
|
1886
|
+
|
|
1887
|
+
meta: dict[str, Any] = {}
|
|
1888
|
+
"""Additional metadata for this run. This can be used to store arbitrary data that is not part of the config schema."""
|
|
1889
|
+
|
|
1890
|
+
@property
|
|
1891
|
+
def run_name(self) -> str:
|
|
1892
|
+
parts = self.name_parts.copy()
|
|
1893
|
+
if self.name is not None:
|
|
1894
|
+
parts = [self.name] + parts
|
|
1895
|
+
name = "-".join(parts)
|
|
1896
|
+
if not name:
|
|
1897
|
+
name = self.id
|
|
1898
|
+
return name
|
|
1899
|
+
|
|
1900
|
+
def clone(self, with_new_id: bool = True) -> Self:
|
|
1901
|
+
c = copy.deepcopy(self)
|
|
1902
|
+
if with_new_id:
|
|
1903
|
+
c.id = BaseConfig.generate_id()
|
|
1904
|
+
return c
|
|
1905
|
+
|
|
1906
|
+
def subdirectory(self, subdirectory: str) -> Path:
|
|
1907
|
+
return self.directory.resolve_subdirectory(self.id, subdirectory)
|
|
1908
|
+
|
|
1909
|
+
# region Helper methods
|
|
1910
|
+
def with_project_root_(self, project_root: str | Path | os.PathLike) -> Self:
|
|
1911
|
+
"""
|
|
1912
|
+
Set the project root directory for the trainer.
|
|
1913
|
+
|
|
1914
|
+
Args:
|
|
1915
|
+
project_root (Path): The base directory to use.
|
|
1916
|
+
|
|
1917
|
+
Returns:
|
|
1918
|
+
self: The current instance of the class.
|
|
1919
|
+
"""
|
|
1920
|
+
self.directory.project_root = Path(project_root)
|
|
1921
|
+
return self
|
|
1922
|
+
|
|
1923
|
+
def reset_(
|
|
1924
|
+
self,
|
|
1925
|
+
*,
|
|
1926
|
+
id: bool = True,
|
|
1927
|
+
basic: bool = True,
|
|
1928
|
+
project_root: bool = True,
|
|
1929
|
+
environment: bool = True,
|
|
1930
|
+
meta: bool = True,
|
|
1931
|
+
):
|
|
1932
|
+
"""
|
|
1933
|
+
Reset the configuration object to its initial state.
|
|
1934
|
+
|
|
1935
|
+
Parameters:
|
|
1936
|
+
- id (bool): If True, generate a new ID for the configuration object.
|
|
1937
|
+
- basic (bool): If True, reset basic attributes like name, project, tags, and notes.
|
|
1938
|
+
- project_root (bool): If True, reset the directory configuration to its initial state.
|
|
1939
|
+
- environment (bool): If True, reset the environment configuration to its initial state.
|
|
1940
|
+
- meta (bool): If True, reset the meta dictionary to an empty dictionary.
|
|
1941
|
+
|
|
1942
|
+
Returns:
|
|
1943
|
+
- self: The updated configuration object.
|
|
1944
|
+
|
|
1945
|
+
"""
|
|
1946
|
+
if id:
|
|
1947
|
+
self.id = self.generate_id()
|
|
1948
|
+
|
|
1949
|
+
if basic:
|
|
1950
|
+
self.name = None
|
|
1951
|
+
self.name_parts = []
|
|
1952
|
+
self.project = None
|
|
1953
|
+
self.tags = []
|
|
1954
|
+
self.notes = []
|
|
1955
|
+
|
|
1956
|
+
if project_root:
|
|
1957
|
+
self.directory = DirectoryConfig()
|
|
1958
|
+
|
|
1959
|
+
if environment:
|
|
1960
|
+
self.environment = EnvironmentConfig()
|
|
1961
|
+
|
|
1962
|
+
if meta:
|
|
1963
|
+
self.meta = {}
|
|
1964
|
+
|
|
1965
|
+
return self
|
|
1966
|
+
|
|
1967
|
+
def concise_repr(self) -> str:
|
|
1968
|
+
"""Get a concise representation of the configuration object."""
|
|
1969
|
+
|
|
1970
|
+
def _truncate(s: str, max_len: int = 50):
|
|
1971
|
+
return s if len(s) <= max_len else f"{s[:max_len - 3]}..."
|
|
1972
|
+
|
|
1973
|
+
cls_name = self.__class__.__name__
|
|
1974
|
+
|
|
1975
|
+
parts: list[str] = []
|
|
1976
|
+
parts.append(f"name={self.run_name}")
|
|
1977
|
+
if self.project:
|
|
1978
|
+
parts.append(f"project={_truncate(self.project)}")
|
|
1979
|
+
|
|
1980
|
+
return f"{cls_name}({', '.join(parts)})"
|
|
1981
|
+
|
|
1982
|
+
# endregion
|
|
1983
|
+
|
|
1984
|
+
# region Seeding
|
|
1985
|
+
|
|
1986
|
+
_rng: ClassVar[np.random.Generator | None] = None
|
|
1987
|
+
|
|
1988
|
+
@staticmethod
|
|
1989
|
+
def generate_id(
|
|
1990
|
+
*,
|
|
1991
|
+
length: int = 8,
|
|
1992
|
+
ignore_rng: bool = False,
|
|
1993
|
+
) -> str:
|
|
1994
|
+
"""
|
|
1995
|
+
Generate a random ID of specified length.
|
|
1996
|
+
|
|
1997
|
+
Args:
|
|
1998
|
+
length (int): The length of the generated ID. Default is 8.
|
|
1999
|
+
ignore_rng (bool): If True, ignore the global random number generator and use a new one. Default is False.
|
|
2000
|
+
|
|
2001
|
+
Returns:
|
|
2002
|
+
str: The generated random ID.
|
|
2003
|
+
|
|
2004
|
+
Raises:
|
|
2005
|
+
IdSeedWarning: If the global random number generator is None and ignore_rng is False.
|
|
2006
|
+
|
|
2007
|
+
Notes:
|
|
2008
|
+
- The generated IDs will not be reproducible if the global random number generator is None and ignore_rng is False.
|
|
2009
|
+
- To ensure reproducibility, call BaseConfig.set_seed(...) before generating any IDs.
|
|
2010
|
+
"""
|
|
2011
|
+
rng = BaseConfig._rng if not ignore_rng else np.random.default_rng()
|
|
2012
|
+
if rng is None:
|
|
2013
|
+
warnings.warn(
|
|
2014
|
+
"BaseConfig._rng is None. The generated IDs will not be reproducible. "
|
|
2015
|
+
+ "To fix this, call BaseConfig.set_seed(...) before generating any IDs.",
|
|
2016
|
+
category=IdSeedWarning,
|
|
2017
|
+
)
|
|
2018
|
+
rng = np.random.default_rng()
|
|
2019
|
+
|
|
2020
|
+
alphabet = list(string.ascii_lowercase + string.digits)
|
|
2021
|
+
|
|
2022
|
+
id = "".join(rng.choice(alphabet) for _ in range(length))
|
|
2023
|
+
return id
|
|
2024
|
+
|
|
2025
|
+
@staticmethod
|
|
2026
|
+
def set_seed(seed: int | None = None) -> None:
|
|
2027
|
+
"""
|
|
2028
|
+
Set the seed for the random number generator.
|
|
2029
|
+
|
|
2030
|
+
Args:
|
|
2031
|
+
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
|
2032
|
+
|
|
2033
|
+
Returns:
|
|
2034
|
+
None
|
|
2035
|
+
"""
|
|
2036
|
+
if seed is None:
|
|
2037
|
+
seed = int(time.time() * 1000)
|
|
2038
|
+
log.critical(f"Seeding BaseConfig with seed {seed}")
|
|
2039
|
+
BaseConfig._rng = np.random.default_rng(seed)
|
|
2040
|
+
|
|
2041
|
+
# endregion
|
|
2042
|
+
|
|
2043
|
+
@classmethod
|
|
2044
|
+
def from_checkpoint(
|
|
2045
|
+
cls,
|
|
2046
|
+
path: str | Path,
|
|
2047
|
+
*,
|
|
2048
|
+
hparams_key: str = "hyper_parameters",
|
|
2049
|
+
):
|
|
2050
|
+
ckpt = torch.load(path)
|
|
2051
|
+
if (hparams := ckpt.get(hparams_key)) is None:
|
|
2052
|
+
raise ValueError(
|
|
2053
|
+
f"The checkpoint does not contain the `{hparams_key}` attribute. "
|
|
2054
|
+
"Are you sure this is a valid Lightning checkpoint?"
|
|
2055
|
+
)
|
|
2056
|
+
return cls.model_validate(hparams)
|
|
2057
|
+
|
|
2058
|
+
def ll_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
|
2059
|
+
yield self.trainer.actsave
|
|
2060
|
+
yield self.trainer.early_stopping
|
|
2061
|
+
yield self.trainer.checkpoint_saving
|
|
2062
|
+
yield self.trainer.logging
|
|
2063
|
+
yield self.trainer.optimizer
|
|
2064
|
+
yield from self.trainer.callbacks
|