nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,208 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ import signal
5
+ import subprocess
6
+ import threading
7
+ from collections import defaultdict
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from types import FrameType
11
+ from typing import Any, TypeAlias
12
+
13
+ import torch.utils.data
14
+ from lightning.fabric.plugins.environments.lsf import LSFEnvironment
15
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
16
+ from lightning.pytorch.trainer.connectors.signal_connector import _HandlersCompose
17
+ from lightning.pytorch.trainer.connectors.signal_connector import (
18
+ _SignalConnector as _LightningSignalConnector,
19
+ )
20
+ from typing_extensions import override
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+ _SIGNUM = int | signal.Signals
25
+ _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
26
+
27
+
28
+ class _SignalConnector(_LightningSignalConnector):
29
+ def _auto_requeue_signals(self) -> list[signal.Signals]:
30
+ from ..model.base import BaseConfig
31
+
32
+ if not isinstance(config := self.trainer.lightning_module.hparams, BaseConfig):
33
+ return []
34
+
35
+ signals = config.runner.submit._resolved_auto_requeue_signals()
36
+ signals_set = set(signals)
37
+ valid_signals: set[signal.Signals] = signal.valid_signals()
38
+ assert signals_set.issubset(
39
+ valid_signals
40
+ ), f"Invalid signal(s) found: {signals_set - valid_signals}"
41
+ return signals
42
+
43
+ def _compose_and_register(
44
+ self,
45
+ signum: _SIGNUM,
46
+ handlers: list[_HANDLER],
47
+ replace_existing: bool = False,
48
+ ):
49
+ if self._is_on_windows():
50
+ log.info(f"Signal {signum} has no handlers or is not supported on Windows.")
51
+ return
52
+
53
+ if self._has_already_handler(signum):
54
+ if not replace_existing:
55
+ log.info(
56
+ f"Signal {signum} already has a handler. Adding ours to the existing one."
57
+ )
58
+ handlers.append(signal.getsignal(signum))
59
+ else:
60
+ log.info(f"Replacing existing handler for signal {signum} with ours.")
61
+
62
+ self._register_signal(signum, _HandlersCompose(handlers))
63
+ log.info(f"Registered {len(handlers)} handlers for signal {signum}.")
64
+
65
+ @override
66
+ def register_signal_handlers(self) -> None:
67
+ if not (auto_requeue_signals := self._auto_requeue_signals()):
68
+ log.info(
69
+ "No auto-requeue signals found. Reverting to default Lightning behavior."
70
+ )
71
+ return super().register_signal_handlers()
72
+
73
+ self.received_sigterm = False
74
+ self._original_handlers = self._get_current_signal_handlers()
75
+
76
+ signals = defaultdict[signal.Signals, list[_HANDLER]](lambda: [])
77
+ signals[signal.SIGTERM].append(self._sigterm_notifier_fn)
78
+
79
+ environment = self.trainer._accelerator_connector.cluster_environment
80
+ if isinstance(environment, SLURMEnvironment):
81
+ log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
82
+ for signal_handler in auto_requeue_signals:
83
+ signals[signal_handler].append(self._slurm_sigusr_handler_fn)
84
+
85
+ if isinstance(environment, LSFEnvironment):
86
+ # Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
87
+ # We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
88
+
89
+ log.info("LSF auto-requeueing enabled. Setting signal handlers.")
90
+ for signal_handler in auto_requeue_signals:
91
+ signals[signal_handler].append(self._lsf_sigusr_handler_fn)
92
+
93
+ for signum, handlers in signals.items():
94
+ if not handlers:
95
+ continue
96
+
97
+ self._compose_and_register(signum, handlers)
98
+
99
+ def _should_ignore_signal_handler(self) -> str | None:
100
+ if threading.current_thread() is not threading.main_thread():
101
+ return "Not in main thread"
102
+
103
+ if torch.utils.data.get_worker_info() is not None:
104
+ return "Inside DataLoader worker process"
105
+
106
+ return None
107
+
108
+ @override
109
+ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
110
+ if ignore_reason := self._should_ignore_signal_handler():
111
+ log.info(
112
+ f"Skipping SLURM auto-requeue signal handler. Reason: {ignore_reason}"
113
+ )
114
+ return
115
+
116
+ log.critical(f"Handling SLURM auto-requeue signal: {signum}")
117
+
118
+ # save logger to make sure we get all the metrics
119
+ for logger in self.trainer.loggers:
120
+ logger.finalize("finished")
121
+
122
+ hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(
123
+ self.trainer.default_root_dir
124
+ )
125
+ self.trainer.save_checkpoint(hpc_save_path)
126
+
127
+ if not self.trainer.is_global_zero:
128
+ return
129
+
130
+ # find job id
131
+ array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
132
+ if array_job_id is not None:
133
+ array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
134
+ job_id = f"{array_job_id}_{array_task_id}"
135
+ else:
136
+ job_id = os.environ["SLURM_JOB_ID"]
137
+
138
+ assert re.match("[0-9_-]+", job_id)
139
+ cmd = ["scontrol", "requeue", job_id]
140
+
141
+ # requeue job
142
+ log.info(f"requeing job {job_id}...")
143
+ try:
144
+ result = subprocess.call(cmd)
145
+ except FileNotFoundError:
146
+ # This can occur if a subprocess call to `scontrol` is run outside a shell context
147
+ # Re-attempt call (now with shell context). If any error is raised, propagate to user.
148
+ # When running a shell command, it should be passed as a single string.
149
+ result = subprocess.call(" ".join(cmd), shell=True)
150
+
151
+ # print result text
152
+ if result == 0:
153
+ log.info(f"Requeued SLURM job: {job_id}")
154
+ else:
155
+ log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}")
156
+
157
+ def _lsf_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
158
+ if ignore_reason := self._should_ignore_signal_handler():
159
+ log.info(
160
+ f"Skipping LSF auto-requeue signal handler. Reason: {ignore_reason}"
161
+ )
162
+ return
163
+
164
+ log.critical(f"Handling LSF auto-requeue signal: {signum}")
165
+
166
+ # Save logger to make sure we get all the metrics
167
+ for logger in self.trainer.loggers:
168
+ logger.finalize("finished")
169
+
170
+ # Save checkpoint
171
+ hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(
172
+ self.trainer.default_root_dir
173
+ )
174
+ self.trainer.save_checkpoint(hpc_save_path)
175
+ log.info(f"Saved checkpoint to {hpc_save_path}")
176
+
177
+ if not self.trainer.is_global_zero:
178
+ return
179
+
180
+ # Find job id
181
+ if (job_id := os.getenv("LSB_JOBID")) is None:
182
+ log.warning(
183
+ "LSB_JOBID environment variable not found. Unable to requeue job."
184
+ )
185
+ return
186
+
187
+ assert re.match("[0-9_-]+", job_id)
188
+
189
+ exe = "brequeue"
190
+ if (bin_dir := os.getenv("LSF_BINDIR")) is not None:
191
+ exe = str((Path(bin_dir) / exe).resolve().absolute())
192
+
193
+ log.info(f"Using LSF requeue executable: {exe}")
194
+ cmd = [exe, job_id]
195
+
196
+ # Requeue job
197
+ log.info(f"Requeuing job {job_id}...")
198
+ try:
199
+ result = subprocess.call(cmd)
200
+ except FileNotFoundError:
201
+ # Retry with shell context if subprocess call fails
202
+ result = subprocess.call(" ".join(cmd), shell=True)
203
+
204
+ # Print result text
205
+ if result == 0:
206
+ log.info(f"Requeued LSF job: {job_id}")
207
+ else:
208
+ log.warning(f"Requeuing LSF job {job_id} failed with error code {result}")
@@ -0,0 +1,340 @@
1
+ import contextlib
2
+ import logging
3
+ import os
4
+ from collections.abc import Sequence
5
+ from pathlib import Path
6
+ from typing import Any, cast
7
+
8
+ import torch
9
+ from lightning.fabric.plugins.environments.lsf import LSFEnvironment
10
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
11
+ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
12
+ from lightning.pytorch import LightningModule
13
+ from lightning.pytorch import Trainer as LightningTrainer
14
+ from lightning.pytorch.profilers import Profiler
15
+ from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
16
+ from typing_extensions import Unpack, assert_never, override
17
+
18
+ from ..actsave import ActSave
19
+ from ..callbacks.base import resolve_all_callbacks
20
+ from ..model.config import (
21
+ AcceleratorConfigProtocol,
22
+ BaseConfig,
23
+ BaseProfilerConfig,
24
+ LightningTrainerKwargs,
25
+ StrategyConfigProtocol,
26
+ )
27
+ from .signal_connector import _SignalConnector
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ def _is_bf16_supported_no_emulation():
33
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
34
+ version = cast(Any, torch.version)
35
+
36
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
37
+ # since it is supported on AMD GPU archs.
38
+ if version.hip:
39
+ return True
40
+
41
+ device = torch.cuda.current_device()
42
+
43
+ # Check for CUDA version and device compute capability.
44
+ # This is a fast way to check for it.
45
+ cuda_version = version.cuda
46
+ if (
47
+ cuda_version is not None
48
+ and int(cuda_version.split(".")[0]) >= 11
49
+ and torch.cuda.get_device_properties(device).major >= 8
50
+ ):
51
+ return True
52
+
53
+ return False
54
+
55
+
56
+ class Trainer(LightningTrainer):
57
+ @classmethod
58
+ @contextlib.contextmanager
59
+ def context(cls, config: BaseConfig):
60
+ if (precision := config.trainer.set_float32_matmul_precision) is not None:
61
+ torch.set_float32_matmul_precision(precision)
62
+
63
+ yield
64
+
65
+ @classmethod
66
+ def _update_kwargs(
67
+ cls,
68
+ config: BaseConfig,
69
+ kwargs_ctor: LightningTrainerKwargs,
70
+ ):
71
+ kwargs: LightningTrainerKwargs = {
72
+ "deterministic": config.trainer.reproducibility.deterministic,
73
+ "fast_dev_run": config.trainer.fast_dev_run,
74
+ "max_epochs": config.trainer.max_epochs,
75
+ "min_epochs": config.trainer.min_epochs,
76
+ "max_steps": config.trainer.max_steps,
77
+ "min_steps": config.trainer.min_steps,
78
+ "max_time": config.trainer.max_time,
79
+ "limit_train_batches": config.trainer.limit_train_batches,
80
+ "limit_val_batches": config.trainer.limit_val_batches,
81
+ "limit_test_batches": config.trainer.limit_test_batches,
82
+ "limit_predict_batches": config.trainer.limit_predict_batches,
83
+ "overfit_batches": config.trainer.overfit_batches,
84
+ "val_check_interval": config.trainer.val_check_interval,
85
+ "num_sanity_val_steps": config.trainer.num_sanity_val_steps,
86
+ "log_every_n_steps": config.trainer.log_every_n_steps,
87
+ "inference_mode": config.trainer.inference_mode,
88
+ "callbacks": [],
89
+ "plugins": [],
90
+ "logger": [],
91
+ # Moved to `lightning_kwargs`:
92
+ # "enable_checkpointing": config.trainer.enable_checkpointing,
93
+ # "accelerator": config.trainer.accelerator,
94
+ # "strategy": config.trainer.strategy,
95
+ # "num_nodes": config.trainer.num_nodes,
96
+ # "precision": config.trainer.precision,
97
+ # "logger": config.trainer.logging.enabled,
98
+ # "log_every_n_steps": config.trainer.log_every_n_steps,
99
+ # "enable_progress_bar": config.trainer.enable_progress_bar,
100
+ # "enable_model_summary": config.trainer.enable_model_summary,
101
+ # "accumulate_grad_batches": config.trainer.accumulate_grad_batches,
102
+ # "benchmark": config.trainer.benchmark,
103
+ # "use_distributed_sampler": config.trainer.use_distributed_sampler,
104
+ # "detect_anomaly": config.trainer.detect_anomaly,
105
+ # "barebones": config.trainer.barebones,
106
+ # "plugins": config.trainer.plugins,
107
+ # "sync_batchnorm": config.trainer.sync_batchnorm,
108
+ # "reload_dataloaders_every_n_epochs": config.trainer.reload_dataloaders_every_n_epochs,
109
+ }
110
+
111
+ def _update_key(key: str, new_value: Any):
112
+ # First, check to see if the key is already in the kwargs.
113
+ if key not in kwargs:
114
+ kwargs[key] = new_value
115
+ return
116
+
117
+ # If the key is already in the kwargs, then we check the type:
118
+ # - If the type is a sequence, then we extend the sequence.
119
+ # - Otherwise, we just update the value but warn the user.
120
+
121
+ match existing_value := kwargs[key]:
122
+ case Sequence() as existing_value:
123
+ # Make sure value is a sequence too
124
+ if not isinstance(new_value, Sequence):
125
+ new_value = [new_value]
126
+ kwargs[key] = [*existing_value, *new_value]
127
+ case _:
128
+ log.warning(
129
+ f"Trainer.__init__: Overwriting existing value {existing_value=} with {new_value=} for key {key=}."
130
+ )
131
+ kwargs[key] = new_value
132
+
133
+ def _update_kwargs(**update: Unpack[LightningTrainerKwargs]):
134
+ for key, value in update.items():
135
+ _update_key(key, value)
136
+
137
+ # Set `default_root_dir` if `auto_set_default_root_dir` is enabled.
138
+ if config.trainer.auto_set_default_root_dir:
139
+ if kwargs.get("default_root_dir"):
140
+ raise ValueError(
141
+ "You have set `config.trainer.default_root_dir`. "
142
+ "But we are trying to set it automatically. "
143
+ "Please use `config.directory.base` rather than `config.trainer.default_root_dir`. "
144
+ "If you want to set it manually, please set `config.trainer.auto_set_default_root_dir=False`."
145
+ )
146
+
147
+ _update_kwargs(
148
+ default_root_dir=config.directory.resolve_run_root_directory(config.id)
149
+ )
150
+
151
+ if (devices_input := config.trainer.devices) is not None:
152
+ match devices_input:
153
+ case "all":
154
+ devices = -1
155
+ case "auto":
156
+ devices = "auto"
157
+ case Sequence():
158
+ devices = list(devices_input)
159
+ case _:
160
+ raise ValueError(f"Invalid value for devices={devices_input}.")
161
+
162
+ _update_kwargs(devices=devices)
163
+
164
+ if (
165
+ use_distributed_sampler := config.trainer.use_distributed_sampler
166
+ ) is not None:
167
+ _update_kwargs(use_distributed_sampler=use_distributed_sampler)
168
+
169
+ if (accelerator := config.trainer.accelerator) is not None:
170
+ if isinstance(accelerator, AcceleratorConfigProtocol):
171
+ accelerator = accelerator.construct_accelerator()
172
+ _update_kwargs(accelerator=accelerator)
173
+
174
+ if (strategy := config.trainer.strategy) is not None:
175
+ if isinstance(strategy, StrategyConfigProtocol):
176
+ strategy = strategy.construct_strategy()
177
+ _update_kwargs(strategy=strategy)
178
+
179
+ if (precision := config.trainer.precision) is not None:
180
+ resolved_precision: _PRECISION_INPUT
181
+ match precision:
182
+ case "64-true" | "32-true" | "bf16-mixed":
183
+ resolved_precision = precision
184
+ case "fp16-mixed":
185
+ resolved_precision = "16-mixed"
186
+ case "16-mixed-auto":
187
+ try:
188
+ resolved_precision = (
189
+ "bf16-mixed"
190
+ if _is_bf16_supported_no_emulation()
191
+ else "16-mixed"
192
+ )
193
+ except BaseException:
194
+ resolved_precision = "16-mixed"
195
+ log.warning(
196
+ "Failed to detect bfloat16 support. Falling back to 16-mixed."
197
+ )
198
+
199
+ log.critical(
200
+ f"Auto-resolving {precision=} to {resolved_precision=}."
201
+ )
202
+ case _:
203
+ assert_never(precision)
204
+
205
+ _update_kwargs(precision=resolved_precision)
206
+
207
+ if (detect_anomaly := config.trainer.detect_anomaly) is not None:
208
+ _update_kwargs(detect_anomaly=detect_anomaly)
209
+
210
+ if (
211
+ grad_clip_config := config.trainer.optimizer.gradient_clipping
212
+ ) is not None and grad_clip_config.enabled:
213
+ # kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm
214
+ # kwargs["gradient_clip_val"] = grad_clip_config.value
215
+ _update_kwargs(
216
+ gradient_clip_algorithm=grad_clip_config.algorithm,
217
+ gradient_clip_val=grad_clip_config.value,
218
+ )
219
+
220
+ if profiler := config.trainer.profiler:
221
+ # If the profiler is an ProfilerConfig instance, then we instantiate it.
222
+ if isinstance(profiler, BaseProfilerConfig):
223
+ profiler = profiler.construct_profiler(config)
224
+ # Make sure that the profiler is an instance of `Profiler`.
225
+ if not isinstance(profiler, Profiler):
226
+ raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.")
227
+
228
+ # Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"),
229
+ # then we just pass it through.
230
+ # kwargs["profiler"] = profiler
231
+ _update_kwargs(profiler=profiler)
232
+
233
+ if callbacks := resolve_all_callbacks(config):
234
+ _update_kwargs(callbacks=callbacks)
235
+
236
+ if plugin_configs := config.trainer.plugins:
237
+ _update_kwargs(
238
+ plugins=[
239
+ plugin_config.construct_plugin() for plugin_config in plugin_configs
240
+ ]
241
+ )
242
+
243
+ if not config.trainer.logging.enabled:
244
+ log.critical(f"Disabling logger because {config.trainer.logging.enabled=}.")
245
+ kwargs["logger"] = False
246
+ else:
247
+ _update_kwargs(logger=config.trainer.logging.construct_loggers(config))
248
+
249
+ if config.trainer.auto_determine_num_nodes:
250
+ # When num_nodes is auto, we need to detect the number of nodes.
251
+ if SLURMEnvironment.detect():
252
+ if (num_nodes := os.environ.get("SLURM_NNODES")) is not None:
253
+ num_nodes = int(num_nodes)
254
+ log.critical(f"SLURM detected with {num_nodes=}.")
255
+ _update_kwargs(num_nodes=num_nodes)
256
+ else:
257
+ log.critical(
258
+ "SLURM detected, but SLURM_NNODES not found. "
259
+ "We'll continue without setting num_nodes, but this may cause issues."
260
+ )
261
+
262
+ elif LSFEnvironment.detect():
263
+ num_nodes = LSFEnvironment().world_size()
264
+ log.critical(f"LSF detected with {num_nodes=}.")
265
+ _update_kwargs(num_nodes=num_nodes)
266
+ else:
267
+ log.info(
268
+ "config.trainer.auto_determine_num_nodes ignored because no SLURM or LSF detected."
269
+ )
270
+
271
+ # Update the kwargs with the additional trainer kwargs
272
+ _update_kwargs(**cast(Any, config.trainer.additional_lightning_kwargs))
273
+ _update_kwargs(**config.trainer.lightning_kwargs)
274
+ _update_kwargs(**kwargs_ctor)
275
+
276
+ return kwargs
277
+
278
+ @override
279
+ def __init__(
280
+ self,
281
+ config: BaseConfig,
282
+ /,
283
+ **kwargs: Unpack[LightningTrainerKwargs],
284
+ ):
285
+ self._ll_config = config
286
+ kwargs = self._update_kwargs(config, kwargs)
287
+ log.critical(f"LightningTrainer.__init__ with {kwargs=}.")
288
+
289
+ super().__init__(**kwargs)
290
+
291
+ # Replace the signal connector with our own.
292
+ self._signal_connector = _SignalConnector(self)
293
+
294
+ # Print out the log dir, so that we can easily find it in the logs.
295
+ if log_dir := self.log_dir:
296
+ log_dir = str(Path(log_dir).resolve())
297
+ log.critical(f"LightningTrainer log directory: {self.log_dir}.")
298
+
299
+ # Checkpoint loading
300
+ if (
301
+ ckpt_loading := self._ll_config.trainer.checkpoint_loading
302
+ ) and ckpt_loading.path:
303
+ self.ckpt_path = ckpt_loading.path
304
+
305
+ @contextlib.contextmanager
306
+ def _actsave_context(self, model: LightningModule):
307
+ hparams = cast(BaseConfig, model.hparams)
308
+ if not (actsave_config := hparams.trainer.actsave):
309
+ yield
310
+ return
311
+
312
+ # Enter actsave context
313
+ with ActSave.enabled(actsave_config.resolve_save_dir(hparams)):
314
+ yield
315
+
316
+ @override
317
+ def _run(
318
+ self, model: LightningModule, ckpt_path: str | Path | None = None
319
+ ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None:
320
+ """
321
+ Two things done here:
322
+ 1. Lightning doesn't support gradient clipping with manual optimization.
323
+ We patch the `Trainer._run` method to throw if gradient clipping is enabled
324
+ and `model.automatic_optimization` is False.
325
+
326
+ 2. We actually set up actsave here.
327
+ """
328
+
329
+ if not model.automatic_optimization and (
330
+ self.gradient_clip_val is not None
331
+ or self.gradient_clip_algorithm is not None
332
+ ):
333
+ raise ValueError(
334
+ "Automatic gradient clipping is not supported with manual optimization. "
335
+ f"Please set {model.__class__.__name__}.automatic_optimization to True "
336
+ "or disable automatic gradient clipping. "
337
+ )
338
+
339
+ with self._actsave_context(model):
340
+ return super()._run(model, ckpt_path)
@@ -0,0 +1,144 @@
1
+ import os
2
+ from collections.abc import Sequence
3
+ from logging import getLogger
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ from jaxtyping import BFloat16 as BFloat16
9
+ from jaxtyping import Bool as Bool
10
+ from jaxtyping import Complex as Complex
11
+ from jaxtyping import Complex64 as Complex64
12
+ from jaxtyping import Complex128 as Complex128
13
+ from jaxtyping import Float as Float
14
+ from jaxtyping import Float16 as Float16
15
+ from jaxtyping import Float32 as Float32
16
+ from jaxtyping import Float64 as Float64
17
+ from jaxtyping import Inexact as Inexact
18
+ from jaxtyping import Int as Int
19
+ from jaxtyping import Int4 as Int4
20
+ from jaxtyping import Int8 as Int8
21
+ from jaxtyping import Int16 as Int16
22
+ from jaxtyping import Int32 as Int32
23
+ from jaxtyping import Int64 as Int64
24
+ from jaxtyping import Integer as Integer
25
+ from jaxtyping import Key as Key
26
+ from jaxtyping import Num as Num
27
+ from jaxtyping import Real as Real
28
+ from jaxtyping import Shaped as Shaped
29
+ from jaxtyping import UInt as UInt
30
+ from jaxtyping import UInt4 as UInt4
31
+ from jaxtyping import UInt8 as UInt8
32
+ from jaxtyping import UInt16 as UInt16
33
+ from jaxtyping import UInt32 as UInt32
34
+ from jaxtyping import UInt64 as UInt64
35
+ from jaxtyping._storage import get_shape_memo, shape_str
36
+ from torch import Tensor as Tensor
37
+ from torch.nn.parameter import Parameter as Parameter
38
+ from typing_extensions import TypeVar
39
+
40
+ log = getLogger(__name__)
41
+
42
+ DISABLE_ENV_KEY = "LL_DISABLE_TYPECHECKING"
43
+
44
+
45
+ def typecheck_modules(modules: Sequence[str]):
46
+ """
47
+ Typecheck the given modules using `jaxtyping`.
48
+
49
+ Args:
50
+ modules: Modules to typecheck.
51
+ """
52
+ # If `DISABLE_ENV_KEY` is set and the environment variable is set, skip
53
+ # typechecking.
54
+ if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
55
+ log.critical(
56
+ f"Type checking is disabled due to the environment variable {DISABLE_ENV_KEY}."
57
+ )
58
+ return
59
+
60
+ # Install the jaxtyping import hook for this module.
61
+ from jaxtyping import install_import_hook
62
+
63
+ install_import_hook(modules, "beartype.beartype")
64
+
65
+ log.critical(f"Type checking the following modules: {modules}")
66
+
67
+
68
+ def typecheck_this_module(additional_modules: Sequence[str] = ()):
69
+ """
70
+ Typecheck the calling module and any additional modules using `jaxtyping`.
71
+
72
+ Args:
73
+ additional_modules: Additional modules to typecheck.
74
+ """
75
+ # Get the calling module's name.
76
+ # Here, we can just use beartype's internal implementation behind
77
+ # `beartype_this_package`.
78
+ from beartype._util.func.utilfuncframe import get_frame, get_frame_package_name
79
+
80
+ # Get the calling module's name.
81
+ assert get_frame is not None, "get_frame is None"
82
+ frame = get_frame(1)
83
+ assert frame is not None, "frame is None"
84
+ calling_module_name = get_frame_package_name(frame)
85
+
86
+ # Typecheck the calling module + any additional modules.
87
+ typecheck_modules((calling_module_name, *additional_modules))
88
+
89
+
90
+ def _make_error_str(input: Any, t: Any) -> str:
91
+ error_components: list[str] = []
92
+ error_components.append("Type checking error:")
93
+ if hasattr(t, "__instancecheck_str__"):
94
+ error_components.append(t.__instancecheck_str__(input))
95
+ if torch.is_tensor(input):
96
+ try:
97
+ from lovely_tensors import lovely
98
+
99
+ error_components.append(repr(lovely(input)))
100
+ except BaseException:
101
+ error_components.append(repr(input.shape))
102
+ error_components.append(shape_str(get_shape_memo()))
103
+
104
+ return "\n".join(error_components)
105
+
106
+
107
+ T = TypeVar("T", torch.Tensor, np.ndarray, infer_variance=True)
108
+
109
+ """
110
+ Patch to jaxtyping:
111
+
112
+ In `jaxtyping._import_hook`, we add:
113
+ def _has_isinstance_or_tassert(func_def):
114
+ for node in ast.walk(func_def):
115
+ if isinstance(node, ast.Call):
116
+ if isinstance(node.func, ast.Name) and node.func.id == "isinstance":
117
+ return True
118
+ elif isinstance(node.func, ast.Name) and node.func.id == "tassert":
119
+ return True
120
+ return False
121
+
122
+ and we check this when adding the decorators.
123
+ """
124
+
125
+
126
+ def tassert(t: Any, input: T | tuple[T, ...]):
127
+ """
128
+ Typecheck the input against the given type.
129
+
130
+ Args:
131
+ t: Type to check against.
132
+ input: Input to check.
133
+ """
134
+
135
+ # Ignore typechecking if the environment variable is set.
136
+ if DISABLE_ENV_KEY is not None and bool(int(os.environ.get(DISABLE_ENV_KEY, "0"))):
137
+ return
138
+
139
+ if isinstance(input, tuple):
140
+ for i in input:
141
+ assert isinstance(i, t), _make_error_str(i, t)
142
+ return
143
+ else:
144
+ assert isinstance(input, t), _make_error_str(input, t)