nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import signal
|
|
5
|
+
import subprocess
|
|
6
|
+
import threading
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from types import FrameType
|
|
11
|
+
from typing import Any, TypeAlias
|
|
12
|
+
|
|
13
|
+
import torch.utils.data
|
|
14
|
+
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
|
15
|
+
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
16
|
+
from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompose
|
|
17
|
+
from lightning.pytorch.trainer.connectors.signal_connector import (
|
|
18
|
+
_SignalConnector as _LightningSignalConnector,
|
|
19
|
+
)
|
|
20
|
+
from typing_extensions import override
|
|
21
|
+
|
|
22
|
+
log = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
_SIGNUM = int | signal.Signals
|
|
25
|
+
_HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _SignalConnector(_LightningSignalConnector):
|
|
29
|
+
def _auto_requeue_signals(self) -> list[signal.Signals]:
|
|
30
|
+
from ..model.base import BaseConfig
|
|
31
|
+
|
|
32
|
+
if not isinstance(config := self.trainer.lightning_module.hparams, BaseConfig):
|
|
33
|
+
return []
|
|
34
|
+
|
|
35
|
+
signals = config.runner.submit._resolved_auto_requeue_signals()
|
|
36
|
+
signals_set = set(signals)
|
|
37
|
+
valid_signals: set[signal.Signals] = signal.valid_signals()
|
|
38
|
+
assert signals_set.issubset(
|
|
39
|
+
valid_signals
|
|
40
|
+
), f"Invalid signal(s) found: {signals_set - valid_signals}"
|
|
41
|
+
return signals
|
|
42
|
+
|
|
43
|
+
def _compose_and_register(
|
|
44
|
+
self,
|
|
45
|
+
signum: _SIGNUM,
|
|
46
|
+
handlers: list[_HANDLER],
|
|
47
|
+
replace_existing: bool = False,
|
|
48
|
+
):
|
|
49
|
+
if self._is_on_windows():
|
|
50
|
+
log.info(f"Signal {signum} has no handlers or is not supported on Windows.")
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
if self._has_already_handler(signum):
|
|
54
|
+
if not replace_existing:
|
|
55
|
+
log.info(
|
|
56
|
+
f"Signal {signum} already has a handler. Adding ours to the existing one."
|
|
57
|
+
)
|
|
58
|
+
handlers.append(signal.getsignal(signum))
|
|
59
|
+
else:
|
|
60
|
+
log.info(f"Replacing existing handler for signal {signum} with ours.")
|
|
61
|
+
|
|
62
|
+
self._register_signal(signum, _HandlersCompose(handlers))
|
|
63
|
+
log.info(f"Registered {len(handlers)} handlers for signal {signum}.")
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
def register_signal_handlers(self) -> None:
|
|
67
|
+
if not (auto_requeue_signals := self._auto_requeue_signals()):
|
|
68
|
+
log.info(
|
|
69
|
+
"No auto-requeue signals found. Reverting to default Lightning behavior."
|
|
70
|
+
)
|
|
71
|
+
return super().register_signal_handlers()
|
|
72
|
+
|
|
73
|
+
self.received_sigterm = False
|
|
74
|
+
self._original_handlers = self._get_current_signal_handlers()
|
|
75
|
+
|
|
76
|
+
signals = defaultdict[signal.Signals, list[_HANDLER]](lambda: [])
|
|
77
|
+
signals[signal.SIGTERM].append(self._sigterm_notifier_fn)
|
|
78
|
+
|
|
79
|
+
environment = self.trainer._accelerator_connector.cluster_environment
|
|
80
|
+
if isinstance(environment, SLURMEnvironment):
|
|
81
|
+
log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
|
|
82
|
+
for signal_handler in auto_requeue_signals:
|
|
83
|
+
signals[signal_handler].append(self._slurm_sigusr_handler_fn)
|
|
84
|
+
|
|
85
|
+
if isinstance(environment, LSFEnvironment):
|
|
86
|
+
# Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
|
|
87
|
+
# We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
|
|
88
|
+
|
|
89
|
+
log.info("LSF auto-requeueing enabled. Setting signal handlers.")
|
|
90
|
+
for signal_handler in auto_requeue_signals:
|
|
91
|
+
signals[signal_handler].append(self._lsf_sigusr_handler_fn)
|
|
92
|
+
|
|
93
|
+
for signum, handlers in signals.items():
|
|
94
|
+
if not handlers:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
self._compose_and_register(signum, handlers)
|
|
98
|
+
|
|
99
|
+
def _should_ignore_signal_handler(self) -> str | None:
|
|
100
|
+
if threading.current_thread() is not threading.main_thread():
|
|
101
|
+
return "Not in main thread"
|
|
102
|
+
|
|
103
|
+
if torch.utils.data.get_worker_info() is not None:
|
|
104
|
+
return "Inside DataLoader worker process"
|
|
105
|
+
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
|
|
110
|
+
if ignore_reason := self._should_ignore_signal_handler():
|
|
111
|
+
log.info(
|
|
112
|
+
f"Skipping SLURM auto-requeue signal handler. Reason: {ignore_reason}"
|
|
113
|
+
)
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
log.critical(f"Handling SLURM auto-requeue signal: {signum}")
|
|
117
|
+
|
|
118
|
+
# save logger to make sure we get all the metrics
|
|
119
|
+
for logger in self.trainer.loggers:
|
|
120
|
+
logger.finalize("finished")
|
|
121
|
+
|
|
122
|
+
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(
|
|
123
|
+
self.trainer.default_root_dir
|
|
124
|
+
)
|
|
125
|
+
self.trainer.save_checkpoint(hpc_save_path)
|
|
126
|
+
|
|
127
|
+
if not self.trainer.is_global_zero:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
# find job id
|
|
131
|
+
array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
|
|
132
|
+
if array_job_id is not None:
|
|
133
|
+
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
|
|
134
|
+
job_id = f"{array_job_id}_{array_task_id}"
|
|
135
|
+
else:
|
|
136
|
+
job_id = os.environ["SLURM_JOB_ID"]
|
|
137
|
+
|
|
138
|
+
assert re.match("[0-9_-]+", job_id)
|
|
139
|
+
cmd = ["scontrol", "requeue", job_id]
|
|
140
|
+
|
|
141
|
+
# requeue job
|
|
142
|
+
log.info(f"requeing job {job_id}...")
|
|
143
|
+
try:
|
|
144
|
+
result = subprocess.call(cmd)
|
|
145
|
+
except FileNotFoundError:
|
|
146
|
+
# This can occur if a subprocess call to `scontrol` is run outside a shell context
|
|
147
|
+
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
|
|
148
|
+
# When running a shell command, it should be passed as a single string.
|
|
149
|
+
result = subprocess.call(" ".join(cmd), shell=True)
|
|
150
|
+
|
|
151
|
+
# print result text
|
|
152
|
+
if result == 0:
|
|
153
|
+
log.info(f"Requeued SLURM job: {job_id}")
|
|
154
|
+
else:
|
|
155
|
+
log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}")
|
|
156
|
+
|
|
157
|
+
def _lsf_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
|
|
158
|
+
if ignore_reason := self._should_ignore_signal_handler():
|
|
159
|
+
log.info(
|
|
160
|
+
f"Skipping LSF auto-requeue signal handler. Reason: {ignore_reason}"
|
|
161
|
+
)
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
log.critical(f"Handling LSF auto-requeue signal: {signum}")
|
|
165
|
+
|
|
166
|
+
# Save logger to make sure we get all the metrics
|
|
167
|
+
for logger in self.trainer.loggers:
|
|
168
|
+
logger.finalize("finished")
|
|
169
|
+
|
|
170
|
+
# Save checkpoint
|
|
171
|
+
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(
|
|
172
|
+
self.trainer.default_root_dir
|
|
173
|
+
)
|
|
174
|
+
self.trainer.save_checkpoint(hpc_save_path)
|
|
175
|
+
log.info(f"Saved checkpoint to {hpc_save_path}")
|
|
176
|
+
|
|
177
|
+
if not self.trainer.is_global_zero:
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
# Find job id
|
|
181
|
+
if (job_id := os.getenv("LSB_JOBID")) is None:
|
|
182
|
+
log.warning(
|
|
183
|
+
"LSB_JOBID environment variable not found. Unable to requeue job."
|
|
184
|
+
)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
assert re.match("[0-9_-]+", job_id)
|
|
188
|
+
|
|
189
|
+
exe = "brequeue"
|
|
190
|
+
if (bin_dir := os.getenv("LSF_BINDIR")) is not None:
|
|
191
|
+
exe = str((Path(bin_dir) / exe).resolve().absolute())
|
|
192
|
+
|
|
193
|
+
log.info(f"Using LSF requeue executable: {exe}")
|
|
194
|
+
cmd = [exe, job_id]
|
|
195
|
+
|
|
196
|
+
# Requeue job
|
|
197
|
+
log.info(f"Requeuing job {job_id}...")
|
|
198
|
+
try:
|
|
199
|
+
result = subprocess.call(cmd)
|
|
200
|
+
except FileNotFoundError:
|
|
201
|
+
# Retry with shell context if subprocess call fails
|
|
202
|
+
result = subprocess.call(" ".join(cmd), shell=True)
|
|
203
|
+
|
|
204
|
+
# Print result text
|
|
205
|
+
if result == 0:
|
|
206
|
+
log.info(f"Requeued LSF job: {job_id}")
|
|
207
|
+
else:
|
|
208
|
+
log.warning(f"Requeuing LSF job {job_id} failed with error code {result}")
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
|
10
|
+
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
11
|
+
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
12
|
+
from lightning.pytorch import LightningModule
|
|
13
|
+
from lightning.pytorch import Trainer as LightningTrainer
|
|
14
|
+
from lightning.pytorch.profilers import Profiler
|
|
15
|
+
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
|
|
16
|
+
from typing_extensions import Unpack, assert_never, override
|
|
17
|
+
|
|
18
|
+
from ..actsave import ActSave
|
|
19
|
+
from ..callbacks.base import resolve_all_callbacks
|
|
20
|
+
from ..model.config import (
|
|
21
|
+
AcceleratorConfigProtocol,
|
|
22
|
+
BaseConfig,
|
|
23
|
+
BaseProfilerConfig,
|
|
24
|
+
LightningTrainerKwargs,
|
|
25
|
+
StrategyConfigProtocol,
|
|
26
|
+
)
|
|
27
|
+
from .signal_connector import _SignalConnector
|
|
28
|
+
|
|
29
|
+
log = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _is_bf16_supported_no_emulation():
|
|
33
|
+
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
|
34
|
+
version = cast(Any, torch.version)
|
|
35
|
+
|
|
36
|
+
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
|
37
|
+
# since it is supported on AMD GPU archs.
|
|
38
|
+
if version.hip:
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
device = torch.cuda.current_device()
|
|
42
|
+
|
|
43
|
+
# Check for CUDA version and device compute capability.
|
|
44
|
+
# This is a fast way to check for it.
|
|
45
|
+
cuda_version = version.cuda
|
|
46
|
+
if (
|
|
47
|
+
cuda_version is not None
|
|
48
|
+
and int(cuda_version.split(".")[0]) >= 11
|
|
49
|
+
and torch.cuda.get_device_properties(device).major >= 8
|
|
50
|
+
):
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Trainer(LightningTrainer):
|
|
57
|
+
@classmethod
|
|
58
|
+
@contextlib.contextmanager
|
|
59
|
+
def context(cls, config: BaseConfig):
|
|
60
|
+
if (precision := config.trainer.set_float32_matmul_precision) is not None:
|
|
61
|
+
torch.set_float32_matmul_precision(precision)
|
|
62
|
+
|
|
63
|
+
yield
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def _update_kwargs(
|
|
67
|
+
cls,
|
|
68
|
+
config: BaseConfig,
|
|
69
|
+
kwargs_ctor: LightningTrainerKwargs,
|
|
70
|
+
):
|
|
71
|
+
kwargs: LightningTrainerKwargs = {
|
|
72
|
+
"deterministic": config.trainer.reproducibility.deterministic,
|
|
73
|
+
"fast_dev_run": config.trainer.fast_dev_run,
|
|
74
|
+
"max_epochs": config.trainer.max_epochs,
|
|
75
|
+
"min_epochs": config.trainer.min_epochs,
|
|
76
|
+
"max_steps": config.trainer.max_steps,
|
|
77
|
+
"min_steps": config.trainer.min_steps,
|
|
78
|
+
"max_time": config.trainer.max_time,
|
|
79
|
+
"limit_train_batches": config.trainer.limit_train_batches,
|
|
80
|
+
"limit_val_batches": config.trainer.limit_val_batches,
|
|
81
|
+
"limit_test_batches": config.trainer.limit_test_batches,
|
|
82
|
+
"limit_predict_batches": config.trainer.limit_predict_batches,
|
|
83
|
+
"overfit_batches": config.trainer.overfit_batches,
|
|
84
|
+
"val_check_interval": config.trainer.val_check_interval,
|
|
85
|
+
"num_sanity_val_steps": config.trainer.num_sanity_val_steps,
|
|
86
|
+
"log_every_n_steps": config.trainer.log_every_n_steps,
|
|
87
|
+
"inference_mode": config.trainer.inference_mode,
|
|
88
|
+
"callbacks": [],
|
|
89
|
+
"plugins": [],
|
|
90
|
+
"logger": [],
|
|
91
|
+
# Moved to `lightning_kwargs`:
|
|
92
|
+
# "enable_checkpointing": config.trainer.enable_checkpointing,
|
|
93
|
+
# "accelerator": config.trainer.accelerator,
|
|
94
|
+
# "strategy": config.trainer.strategy,
|
|
95
|
+
# "num_nodes": config.trainer.num_nodes,
|
|
96
|
+
# "precision": config.trainer.precision,
|
|
97
|
+
# "logger": config.trainer.logging.enabled,
|
|
98
|
+
# "log_every_n_steps": config.trainer.log_every_n_steps,
|
|
99
|
+
# "enable_progress_bar": config.trainer.enable_progress_bar,
|
|
100
|
+
# "enable_model_summary": config.trainer.enable_model_summary,
|
|
101
|
+
# "accumulate_grad_batches": config.trainer.accumulate_grad_batches,
|
|
102
|
+
# "benchmark": config.trainer.benchmark,
|
|
103
|
+
# "use_distributed_sampler": config.trainer.use_distributed_sampler,
|
|
104
|
+
# "detect_anomaly": config.trainer.detect_anomaly,
|
|
105
|
+
# "barebones": config.trainer.barebones,
|
|
106
|
+
# "plugins": config.trainer.plugins,
|
|
107
|
+
# "sync_batchnorm": config.trainer.sync_batchnorm,
|
|
108
|
+
# "reload_dataloaders_every_n_epochs": config.trainer.reload_dataloaders_every_n_epochs,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def _update_key(key: str, new_value: Any):
|
|
112
|
+
# First, check to see if the key is already in the kwargs.
|
|
113
|
+
if key not in kwargs:
|
|
114
|
+
kwargs[key] = new_value
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# If the key is already in the kwargs, then we check the type:
|
|
118
|
+
# - If the type is a sequence, then we extend the sequence.
|
|
119
|
+
# - Otherwise, we just update the value but warn the user.
|
|
120
|
+
|
|
121
|
+
match existing_value := kwargs[key]:
|
|
122
|
+
case Sequence() as existing_value:
|
|
123
|
+
# Make sure value is a sequence too
|
|
124
|
+
if not isinstance(new_value, Sequence):
|
|
125
|
+
new_value = [new_value]
|
|
126
|
+
kwargs[key] = [*existing_value, *new_value]
|
|
127
|
+
case _:
|
|
128
|
+
log.warning(
|
|
129
|
+
f"Trainer.__init__: Overwriting existing value {existing_value=} with {new_value=} for key {key=}."
|
|
130
|
+
)
|
|
131
|
+
kwargs[key] = new_value
|
|
132
|
+
|
|
133
|
+
def _update_kwargs(**update: Unpack[LightningTrainerKwargs]):
|
|
134
|
+
for key, value in update.items():
|
|
135
|
+
_update_key(key, value)
|
|
136
|
+
|
|
137
|
+
# Set `default_root_dir` if `auto_set_default_root_dir` is enabled.
|
|
138
|
+
if config.trainer.auto_set_default_root_dir:
|
|
139
|
+
if kwargs.get("default_root_dir"):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"You have set `config.trainer.default_root_dir`. "
|
|
142
|
+
"But we are trying to set it automatically. "
|
|
143
|
+
"Please use `config.directory.base` rather than `config.trainer.default_root_dir`. "
|
|
144
|
+
"If you want to set it manually, please set `config.trainer.auto_set_default_root_dir=False`."
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
_update_kwargs(
|
|
148
|
+
default_root_dir=config.directory.resolve_run_root_directory(config.id)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if (devices_input := config.trainer.devices) is not None:
|
|
152
|
+
match devices_input:
|
|
153
|
+
case "all":
|
|
154
|
+
devices = -1
|
|
155
|
+
case "auto":
|
|
156
|
+
devices = "auto"
|
|
157
|
+
case Sequence():
|
|
158
|
+
devices = list(devices_input)
|
|
159
|
+
case _:
|
|
160
|
+
raise ValueError(f"Invalid value for devices={devices_input}.")
|
|
161
|
+
|
|
162
|
+
_update_kwargs(devices=devices)
|
|
163
|
+
|
|
164
|
+
if (
|
|
165
|
+
use_distributed_sampler := config.trainer.use_distributed_sampler
|
|
166
|
+
) is not None:
|
|
167
|
+
_update_kwargs(use_distributed_sampler=use_distributed_sampler)
|
|
168
|
+
|
|
169
|
+
if (accelerator := config.trainer.accelerator) is not None:
|
|
170
|
+
if isinstance(accelerator, AcceleratorConfigProtocol):
|
|
171
|
+
accelerator = accelerator.construct_accelerator()
|
|
172
|
+
_update_kwargs(accelerator=accelerator)
|
|
173
|
+
|
|
174
|
+
if (strategy := config.trainer.strategy) is not None:
|
|
175
|
+
if isinstance(strategy, StrategyConfigProtocol):
|
|
176
|
+
strategy = strategy.construct_strategy()
|
|
177
|
+
_update_kwargs(strategy=strategy)
|
|
178
|
+
|
|
179
|
+
if (precision := config.trainer.precision) is not None:
|
|
180
|
+
resolved_precision: _PRECISION_INPUT
|
|
181
|
+
match precision:
|
|
182
|
+
case "64-true" | "32-true" | "bf16-mixed":
|
|
183
|
+
resolved_precision = precision
|
|
184
|
+
case "fp16-mixed":
|
|
185
|
+
resolved_precision = "16-mixed"
|
|
186
|
+
case "16-mixed-auto":
|
|
187
|
+
try:
|
|
188
|
+
resolved_precision = (
|
|
189
|
+
"bf16-mixed"
|
|
190
|
+
if _is_bf16_supported_no_emulation()
|
|
191
|
+
else "16-mixed"
|
|
192
|
+
)
|
|
193
|
+
except BaseException:
|
|
194
|
+
resolved_precision = "16-mixed"
|
|
195
|
+
log.warning(
|
|
196
|
+
"Failed to detect bfloat16 support. Falling back to 16-mixed."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
log.critical(
|
|
200
|
+
f"Auto-resolving {precision=} to {resolved_precision=}."
|
|
201
|
+
)
|
|
202
|
+
case _:
|
|
203
|
+
assert_never(precision)
|
|
204
|
+
|
|
205
|
+
_update_kwargs(precision=resolved_precision)
|
|
206
|
+
|
|
207
|
+
if (detect_anomaly := config.trainer.detect_anomaly) is not None:
|
|
208
|
+
_update_kwargs(detect_anomaly=detect_anomaly)
|
|
209
|
+
|
|
210
|
+
if (
|
|
211
|
+
grad_clip_config := config.trainer.optimizer.gradient_clipping
|
|
212
|
+
) is not None and grad_clip_config.enabled:
|
|
213
|
+
# kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
|
|
214
|
+
# kwargs["gradient_clip_val"] = grad_clip_config.value
|
|
215
|
+
_update_kwargs(
|
|
216
|
+
gradient_clip_algorithm=grad_clip_config.algorithm,
|
|
217
|
+
gradient_clip_val=grad_clip_config.value,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if profiler := config.trainer.profiler:
|
|
221
|
+
# If the profiler is an ProfilerConfig instance, then we instantiate it.
|
|
222
|
+
if isinstance(profiler, BaseProfilerConfig):
|
|
223
|
+
profiler = profiler.construct_profiler(config)
|
|
224
|
+
# Make sure that the profiler is an instance of `Profiler`.
|
|
225
|
+
if not isinstance(profiler, Profiler):
|
|
226
|
+
raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
|
|
227
|
+
|
|
228
|
+
# Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
|
|
229
|
+
# then we just pass it through.
|
|
230
|
+
# kwargs["profiler"] = profiler
|
|
231
|
+
_update_kwargs(profiler=profiler)
|
|
232
|
+
|
|
233
|
+
if callbacks := resolve_all_callbacks(config):
|
|
234
|
+
_update_kwargs(callbacks=callbacks)
|
|
235
|
+
|
|
236
|
+
if plugin_configs := config.trainer.plugins:
|
|
237
|
+
_update_kwargs(
|
|
238
|
+
plugins=[
|
|
239
|
+
plugin_config.construct_plugin() for plugin_config in plugin_configs
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if not config.trainer.logging.enabled:
|
|
244
|
+
log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
|
|
245
|
+
kwargs["logger"] = False
|
|
246
|
+
else:
|
|
247
|
+
_update_kwargs(logger=config.trainer.logging.construct_loggers(config))
|
|
248
|
+
|
|
249
|
+
if config.trainer.auto_determine_num_nodes:
|
|
250
|
+
# When num_nodes is auto, we need to detect the number of nodes.
|
|
251
|
+
if SLURMEnvironment.detect():
|
|
252
|
+
if (num_nodes := os.environ.get("SLURM_NNODES")) is not None:
|
|
253
|
+
num_nodes = int(num_nodes)
|
|
254
|
+
log.critical(f"SLURM detected with {num_nodes=}.")
|
|
255
|
+
_update_kwargs(num_nodes=num_nodes)
|
|
256
|
+
else:
|
|
257
|
+
log.critical(
|
|
258
|
+
"SLURM detected, but SLURM_NNODES not found. "
|
|
259
|
+
"We'll continue without setting num_nodes, but this may cause issues."
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
elif LSFEnvironment.detect():
|
|
263
|
+
num_nodes = LSFEnvironment().world_size()
|
|
264
|
+
log.critical(f"LSF detected with {num_nodes=}.")
|
|
265
|
+
_update_kwargs(num_nodes=num_nodes)
|
|
266
|
+
else:
|
|
267
|
+
log.info(
|
|
268
|
+
"config.trainer.auto_determine_num_nodes ignored because no SLURM or LSF detected."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Update the kwargs with the additional trainer kwargs
|
|
272
|
+
_update_kwargs(**cast(Any, config.trainer.additional_lightning_kwargs))
|
|
273
|
+
_update_kwargs(**config.trainer.lightning_kwargs)
|
|
274
|
+
_update_kwargs(**kwargs_ctor)
|
|
275
|
+
|
|
276
|
+
return kwargs
|
|
277
|
+
|
|
278
|
+
@override
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
config: BaseConfig,
|
|
282
|
+
/,
|
|
283
|
+
**kwargs: Unpack[LightningTrainerKwargs],
|
|
284
|
+
):
|
|
285
|
+
self._ll_config = config
|
|
286
|
+
kwargs = self._update_kwargs(config, kwargs)
|
|
287
|
+
log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
|
|
288
|
+
|
|
289
|
+
super().__init__(**kwargs)
|
|
290
|
+
|
|
291
|
+
# Replace the signal connector with our own.
|
|
292
|
+
self._signal_connector = _SignalConnector(self)
|
|
293
|
+
|
|
294
|
+
# Print out the log dir, so that we can easily find it in the logs.
|
|
295
|
+
if log_dir := self.log_dir:
|
|
296
|
+
log_dir = str(Path(log_dir).resolve())
|
|
297
|
+
log.critical(f"LightningTrainer log directory: {self.log_dir}.")
|
|
298
|
+
|
|
299
|
+
# Checkpoint loading
|
|
300
|
+
if (
|
|
301
|
+
ckpt_loading := self._ll_config.trainer.checkpoint_loading
|
|
302
|
+
) and ckpt_loading.path:
|
|
303
|
+
self.ckpt_path = ckpt_loading.path
|
|
304
|
+
|
|
305
|
+
@contextlib.contextmanager
|
|
306
|
+
def _actsave_context(self, model: LightningModule):
|
|
307
|
+
hparams = cast(BaseConfig, model.hparams)
|
|
308
|
+
if not (actsave_config := hparams.trainer.actsave):
|
|
309
|
+
yield
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
# Enter actsave context
|
|
313
|
+
with ActSave.enabled(actsave_config.resolve_save_dir(hparams)):
|
|
314
|
+
yield
|
|
315
|
+
|
|
316
|
+
@override
|
|
317
|
+
def _run(
|
|
318
|
+
self, model: LightningModule, ckpt_path: str | Path | None = None
|
|
319
|
+
) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None:
|
|
320
|
+
"""
|
|
321
|
+
Two things done here:
|
|
322
|
+
1. Lightning doesn't support gradient clipping with manual optimization.
|
|
323
|
+
We patch the `Trainer._run` method to throw if gradient clipping is enabled
|
|
324
|
+
and `model.automatic_optimization` is False.
|
|
325
|
+
|
|
326
|
+
2. We actually set up actsave here.
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
if not model.automatic_optimization and (
|
|
330
|
+
self.gradient_clip_val is not None
|
|
331
|
+
or self.gradient_clip_algorithm is not None
|
|
332
|
+
):
|
|
333
|
+
raise ValueError(
|
|
334
|
+
"Automatic gradient clipping is not supported with manual optimization. "
|
|
335
|
+
f"Please set {model.__class__.__name__}.automatic_optimization to True "
|
|
336
|
+
"or disable automatic gradient clipping. "
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
with self._actsave_context(model):
|
|
340
|
+
return super()._run(model, ckpt_path)
|
nshtrainer/typecheck.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from jaxtyping import BFloat16 as BFloat16
|
|
9
|
+
from jaxtyping import Bool as Bool
|
|
10
|
+
from jaxtyping import Complex as Complex
|
|
11
|
+
from jaxtyping import Complex64 as Complex64
|
|
12
|
+
from jaxtyping import Complex128 as Complex128
|
|
13
|
+
from jaxtyping import Float as Float
|
|
14
|
+
from jaxtyping import Float16 as Float16
|
|
15
|
+
from jaxtyping import Float32 as Float32
|
|
16
|
+
from jaxtyping import Float64 as Float64
|
|
17
|
+
from jaxtyping import Inexact as Inexact
|
|
18
|
+
from jaxtyping import Int as Int
|
|
19
|
+
from jaxtyping import Int4 as Int4
|
|
20
|
+
from jaxtyping import Int8 as Int8
|
|
21
|
+
from jaxtyping import Int16 as Int16
|
|
22
|
+
from jaxtyping import Int32 as Int32
|
|
23
|
+
from jaxtyping import Int64 as Int64
|
|
24
|
+
from jaxtyping import Integer as Integer
|
|
25
|
+
from jaxtyping import Key as Key
|
|
26
|
+
from jaxtyping import Num as Num
|
|
27
|
+
from jaxtyping import Real as Real
|
|
28
|
+
from jaxtyping import Shaped as Shaped
|
|
29
|
+
from jaxtyping import UInt as UInt
|
|
30
|
+
from jaxtyping import UInt4 as UInt4
|
|
31
|
+
from jaxtyping import UInt8 as UInt8
|
|
32
|
+
from jaxtyping import UInt16 as UInt16
|
|
33
|
+
from jaxtyping import UInt32 as UInt32
|
|
34
|
+
from jaxtyping import UInt64 as UInt64
|
|
35
|
+
from jaxtyping._storage import get_shape_memo, shape_str
|
|
36
|
+
from torch import Tensor as Tensor
|
|
37
|
+
from torch.nn.parameter import Parameter as Parameter
|
|
38
|
+
from typing_extensions import TypeVar
|
|
39
|
+
|
|
40
|
+
log = getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
DISABLE_ENV_KEY = "LL_DISABLE_TYPECHECKING"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def typecheck_modules(modules: Sequence[str]):
|
|
46
|
+
"""
|
|
47
|
+
Typecheck the given modules using `jaxtyping`.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
modules: Modules to typecheck.
|
|
51
|
+
"""
|
|
52
|
+
# If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
|
|
53
|
+
# typechecking.
|
|
54
|
+
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
|
55
|
+
log.critical(
|
|
56
|
+
f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
|
|
57
|
+
)
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# Install the jaxtyping import hook for this module.
|
|
61
|
+
from jaxtyping import install_import_hook
|
|
62
|
+
|
|
63
|
+
install_import_hook(modules, "beartype.beartype")
|
|
64
|
+
|
|
65
|
+
log.critical(f"Type checking the following modules: {modules}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def typecheck_this_module(additional_modules: Sequence[str] = ()):
|
|
69
|
+
"""
|
|
70
|
+
Typecheck the calling module and any additional modules using `jaxtyping`.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
additional_modules: Additional modules to typecheck.
|
|
74
|
+
"""
|
|
75
|
+
# Get the calling module's name.
|
|
76
|
+
# Here, we can just use beartype's internal implementation behind
|
|
77
|
+
# `beartype_this_package`.
|
|
78
|
+
from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
|
|
79
|
+
|
|
80
|
+
# Get the calling module's name.
|
|
81
|
+
assert get_frame is not None, "get_frame is None"
|
|
82
|
+
frame = get_frame(1)
|
|
83
|
+
assert frame is not None, "frame is None"
|
|
84
|
+
calling_module_name = get_frame_package_name(frame)
|
|
85
|
+
|
|
86
|
+
# Typecheck the calling module + any additional modules.
|
|
87
|
+
typecheck_modules((calling_module_name, *additional_modules))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _make_error_str(input: Any, t: Any) -> str:
|
|
91
|
+
error_components: list[str] = []
|
|
92
|
+
error_components.append("Type checking error:")
|
|
93
|
+
if hasattr(t, "__instancecheck_str__"):
|
|
94
|
+
error_components.append(t.__instancecheck_str__(input))
|
|
95
|
+
if torch.is_tensor(input):
|
|
96
|
+
try:
|
|
97
|
+
from lovely_tensors import lovely
|
|
98
|
+
|
|
99
|
+
error_components.append(repr(lovely(input)))
|
|
100
|
+
except BaseException:
|
|
101
|
+
error_components.append(repr(input.shape))
|
|
102
|
+
error_components.append(shape_str(get_shape_memo()))
|
|
103
|
+
|
|
104
|
+
return "\n".join(error_components)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
T = TypeVar("T", torch.Tensor, np.ndarray, infer_variance=True)
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
Patch to jaxtyping:
|
|
111
|
+
|
|
112
|
+
In `jaxtyping._import_hook`, we add:
|
|
113
|
+
def _has_isinstance_or_tassert(func_def):
|
|
114
|
+
for node in ast.walk(func_def):
|
|
115
|
+
if isinstance(node, ast.Call):
|
|
116
|
+
if isinstance(node.func, ast.Name) and node.func.id == "isinstance":
|
|
117
|
+
return True
|
|
118
|
+
elif isinstance(node.func, ast.Name) and node.func.id == "tassert":
|
|
119
|
+
return True
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
and we check this when adding the decorators.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def tassert(t: Any, input: T | tuple[T, ...]):
|
|
127
|
+
"""
|
|
128
|
+
Typecheck the input against the given type.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
t: Type to check against.
|
|
132
|
+
input: Input to check.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Ignore typechecking if the environment variable is set.
|
|
136
|
+
if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
if isinstance(input, tuple):
|
|
140
|
+
for i in input:
|
|
141
|
+
assert isinstance(i, t), _make_error_str(i, t)
|
|
142
|
+
return
|
|
143
|
+
else:
|
|
144
|
+
assert isinstance(input, t), _make_error_str(input, t)
|