nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +12 -1
- nshtrainer/callbacks/debug_flag.py +72 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +124 -67
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +787 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/callback.py +0 -206
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/model/config.py
CHANGED
|
@@ -3,1037 +3,24 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import string
|
|
5
5
|
import time
|
|
6
|
-
from abc import
|
|
7
|
-
from collections.abc import Iterable, Sequence
|
|
8
|
-
from datetime import timedelta
|
|
6
|
+
from collections.abc import Iterable
|
|
9
7
|
from pathlib import Path
|
|
10
|
-
from typing import
|
|
11
|
-
Annotated,
|
|
12
|
-
Any,
|
|
13
|
-
ClassVar,
|
|
14
|
-
Literal,
|
|
15
|
-
Protocol,
|
|
16
|
-
TypeAlias,
|
|
17
|
-
runtime_checkable,
|
|
18
|
-
)
|
|
8
|
+
from typing import Annotated, Any, ClassVar
|
|
19
9
|
|
|
20
10
|
import nshconfig as C
|
|
21
11
|
import numpy as np
|
|
22
12
|
import torch
|
|
23
|
-
from
|
|
24
|
-
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
25
|
-
from lightning.pytorch.accelerators import Accelerator
|
|
26
|
-
from lightning.pytorch.callbacks.callback import Callback
|
|
27
|
-
from lightning.pytorch.loggers import Logger
|
|
28
|
-
from lightning.pytorch.plugins import _PLUGIN_INPUT
|
|
29
|
-
from lightning.pytorch.plugins.layer_sync import LayerSync
|
|
30
|
-
from lightning.pytorch.plugins.precision.precision import Precision
|
|
31
|
-
from lightning.pytorch.profilers import Profiler
|
|
32
|
-
from lightning.pytorch.strategies.strategy import Strategy
|
|
33
|
-
from typing_extensions import Self, TypedDict, TypeVar, override
|
|
13
|
+
from typing_extensions import Self
|
|
34
14
|
|
|
35
|
-
from ..
|
|
36
|
-
from .._hf_hub import HuggingFaceHubConfig
|
|
37
|
-
from ..callbacks import (
|
|
38
|
-
BestCheckpointCallbackConfig,
|
|
39
|
-
CallbackConfig,
|
|
40
|
-
EarlyStoppingConfig,
|
|
41
|
-
LastCheckpointCallbackConfig,
|
|
42
|
-
OnExceptionCheckpointCallbackConfig,
|
|
43
|
-
)
|
|
15
|
+
from .._directory import DirectoryConfig
|
|
44
16
|
from ..callbacks.base import CallbackConfigBase
|
|
45
|
-
from ..loggers import (
|
|
46
|
-
CSVLoggerConfig,
|
|
47
|
-
LoggerConfig,
|
|
48
|
-
TensorboardLoggerConfig,
|
|
49
|
-
WandbLoggerConfig,
|
|
50
|
-
)
|
|
51
17
|
from ..metrics import MetricConfig
|
|
18
|
+
from ..trainer._config import TrainerConfig
|
|
52
19
|
from ..util._environment_info import EnvironmentConfig
|
|
53
20
|
|
|
54
21
|
log = logging.getLogger(__name__)
|
|
55
22
|
|
|
56
23
|
|
|
57
|
-
class BaseProfilerConfig(C.Config, ABC):
|
|
58
|
-
dirpath: str | Path | None = None
|
|
59
|
-
"""
|
|
60
|
-
Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
|
|
61
|
-
``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
|
|
62
|
-
will be used.
|
|
63
|
-
"""
|
|
64
|
-
filename: str | None = None
|
|
65
|
-
"""
|
|
66
|
-
If present, filename where the profiler results will be saved instead of printing to stdout.
|
|
67
|
-
The ``.txt`` extension will be used automatically.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
@abstractmethod
|
|
71
|
-
def create_profiler(self, root_config: "BaseConfig") -> Profiler: ...
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class SimpleProfilerConfig(BaseProfilerConfig):
|
|
75
|
-
name: Literal["simple"] = "simple"
|
|
76
|
-
|
|
77
|
-
extended: bool = True
|
|
78
|
-
"""
|
|
79
|
-
If ``True``, adds extra columns representing number of calls and percentage of
|
|
80
|
-
total time spent onrespective action.
|
|
81
|
-
"""
|
|
82
|
-
|
|
83
|
-
@override
|
|
84
|
-
def create_profiler(self, root_config):
|
|
85
|
-
from lightning.pytorch.profilers.simple import SimpleProfiler
|
|
86
|
-
|
|
87
|
-
if (dirpath := self.dirpath) is None:
|
|
88
|
-
dirpath = root_config.directory.resolve_subdirectory(
|
|
89
|
-
root_config.id, "profile"
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
if (filename := self.filename) is None:
|
|
93
|
-
filename = f"{root_config.id}_profile.txt"
|
|
94
|
-
|
|
95
|
-
return SimpleProfiler(
|
|
96
|
-
extended=self.extended,
|
|
97
|
-
dirpath=dirpath,
|
|
98
|
-
filename=filename,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
103
|
-
name: Literal["advanced"] = "advanced"
|
|
104
|
-
|
|
105
|
-
line_count_restriction: float = 1.0
|
|
106
|
-
"""
|
|
107
|
-
This can be used to limit the number of functions
|
|
108
|
-
reported for each action. either an integer (to select a count of lines),
|
|
109
|
-
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
@override
|
|
113
|
-
def create_profiler(self, root_config):
|
|
114
|
-
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
|
115
|
-
|
|
116
|
-
if (dirpath := self.dirpath) is None:
|
|
117
|
-
dirpath = root_config.directory.resolve_subdirectory(
|
|
118
|
-
root_config.id, "profile"
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
if (filename := self.filename) is None:
|
|
122
|
-
filename = f"{root_config.id}_profile.txt"
|
|
123
|
-
|
|
124
|
-
return AdvancedProfiler(
|
|
125
|
-
line_count_restriction=self.line_count_restriction,
|
|
126
|
-
dirpath=dirpath,
|
|
127
|
-
filename=filename,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
132
|
-
name: Literal["pytorch"] = "pytorch"
|
|
133
|
-
|
|
134
|
-
group_by_input_shapes: bool = False
|
|
135
|
-
"""Include operator input shapes and group calls by shape."""
|
|
136
|
-
|
|
137
|
-
emit_nvtx: bool = False
|
|
138
|
-
"""
|
|
139
|
-
Context manager that makes every autograd operation emit an NVTX range
|
|
140
|
-
Run::
|
|
141
|
-
|
|
142
|
-
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
|
|
143
|
-
|
|
144
|
-
To visualize, you can either use::
|
|
145
|
-
|
|
146
|
-
nvvp trace_name.prof
|
|
147
|
-
torch.autograd.profiler.load_nvprof(path)
|
|
148
|
-
"""
|
|
149
|
-
|
|
150
|
-
export_to_chrome: bool = True
|
|
151
|
-
"""
|
|
152
|
-
Whether to export the sequence of profiled operators for Chrome.
|
|
153
|
-
It will generate a ``.json`` file which can be read by Chrome.
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
row_limit: int = 20
|
|
157
|
-
"""
|
|
158
|
-
Limit the number of rows in a table, ``-1`` is a special value that
|
|
159
|
-
removes the limit completely.
|
|
160
|
-
"""
|
|
161
|
-
|
|
162
|
-
sort_by_key: str | None = None
|
|
163
|
-
"""
|
|
164
|
-
Attribute used to sort entries. By default
|
|
165
|
-
they are printed in the same order as they were registered.
|
|
166
|
-
Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
|
|
167
|
-
``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
|
|
168
|
-
``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
|
|
169
|
-
"""
|
|
170
|
-
|
|
171
|
-
record_module_names: bool = True
|
|
172
|
-
"""Whether to add module names while recording autograd operation."""
|
|
173
|
-
|
|
174
|
-
table_kwargs: dict[str, Any] | None = None
|
|
175
|
-
"""Dictionary with keyword arguments for the summary table."""
|
|
176
|
-
|
|
177
|
-
additional_profiler_kwargs: dict[str, Any] = {}
|
|
178
|
-
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
|
179
|
-
|
|
180
|
-
@override
|
|
181
|
-
def create_profiler(self, root_config):
|
|
182
|
-
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
|
183
|
-
|
|
184
|
-
if (dirpath := self.dirpath) is None:
|
|
185
|
-
dirpath = root_config.directory.resolve_subdirectory(
|
|
186
|
-
root_config.id, "profile"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
if (filename := self.filename) is None:
|
|
190
|
-
filename = f"{root_config.id}_profile.txt"
|
|
191
|
-
|
|
192
|
-
return PyTorchProfiler(
|
|
193
|
-
group_by_input_shapes=self.group_by_input_shapes,
|
|
194
|
-
emit_nvtx=self.emit_nvtx,
|
|
195
|
-
export_to_chrome=self.export_to_chrome,
|
|
196
|
-
row_limit=self.row_limit,
|
|
197
|
-
sort_by_key=self.sort_by_key,
|
|
198
|
-
record_module_names=self.record_module_names,
|
|
199
|
-
table_kwargs=self.table_kwargs,
|
|
200
|
-
dirpath=dirpath,
|
|
201
|
-
filename=filename,
|
|
202
|
-
**self.additional_profiler_kwargs,
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
ProfilerConfig: TypeAlias = Annotated[
|
|
207
|
-
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
|
208
|
-
C.Field(discriminator="name"),
|
|
209
|
-
]
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
class LoggingConfig(CallbackConfigBase):
|
|
213
|
-
enabled: bool = True
|
|
214
|
-
"""Enable experiment tracking."""
|
|
215
|
-
|
|
216
|
-
loggers: Sequence[LoggerConfig] = [
|
|
217
|
-
WandbLoggerConfig(),
|
|
218
|
-
CSVLoggerConfig(),
|
|
219
|
-
TensorboardLoggerConfig(),
|
|
220
|
-
]
|
|
221
|
-
"""Loggers to use for experiment tracking."""
|
|
222
|
-
|
|
223
|
-
log_lr: bool | Literal["step", "epoch"] = True
|
|
224
|
-
"""If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
|
|
225
|
-
log_epoch: bool = True
|
|
226
|
-
"""If enabled, will log the fractional epoch number to the logger."""
|
|
227
|
-
|
|
228
|
-
actsave_logged_metrics: bool = False
|
|
229
|
-
"""If enabled, will automatically save logged metrics using ActSave (if nshutils is installed)."""
|
|
230
|
-
|
|
231
|
-
@property
|
|
232
|
-
def wandb(self):
|
|
233
|
-
return next(
|
|
234
|
-
(
|
|
235
|
-
logger
|
|
236
|
-
for logger in self.loggers
|
|
237
|
-
if isinstance(logger, WandbLoggerConfig)
|
|
238
|
-
),
|
|
239
|
-
None,
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
@property
|
|
243
|
-
def csv(self):
|
|
244
|
-
return next(
|
|
245
|
-
(logger for logger in self.loggers if isinstance(logger, CSVLoggerConfig)),
|
|
246
|
-
None,
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
@property
|
|
250
|
-
def tensorboard(self):
|
|
251
|
-
return next(
|
|
252
|
-
(
|
|
253
|
-
logger
|
|
254
|
-
for logger in self.loggers
|
|
255
|
-
if isinstance(logger, TensorboardLoggerConfig)
|
|
256
|
-
),
|
|
257
|
-
None,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
def create_loggers(self, root_config: "BaseConfig"):
|
|
261
|
-
"""
|
|
262
|
-
Constructs and returns a list of loggers based on the provided root configuration.
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
root_config (BaseConfig): The root configuration object.
|
|
266
|
-
|
|
267
|
-
Returns:
|
|
268
|
-
list[Logger]: A list of constructed loggers.
|
|
269
|
-
"""
|
|
270
|
-
if not self.enabled:
|
|
271
|
-
return
|
|
272
|
-
|
|
273
|
-
for logger_config in sorted(
|
|
274
|
-
self.loggers,
|
|
275
|
-
key=lambda x: x.priority,
|
|
276
|
-
reverse=True,
|
|
277
|
-
):
|
|
278
|
-
if not logger_config.enabled:
|
|
279
|
-
continue
|
|
280
|
-
if (logger := logger_config.create_logger(root_config)) is None:
|
|
281
|
-
continue
|
|
282
|
-
yield logger
|
|
283
|
-
|
|
284
|
-
@override
|
|
285
|
-
def create_callbacks(self, root_config):
|
|
286
|
-
if self.log_lr:
|
|
287
|
-
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
288
|
-
|
|
289
|
-
logging_interval: str | None = None
|
|
290
|
-
if isinstance(self.log_lr, str):
|
|
291
|
-
logging_interval = self.log_lr
|
|
292
|
-
|
|
293
|
-
yield LearningRateMonitor(logging_interval=logging_interval)
|
|
294
|
-
|
|
295
|
-
if self.log_epoch:
|
|
296
|
-
from ..callbacks.log_epoch import LogEpochCallback
|
|
297
|
-
|
|
298
|
-
yield LogEpochCallback()
|
|
299
|
-
|
|
300
|
-
for logger in self.loggers:
|
|
301
|
-
if not logger or not isinstance(logger, CallbackConfigBase):
|
|
302
|
-
continue
|
|
303
|
-
|
|
304
|
-
yield from logger.create_callbacks(root_config)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
class GradientClippingConfig(C.Config):
|
|
308
|
-
enabled: bool = True
|
|
309
|
-
"""Enable gradient clipping."""
|
|
310
|
-
value: int | float
|
|
311
|
-
"""Value to use for gradient clipping."""
|
|
312
|
-
algorithm: Literal["value", "norm"] = "norm"
|
|
313
|
-
"""Norm type to use for gradient clipping."""
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
class OptimizationConfig(CallbackConfigBase):
|
|
317
|
-
log_grad_norm: bool | str | float = False
|
|
318
|
-
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
|
319
|
-
log_grad_norm_per_param: bool | str | float = False
|
|
320
|
-
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
|
321
|
-
|
|
322
|
-
log_param_norm: bool | str | float = False
|
|
323
|
-
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
|
324
|
-
log_param_norm_per_param: bool | str | float = False
|
|
325
|
-
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
|
326
|
-
|
|
327
|
-
gradient_clipping: GradientClippingConfig | None = None
|
|
328
|
-
"""Gradient clipping configuration, or None to disable gradient clipping."""
|
|
329
|
-
|
|
330
|
-
@override
|
|
331
|
-
def create_callbacks(self, root_config):
|
|
332
|
-
from ..callbacks.norm_logging import NormLoggingConfig
|
|
333
|
-
|
|
334
|
-
yield from NormLoggingConfig(
|
|
335
|
-
log_grad_norm=self.log_grad_norm,
|
|
336
|
-
log_grad_norm_per_param=self.log_grad_norm_per_param,
|
|
337
|
-
log_param_norm=self.log_param_norm,
|
|
338
|
-
log_param_norm_per_param=self.log_param_norm_per_param,
|
|
339
|
-
).create_callbacks(root_config)
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
TPlugin = TypeVar(
|
|
343
|
-
"TPlugin",
|
|
344
|
-
Precision,
|
|
345
|
-
ClusterEnvironment,
|
|
346
|
-
CheckpointIO,
|
|
347
|
-
LayerSync,
|
|
348
|
-
infer_variance=True,
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
@runtime_checkable
|
|
353
|
-
class PluginConfigProtocol(Protocol[TPlugin]):
|
|
354
|
-
def create_plugin(self) -> TPlugin: ...
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
@runtime_checkable
|
|
358
|
-
class AcceleratorConfigProtocol(Protocol):
|
|
359
|
-
def create_accelerator(self) -> Accelerator: ...
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
@runtime_checkable
|
|
363
|
-
class StrategyConfigProtocol(Protocol):
|
|
364
|
-
def create_strategy(self) -> Strategy: ...
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
AcceleratorLiteral: TypeAlias = Literal[
|
|
368
|
-
"cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"
|
|
369
|
-
]
|
|
370
|
-
|
|
371
|
-
StrategyLiteral: TypeAlias = Literal[
|
|
372
|
-
"auto",
|
|
373
|
-
"ddp",
|
|
374
|
-
"ddp_find_unused_parameters_false",
|
|
375
|
-
"ddp_find_unused_parameters_true",
|
|
376
|
-
"ddp_spawn",
|
|
377
|
-
"ddp_spawn_find_unused_parameters_false",
|
|
378
|
-
"ddp_spawn_find_unused_parameters_true",
|
|
379
|
-
"ddp_fork",
|
|
380
|
-
"ddp_fork_find_unused_parameters_false",
|
|
381
|
-
"ddp_fork_find_unused_parameters_true",
|
|
382
|
-
"ddp_notebook",
|
|
383
|
-
"dp",
|
|
384
|
-
"deepspeed",
|
|
385
|
-
"deepspeed_stage_1",
|
|
386
|
-
"deepspeed_stage_1_offload",
|
|
387
|
-
"deepspeed_stage_2",
|
|
388
|
-
"deepspeed_stage_2_offload",
|
|
389
|
-
"deepspeed_stage_3",
|
|
390
|
-
"deepspeed_stage_3_offload",
|
|
391
|
-
"deepspeed_stage_3_offload_nvme",
|
|
392
|
-
"fsdp",
|
|
393
|
-
"fsdp_cpu_offload",
|
|
394
|
-
"single_xla",
|
|
395
|
-
"xla_fsdp",
|
|
396
|
-
"xla",
|
|
397
|
-
"single_tpu",
|
|
398
|
-
]
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
def _create_symlink_to_nshrunner(base_dir: Path):
|
|
402
|
-
# Resolve the current nshrunner session directory
|
|
403
|
-
if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
|
|
404
|
-
log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
|
|
405
|
-
return
|
|
406
|
-
session_dir = Path(session_dir)
|
|
407
|
-
if not session_dir.exists() or not session_dir.is_dir():
|
|
408
|
-
log.warning(
|
|
409
|
-
f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
|
|
410
|
-
"Skipping symlink creation."
|
|
411
|
-
)
|
|
412
|
-
return
|
|
413
|
-
|
|
414
|
-
# Create the symlink
|
|
415
|
-
symlink_path = base_dir / "nshrunner"
|
|
416
|
-
if symlink_path.exists():
|
|
417
|
-
# If it already points to the correct directory, we're done
|
|
418
|
-
if symlink_path.resolve() == session_dir.resolve():
|
|
419
|
-
return
|
|
420
|
-
|
|
421
|
-
# Otherwise, we should log a warning and remove the existing symlink
|
|
422
|
-
log.warning(
|
|
423
|
-
f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
|
|
424
|
-
"Removing the existing symlink."
|
|
425
|
-
)
|
|
426
|
-
symlink_path.unlink()
|
|
427
|
-
|
|
428
|
-
symlink_path.symlink_to(session_dir)
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
class DirectoryConfig(C.Config):
|
|
432
|
-
project_root: Path | None = None
|
|
433
|
-
"""
|
|
434
|
-
Root directory for this project.
|
|
435
|
-
|
|
436
|
-
This isn't specific to the run; it is the parent directory of all runs.
|
|
437
|
-
"""
|
|
438
|
-
|
|
439
|
-
create_symlink_to_nshrunner_root: bool = True
|
|
440
|
-
"""Should we create a symlink to the root folder for the Runner (if we're in one)?"""
|
|
441
|
-
|
|
442
|
-
log: Path | None = None
|
|
443
|
-
"""Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
|
|
444
|
-
|
|
445
|
-
stdio: Path | None = None
|
|
446
|
-
"""stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
|
|
447
|
-
|
|
448
|
-
checkpoint: Path | None = None
|
|
449
|
-
"""Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
|
|
450
|
-
|
|
451
|
-
activation: Path | None = None
|
|
452
|
-
"""Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
|
|
453
|
-
|
|
454
|
-
profile: Path | None = None
|
|
455
|
-
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
|
456
|
-
|
|
457
|
-
def resolve_run_root_directory(self, run_id: str) -> Path:
|
|
458
|
-
if (project_root_dir := self.project_root) is None:
|
|
459
|
-
project_root_dir = Path.cwd()
|
|
460
|
-
|
|
461
|
-
# The default base dir is $CWD/nshtrainer/{id}/
|
|
462
|
-
base_dir = project_root_dir / "nshtrainer"
|
|
463
|
-
base_dir.mkdir(exist_ok=True)
|
|
464
|
-
|
|
465
|
-
# Add a .gitignore file to the nshtrainer directory
|
|
466
|
-
# which will ignore all files except for the .gitignore file itself
|
|
467
|
-
gitignore_path = base_dir / ".gitignore"
|
|
468
|
-
if not gitignore_path.exists():
|
|
469
|
-
gitignore_path.touch()
|
|
470
|
-
gitignore_path.write_text("*\n")
|
|
471
|
-
|
|
472
|
-
base_dir = base_dir / run_id
|
|
473
|
-
base_dir.mkdir(exist_ok=True)
|
|
474
|
-
|
|
475
|
-
# Create a symlink to the root folder for the Runner
|
|
476
|
-
if self.create_symlink_to_nshrunner_root:
|
|
477
|
-
_create_symlink_to_nshrunner(base_dir)
|
|
478
|
-
|
|
479
|
-
return base_dir
|
|
480
|
-
|
|
481
|
-
def resolve_subdirectory(
|
|
482
|
-
self,
|
|
483
|
-
run_id: str,
|
|
484
|
-
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
|
485
|
-
subdirectory: str,
|
|
486
|
-
) -> Path:
|
|
487
|
-
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
|
488
|
-
if (subdir := getattr(self, subdirectory, None)) is not None:
|
|
489
|
-
assert isinstance(
|
|
490
|
-
subdir, Path
|
|
491
|
-
), f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
|
492
|
-
return subdir
|
|
493
|
-
|
|
494
|
-
dir = self.resolve_run_root_directory(run_id)
|
|
495
|
-
dir = dir / subdirectory
|
|
496
|
-
dir.mkdir(exist_ok=True)
|
|
497
|
-
return dir
|
|
498
|
-
|
|
499
|
-
def _resolve_log_directory_for_logger(
|
|
500
|
-
self,
|
|
501
|
-
run_id: str,
|
|
502
|
-
logger: LoggerConfig,
|
|
503
|
-
) -> Path:
|
|
504
|
-
if (log_dir := logger.log_dir) is not None:
|
|
505
|
-
return log_dir
|
|
506
|
-
|
|
507
|
-
# Save to nshtrainer/{id}/log/{logger name}
|
|
508
|
-
log_dir = self.resolve_subdirectory(run_id, "log")
|
|
509
|
-
log_dir = log_dir / logger.name
|
|
510
|
-
log_dir.mkdir(exist_ok=True)
|
|
511
|
-
|
|
512
|
-
return log_dir
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
class ReproducibilityConfig(C.Config):
|
|
516
|
-
deterministic: bool | Literal["warn"] | None = None
|
|
517
|
-
"""
|
|
518
|
-
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
|
519
|
-
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
|
520
|
-
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
|
521
|
-
"""
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
525
|
-
BestCheckpointCallbackConfig
|
|
526
|
-
| LastCheckpointCallbackConfig
|
|
527
|
-
| OnExceptionCheckpointCallbackConfig,
|
|
528
|
-
C.Field(discriminator="name"),
|
|
529
|
-
]
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
class CheckpointSavingConfig(CallbackConfigBase):
|
|
533
|
-
enabled: bool = True
|
|
534
|
-
"""Enable checkpoint saving."""
|
|
535
|
-
|
|
536
|
-
checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
|
|
537
|
-
BestCheckpointCallbackConfig(),
|
|
538
|
-
LastCheckpointCallbackConfig(),
|
|
539
|
-
OnExceptionCheckpointCallbackConfig(),
|
|
540
|
-
]
|
|
541
|
-
"""Checkpoint callback configurations."""
|
|
542
|
-
|
|
543
|
-
def disable_(self):
|
|
544
|
-
self.enabled = False
|
|
545
|
-
return self
|
|
546
|
-
|
|
547
|
-
def should_save_checkpoints(self, root_config: "BaseConfig"):
|
|
548
|
-
if not self.enabled:
|
|
549
|
-
return False
|
|
550
|
-
|
|
551
|
-
if root_config.trainer.fast_dev_run:
|
|
552
|
-
return False
|
|
553
|
-
|
|
554
|
-
return True
|
|
555
|
-
|
|
556
|
-
@override
|
|
557
|
-
def create_callbacks(self, root_config: "BaseConfig"):
|
|
558
|
-
if not self.should_save_checkpoints(root_config):
|
|
559
|
-
return
|
|
560
|
-
|
|
561
|
-
for callback_config in self.checkpoint_callbacks:
|
|
562
|
-
yield from callback_config.create_callbacks(root_config)
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
class LightningTrainerKwargs(TypedDict, total=False):
|
|
566
|
-
accelerator: str | Accelerator
|
|
567
|
-
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
|
568
|
-
as well as custom accelerator instances."""
|
|
569
|
-
|
|
570
|
-
strategy: str | Strategy
|
|
571
|
-
"""Supports different training strategies with aliases as well custom strategies.
|
|
572
|
-
Default: ``"auto"``.
|
|
573
|
-
"""
|
|
574
|
-
|
|
575
|
-
devices: list[int] | str | int
|
|
576
|
-
"""The devices to use. Can be set to a positive number (int or str), a sequence of device indices
|
|
577
|
-
(list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
|
|
578
|
-
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
579
|
-
"""
|
|
580
|
-
|
|
581
|
-
num_nodes: int
|
|
582
|
-
"""Number of GPU nodes for distributed training.
|
|
583
|
-
Default: ``1``.
|
|
584
|
-
"""
|
|
585
|
-
|
|
586
|
-
precision: _PRECISION_INPUT | None
|
|
587
|
-
"""Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
|
|
588
|
-
16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
|
|
589
|
-
Can be used on CPU, GPU, TPUs, HPUs or IPUs.
|
|
590
|
-
Default: ``'32-true'``.
|
|
591
|
-
"""
|
|
592
|
-
|
|
593
|
-
logger: Logger | Iterable[Logger] | bool | None
|
|
594
|
-
"""Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
|
|
595
|
-
the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
|
|
596
|
-
``False`` will disable logging. If multiple loggers are provided, local files
|
|
597
|
-
(checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
|
|
598
|
-
Default: ``True``.
|
|
599
|
-
"""
|
|
600
|
-
|
|
601
|
-
callbacks: list[Callback] | Callback | None
|
|
602
|
-
"""Add a callback or list of callbacks.
|
|
603
|
-
Default: ``None``.
|
|
604
|
-
"""
|
|
605
|
-
|
|
606
|
-
fast_dev_run: int | bool
|
|
607
|
-
"""Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
|
|
608
|
-
of train, val and test to find any bugs (ie: a sort of unit test).
|
|
609
|
-
Default: ``False``.
|
|
610
|
-
"""
|
|
611
|
-
|
|
612
|
-
max_epochs: int | None
|
|
613
|
-
"""Stop training once this number of epochs is reached. Disabled by default (None).
|
|
614
|
-
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
|
|
615
|
-
To enable infinite training, set ``max_epochs = -1``.
|
|
616
|
-
"""
|
|
617
|
-
|
|
618
|
-
min_epochs: int | None
|
|
619
|
-
"""Force training for at least these many epochs. Disabled by default (None).
|
|
620
|
-
"""
|
|
621
|
-
|
|
622
|
-
max_steps: int
|
|
623
|
-
"""Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
|
624
|
-
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
|
625
|
-
``max_epochs`` to ``-1``.
|
|
626
|
-
"""
|
|
627
|
-
|
|
628
|
-
min_steps: int | None
|
|
629
|
-
"""Force training for at least these number of steps. Disabled by default (``None``).
|
|
630
|
-
"""
|
|
631
|
-
|
|
632
|
-
max_time: str | timedelta | dict[str, int] | None
|
|
633
|
-
"""Stop training after this amount of time has passed. Disabled by default (``None``).
|
|
634
|
-
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
|
|
635
|
-
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
|
|
636
|
-
:class:`datetime.timedelta`.
|
|
637
|
-
"""
|
|
638
|
-
|
|
639
|
-
limit_train_batches: int | float | None
|
|
640
|
-
"""How much of training dataset to check (float = fraction, int = num_batches).
|
|
641
|
-
Default: ``1.0``.
|
|
642
|
-
"""
|
|
643
|
-
|
|
644
|
-
limit_val_batches: int | float | None
|
|
645
|
-
"""How much of validation dataset to check (float = fraction, int = num_batches).
|
|
646
|
-
Default: ``1.0``.
|
|
647
|
-
"""
|
|
648
|
-
|
|
649
|
-
limit_test_batches: int | float | None
|
|
650
|
-
"""How much of test dataset to check (float = fraction, int = num_batches).
|
|
651
|
-
Default: ``1.0``.
|
|
652
|
-
"""
|
|
653
|
-
|
|
654
|
-
limit_predict_batches: int | float | None
|
|
655
|
-
"""How much of prediction dataset to check (float = fraction, int = num_batches).
|
|
656
|
-
Default: ``1.0``.
|
|
657
|
-
"""
|
|
658
|
-
|
|
659
|
-
overfit_batches: int | float
|
|
660
|
-
"""Overfit a fraction of training/validation data (float) or a set number of batches (int).
|
|
661
|
-
Default: ``0.0``.
|
|
662
|
-
"""
|
|
663
|
-
|
|
664
|
-
val_check_interval: int | float | None
|
|
665
|
-
"""How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
|
|
666
|
-
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
|
|
667
|
-
batches. An ``int`` value can only be higher than the number of training batches when
|
|
668
|
-
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
|
|
669
|
-
across epochs or during iteration-based training.
|
|
670
|
-
Default: ``1.0``.
|
|
671
|
-
"""
|
|
672
|
-
|
|
673
|
-
check_val_every_n_epoch: int | None
|
|
674
|
-
"""Perform a validation loop every after every `N` training epochs. If ``None``,
|
|
675
|
-
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
|
|
676
|
-
to be an integer value.
|
|
677
|
-
Default: ``1``.
|
|
678
|
-
"""
|
|
679
|
-
|
|
680
|
-
num_sanity_val_steps: int | None
|
|
681
|
-
"""Sanity check runs n validation batches before starting the training routine.
|
|
682
|
-
Set it to `-1` to run all batches in all validation dataloaders.
|
|
683
|
-
Default: ``2``.
|
|
684
|
-
"""
|
|
685
|
-
|
|
686
|
-
log_every_n_steps: int | None
|
|
687
|
-
"""How often to log within steps.
|
|
688
|
-
Default: ``50``.
|
|
689
|
-
"""
|
|
690
|
-
|
|
691
|
-
enable_checkpointing: bool | None
|
|
692
|
-
"""If ``True``, enable checkpointing.
|
|
693
|
-
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
|
|
694
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
|
|
695
|
-
Default: ``True``.
|
|
696
|
-
"""
|
|
697
|
-
|
|
698
|
-
enable_progress_bar: bool | None
|
|
699
|
-
"""Whether to enable to progress bar by default.
|
|
700
|
-
Default: ``True``.
|
|
701
|
-
"""
|
|
702
|
-
|
|
703
|
-
enable_model_summary: bool | None
|
|
704
|
-
"""Whether to enable model summarization by default.
|
|
705
|
-
Default: ``True``.
|
|
706
|
-
"""
|
|
707
|
-
|
|
708
|
-
accumulate_grad_batches: int
|
|
709
|
-
"""Accumulates gradients over k batches before stepping the optimizer.
|
|
710
|
-
Default: 1.
|
|
711
|
-
"""
|
|
712
|
-
|
|
713
|
-
gradient_clip_val: int | float | None
|
|
714
|
-
"""The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
|
|
715
|
-
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
|
|
716
|
-
Default: ``None``.
|
|
717
|
-
"""
|
|
718
|
-
|
|
719
|
-
gradient_clip_algorithm: str | None
|
|
720
|
-
"""The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
|
|
721
|
-
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
|
|
722
|
-
be set to ``"norm"``.
|
|
723
|
-
"""
|
|
724
|
-
|
|
725
|
-
deterministic: bool | Literal["warn"] | None
|
|
726
|
-
"""If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
|
727
|
-
Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
|
|
728
|
-
that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``.
|
|
729
|
-
"""
|
|
730
|
-
|
|
731
|
-
benchmark: bool | None
|
|
732
|
-
"""The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
|
|
733
|
-
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
|
|
734
|
-
(``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
|
|
735
|
-
is set to ``True``, this will default to ``False``. Override to manually set a different value.
|
|
736
|
-
Default: ``None``.
|
|
737
|
-
"""
|
|
738
|
-
|
|
739
|
-
inference_mode: bool
|
|
740
|
-
"""Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
|
|
741
|
-
evaluation (``validate``/``test``/``predict``).
|
|
742
|
-
"""
|
|
743
|
-
|
|
744
|
-
use_distributed_sampler: bool
|
|
745
|
-
"""Whether to wrap the DataLoader's sampler with
|
|
746
|
-
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
|
|
747
|
-
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
|
|
748
|
-
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
|
|
749
|
-
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
|
|
750
|
-
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
|
|
751
|
-
we don't do this automatically.
|
|
752
|
-
"""
|
|
753
|
-
|
|
754
|
-
profiler: Profiler | str | None
|
|
755
|
-
"""To profile individual steps during training and assist in identifying bottlenecks.
|
|
756
|
-
Default: ``None``.
|
|
757
|
-
"""
|
|
758
|
-
|
|
759
|
-
detect_anomaly: bool
|
|
760
|
-
"""Enable anomaly detection for the autograd engine.
|
|
761
|
-
Default: ``False``.
|
|
762
|
-
"""
|
|
763
|
-
|
|
764
|
-
barebones: bool
|
|
765
|
-
"""Whether to run in "barebones mode", where all features that may impact raw speed are
|
|
766
|
-
disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
|
|
767
|
-
runs. The following features are deactivated:
|
|
768
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
|
|
769
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
|
|
770
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
|
|
771
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
|
|
772
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
|
|
773
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
|
|
774
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
|
|
775
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
|
|
776
|
-
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
|
|
777
|
-
:meth:`~lightning.pytorch.core.LightningModule.log`,
|
|
778
|
-
:meth:`~lightning.pytorch.core.LightningModule.log_dict`.
|
|
779
|
-
"""
|
|
780
|
-
|
|
781
|
-
plugins: _PLUGIN_INPUT | list[_PLUGIN_INPUT] | None
|
|
782
|
-
"""Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
|
783
|
-
Default: ``None``.
|
|
784
|
-
"""
|
|
785
|
-
|
|
786
|
-
sync_batchnorm: bool
|
|
787
|
-
"""Synchronize batch norm layers between process groups/whole world.
|
|
788
|
-
Default: ``False``.
|
|
789
|
-
"""
|
|
790
|
-
|
|
791
|
-
reload_dataloaders_every_n_epochs: int
|
|
792
|
-
"""Set to a positive integer to reload dataloaders every n epochs.
|
|
793
|
-
Default: ``0``.
|
|
794
|
-
"""
|
|
795
|
-
|
|
796
|
-
default_root_dir: Path | None
|
|
797
|
-
"""Default path for logs and weights when no logger/ckpt_callback passed.
|
|
798
|
-
Default: ``os.getcwd()``.
|
|
799
|
-
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
|
|
800
|
-
"""
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
class SanityCheckingConfig(C.Config):
|
|
804
|
-
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
805
|
-
"""
|
|
806
|
-
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
|
807
|
-
- If the `interval` is step, it makes sure that validation is called every `frequency` steps.
|
|
808
|
-
- If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
|
|
809
|
-
Valid values are: "disable", "warn", "error".
|
|
810
|
-
"""
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
class TrainerConfig(C.Config):
|
|
814
|
-
ckpt_path: Literal["none"] | str | Path | None = None
|
|
815
|
-
"""Path to a checkpoint to load and resume training from. If ``"none"``, will not load a checkpoint."""
|
|
816
|
-
|
|
817
|
-
checkpoint_loading: CheckpointLoadingConfig | Literal["auto", "none"] = "auto"
|
|
818
|
-
"""Checkpoint loading configuration options.
|
|
819
|
-
`"auto"` will automatically determine the best checkpoint loading strategy based on the provided.
|
|
820
|
-
`"none"` will disable checkpoint loading.
|
|
821
|
-
"""
|
|
822
|
-
|
|
823
|
-
checkpoint_saving: CheckpointSavingConfig = CheckpointSavingConfig()
|
|
824
|
-
"""Checkpoint saving configuration options."""
|
|
825
|
-
|
|
826
|
-
hf_hub: HuggingFaceHubConfig = HuggingFaceHubConfig()
|
|
827
|
-
"""Hugging Face Hub configuration options."""
|
|
828
|
-
|
|
829
|
-
logging: LoggingConfig = LoggingConfig()
|
|
830
|
-
"""Logging/experiment tracking (e.g., WandB) configuration options."""
|
|
831
|
-
|
|
832
|
-
optimizer: OptimizationConfig = OptimizationConfig()
|
|
833
|
-
"""Optimization configuration options."""
|
|
834
|
-
|
|
835
|
-
reproducibility: ReproducibilityConfig = ReproducibilityConfig()
|
|
836
|
-
"""Reproducibility configuration options."""
|
|
837
|
-
|
|
838
|
-
sanity_checking: SanityCheckingConfig = SanityCheckingConfig()
|
|
839
|
-
"""Sanity checking configuration options."""
|
|
840
|
-
|
|
841
|
-
early_stopping: EarlyStoppingConfig | None = None
|
|
842
|
-
"""Early stopping configuration options."""
|
|
843
|
-
|
|
844
|
-
profiler: ProfilerConfig | None = None
|
|
845
|
-
"""
|
|
846
|
-
To profile individual steps during training and assist in identifying bottlenecks.
|
|
847
|
-
Default: ``None``.
|
|
848
|
-
"""
|
|
849
|
-
|
|
850
|
-
callbacks: list[CallbackConfig] = []
|
|
851
|
-
"""Callbacks to use during training."""
|
|
852
|
-
|
|
853
|
-
detect_anomaly: bool | None = None
|
|
854
|
-
"""Enable anomaly detection for the autograd engine.
|
|
855
|
-
Default: ``False``.
|
|
856
|
-
"""
|
|
857
|
-
|
|
858
|
-
plugins: list[PluginConfigProtocol] | None = None
|
|
859
|
-
"""
|
|
860
|
-
Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
|
|
861
|
-
Default: ``None``.
|
|
862
|
-
"""
|
|
863
|
-
|
|
864
|
-
auto_determine_num_nodes: bool = True
|
|
865
|
-
"""
|
|
866
|
-
If enabled, will automatically determine the number of nodes for distributed training.
|
|
867
|
-
|
|
868
|
-
This will only work on:
|
|
869
|
-
- SLURM clusters
|
|
870
|
-
- LSF clusters
|
|
871
|
-
"""
|
|
872
|
-
|
|
873
|
-
fast_dev_run: int | bool = False
|
|
874
|
-
"""Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
|
|
875
|
-
of train, val and test to find any bugs (ie: a sort of unit test).
|
|
876
|
-
Default: ``False``.
|
|
877
|
-
"""
|
|
878
|
-
|
|
879
|
-
precision: (
|
|
880
|
-
Literal[
|
|
881
|
-
"64-true",
|
|
882
|
-
"32-true",
|
|
883
|
-
"fp16-mixed",
|
|
884
|
-
"bf16-mixed",
|
|
885
|
-
"16-mixed-auto",
|
|
886
|
-
]
|
|
887
|
-
| None
|
|
888
|
-
) = None
|
|
889
|
-
"""
|
|
890
|
-
Training precision. Can be one of:
|
|
891
|
-
- "64-true": Double precision (64-bit).
|
|
892
|
-
- "32-true": Full precision (32-bit).
|
|
893
|
-
- "fp16-mixed": Float16 mixed precision.
|
|
894
|
-
- "bf16-mixed": BFloat16 mixed precision.
|
|
895
|
-
- "16-mixed-auto": Automatic 16-bit: Uses bfloat16 if available, otherwise float16.
|
|
896
|
-
"""
|
|
897
|
-
|
|
898
|
-
max_epochs: int | None = None
|
|
899
|
-
"""Stop training once this number of epochs is reached. Disabled by default (None).
|
|
900
|
-
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
|
|
901
|
-
To enable infinite training, set ``max_epochs = -1``.
|
|
902
|
-
"""
|
|
903
|
-
|
|
904
|
-
min_epochs: int | None = None
|
|
905
|
-
"""Force training for at least these many epochs. Disabled by default (None).
|
|
906
|
-
"""
|
|
907
|
-
|
|
908
|
-
max_steps: int = -1
|
|
909
|
-
"""Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
|
910
|
-
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
|
911
|
-
``max_epochs`` to ``-1``.
|
|
912
|
-
"""
|
|
913
|
-
|
|
914
|
-
min_steps: int | None = None
|
|
915
|
-
"""Force training for at least these number of steps. Disabled by default (``None``).
|
|
916
|
-
"""
|
|
917
|
-
|
|
918
|
-
max_time: str | timedelta | dict[str, int] | None = None
|
|
919
|
-
"""Stop training after this amount of time has passed. Disabled by default (``None``).
|
|
920
|
-
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
|
|
921
|
-
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
|
|
922
|
-
:class:`datetime.timedelta`.
|
|
923
|
-
"""
|
|
924
|
-
|
|
925
|
-
limit_train_batches: int | float | None = None
|
|
926
|
-
"""How much of training dataset to check (float = fraction, int = num_batches).
|
|
927
|
-
Default: ``1.0``.
|
|
928
|
-
"""
|
|
929
|
-
|
|
930
|
-
limit_val_batches: int | float | None = None
|
|
931
|
-
"""How much of validation dataset to check (float = fraction, int = num_batches).
|
|
932
|
-
Default: ``1.0``.
|
|
933
|
-
"""
|
|
934
|
-
|
|
935
|
-
limit_test_batches: int | float | None = None
|
|
936
|
-
"""How much of test dataset to check (float = fraction, int = num_batches).
|
|
937
|
-
Default: ``1.0``.
|
|
938
|
-
"""
|
|
939
|
-
|
|
940
|
-
limit_predict_batches: int | float | None = None
|
|
941
|
-
"""How much of prediction dataset to check (float = fraction, int = num_batches).
|
|
942
|
-
Default: ``1.0``.
|
|
943
|
-
"""
|
|
944
|
-
|
|
945
|
-
overfit_batches: int | float = 0.0
|
|
946
|
-
"""Overfit a fraction of training/validation data (float) or a set number of batches (int).
|
|
947
|
-
Default: ``0.0``.
|
|
948
|
-
"""
|
|
949
|
-
|
|
950
|
-
val_check_interval: int | float | None = None
|
|
951
|
-
"""How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
|
|
952
|
-
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
|
|
953
|
-
batches. An ``int`` value can only be higher than the number of training batches when
|
|
954
|
-
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
|
|
955
|
-
across epochs or during iteration-based training.
|
|
956
|
-
Default: ``1.0``.
|
|
957
|
-
"""
|
|
958
|
-
|
|
959
|
-
check_val_every_n_epoch: int | None = 1
|
|
960
|
-
"""Perform a validation loop every after every `N` training epochs. If ``None``,
|
|
961
|
-
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
|
|
962
|
-
to be an integer value.
|
|
963
|
-
Default: ``1``.
|
|
964
|
-
"""
|
|
965
|
-
|
|
966
|
-
num_sanity_val_steps: int | None = None
|
|
967
|
-
"""Sanity check runs n validation batches before starting the training routine.
|
|
968
|
-
Set it to `-1` to run all batches in all validation dataloaders.
|
|
969
|
-
Default: ``2``.
|
|
970
|
-
"""
|
|
971
|
-
|
|
972
|
-
log_every_n_steps: int | None = None
|
|
973
|
-
"""How often to log within steps.
|
|
974
|
-
Default: ``50``.
|
|
975
|
-
"""
|
|
976
|
-
|
|
977
|
-
inference_mode: bool = True
|
|
978
|
-
"""Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during evaluation (``validate``/``test``/``predict``).
|
|
979
|
-
Default: ``True``.
|
|
980
|
-
"""
|
|
981
|
-
|
|
982
|
-
use_distributed_sampler: bool | None = None
|
|
983
|
-
"""Whether to wrap the DataLoader's sampler with
|
|
984
|
-
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
|
|
985
|
-
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
|
|
986
|
-
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
|
|
987
|
-
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
|
|
988
|
-
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
|
|
989
|
-
we don't do this automatically.
|
|
990
|
-
Default: ``True``.
|
|
991
|
-
"""
|
|
992
|
-
|
|
993
|
-
accelerator: AcceleratorConfigProtocol | AcceleratorLiteral | None = None
|
|
994
|
-
"""Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto")
|
|
995
|
-
as well as custom accelerator instances.
|
|
996
|
-
Default: ``"auto"``.
|
|
997
|
-
"""
|
|
998
|
-
|
|
999
|
-
strategy: StrategyConfigProtocol | StrategyLiteral | None = None
|
|
1000
|
-
"""Supports different training strategies with aliases as well custom strategies.
|
|
1001
|
-
Default: ``"auto"``.
|
|
1002
|
-
"""
|
|
1003
|
-
|
|
1004
|
-
devices: tuple[int, ...] | Sequence[int] | Literal["auto", "all"] | None = None
|
|
1005
|
-
"""The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for
|
|
1006
|
-
automatic selection based on the chosen accelerator. Default: ``"auto"``.
|
|
1007
|
-
"""
|
|
1008
|
-
|
|
1009
|
-
auto_set_default_root_dir: bool = True
|
|
1010
|
-
"""If enabled, will automatically set the default root dir to [cwd/lightning_logs/<id>/]. There is basically no reason to disable this."""
|
|
1011
|
-
supports_shared_parameters: bool = True
|
|
1012
|
-
"""If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`"""
|
|
1013
|
-
save_checkpoint_metadata: bool = True
|
|
1014
|
-
"""If enabled, will save additional metadata whenever a checkpoint is saved."""
|
|
1015
|
-
|
|
1016
|
-
lightning_kwargs: LightningTrainerKwargs = LightningTrainerKwargs()
|
|
1017
|
-
"""
|
|
1018
|
-
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
|
1019
|
-
|
|
1020
|
-
Please refer to the Lightning documentation for a list of valid keyword arguments.
|
|
1021
|
-
"""
|
|
1022
|
-
|
|
1023
|
-
additional_lightning_kwargs: dict[str, Any] = {}
|
|
1024
|
-
"""
|
|
1025
|
-
Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.
|
|
1026
|
-
|
|
1027
|
-
This is essentially a non-type-checked version of `lightning_kwargs`.
|
|
1028
|
-
"""
|
|
1029
|
-
|
|
1030
|
-
set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None
|
|
1031
|
-
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
24
|
class BaseConfig(C.Config):
|
|
1038
25
|
id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
|
|
1039
26
|
"""ID of the run."""
|
|
@@ -1060,7 +47,7 @@ class BaseConfig(C.Config):
|
|
|
1060
47
|
trainer: TrainerConfig = TrainerConfig()
|
|
1061
48
|
"""PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
|
|
1062
49
|
|
|
1063
|
-
primary_metric:
|
|
50
|
+
primary_metric: MetricConfig | None = None
|
|
1064
51
|
"""Primary metric configuration options. This is used in the following ways:
|
|
1065
52
|
- To determine the best model checkpoint to save with the ModelCheckpoint callback.
|
|
1066
53
|
- To monitor the primary metric during training and stop training based on the `early_stopping` configuration.
|
|
@@ -1216,9 +203,4 @@ class BaseConfig(C.Config):
|
|
|
1216
203
|
return cls.model_validate(hparams)
|
|
1217
204
|
|
|
1218
205
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
|
1219
|
-
yield self.trainer.
|
|
1220
|
-
yield self.trainer.checkpoint_saving
|
|
1221
|
-
yield self.trainer.logging
|
|
1222
|
-
yield self.trainer.optimizer
|
|
1223
|
-
yield self.trainer.hf_hub
|
|
1224
|
-
yield from self.trainer.callbacks
|
|
206
|
+
yield from self.trainer._nshtrainer_all_callback_configs()
|