nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
5
|
+
from lightning.pytorch.callbacks import Callback
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
Split = Literal["train", "val", "test", "predict"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _check_step(step: int, interval: int, skip_first: bool = False):
|
|
12
|
+
if step % interval != 0:
|
|
13
|
+
return False
|
|
14
|
+
if skip_first and step == 0:
|
|
15
|
+
return False
|
|
16
|
+
return True
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class StepIntervalCallback(Callback):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
function: Callable[[Trainer, LightningModule], None],
|
|
23
|
+
*,
|
|
24
|
+
interval: int,
|
|
25
|
+
skip_first: bool = False,
|
|
26
|
+
splits: list[Split] = ["train", "val", "test", "predict"],
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.function = function
|
|
31
|
+
self.interval = interval
|
|
32
|
+
self.skip_first = skip_first
|
|
33
|
+
self.splits = set(splits)
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def on_train_batch_start(
|
|
37
|
+
self,
|
|
38
|
+
trainer,
|
|
39
|
+
pl_module,
|
|
40
|
+
batch,
|
|
41
|
+
batch_idx,
|
|
42
|
+
dataloader_idx=1,
|
|
43
|
+
):
|
|
44
|
+
if (
|
|
45
|
+
not _check_step(
|
|
46
|
+
trainer.global_step,
|
|
47
|
+
self.interval,
|
|
48
|
+
skip_first=self.skip_first,
|
|
49
|
+
)
|
|
50
|
+
or "train" not in self.splits
|
|
51
|
+
):
|
|
52
|
+
return
|
|
53
|
+
self.function(trainer, pl_module)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
def on_validation_batch_start(
|
|
57
|
+
self,
|
|
58
|
+
trainer,
|
|
59
|
+
pl_module,
|
|
60
|
+
batch,
|
|
61
|
+
batch_idx,
|
|
62
|
+
dataloader_idx=1,
|
|
63
|
+
):
|
|
64
|
+
if (
|
|
65
|
+
not _check_step(
|
|
66
|
+
trainer.global_step,
|
|
67
|
+
self.interval,
|
|
68
|
+
skip_first=self.skip_first,
|
|
69
|
+
)
|
|
70
|
+
or "val" not in self.splits
|
|
71
|
+
):
|
|
72
|
+
return
|
|
73
|
+
self.function(trainer, pl_module)
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def on_test_batch_start(
|
|
77
|
+
self,
|
|
78
|
+
trainer,
|
|
79
|
+
pl_module,
|
|
80
|
+
batch,
|
|
81
|
+
batch_idx,
|
|
82
|
+
dataloader_idx=1,
|
|
83
|
+
):
|
|
84
|
+
if (
|
|
85
|
+
not _check_step(
|
|
86
|
+
trainer.global_step,
|
|
87
|
+
self.interval,
|
|
88
|
+
skip_first=self.skip_first,
|
|
89
|
+
)
|
|
90
|
+
or "test" not in self.splits
|
|
91
|
+
):
|
|
92
|
+
return
|
|
93
|
+
self.function(trainer, pl_module)
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def on_predict_batch_start(
|
|
97
|
+
self,
|
|
98
|
+
trainer,
|
|
99
|
+
pl_module,
|
|
100
|
+
batch,
|
|
101
|
+
batch_idx,
|
|
102
|
+
dataloader_idx=1,
|
|
103
|
+
):
|
|
104
|
+
if (
|
|
105
|
+
not _check_step(
|
|
106
|
+
trainer.global_step,
|
|
107
|
+
self.interval,
|
|
108
|
+
skip_first=self.skip_first,
|
|
109
|
+
)
|
|
110
|
+
or "predict" not in self.splits
|
|
111
|
+
):
|
|
112
|
+
return
|
|
113
|
+
self.function(trainer, pl_module)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class EpochIntervalCallback(Callback):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
function: Callable[[Trainer, LightningModule], None],
|
|
120
|
+
*,
|
|
121
|
+
interval: int,
|
|
122
|
+
skip_first: bool = False,
|
|
123
|
+
splits: list[Split] = ["train", "val", "test", "predict"],
|
|
124
|
+
):
|
|
125
|
+
super().__init__()
|
|
126
|
+
|
|
127
|
+
self.function = function
|
|
128
|
+
self.interval = interval
|
|
129
|
+
self.skip_first = skip_first
|
|
130
|
+
self.splits = set(splits)
|
|
131
|
+
|
|
132
|
+
@override
|
|
133
|
+
def on_train_epoch_start(self, trainer, pl_module):
|
|
134
|
+
if (
|
|
135
|
+
not _check_step(
|
|
136
|
+
trainer.current_epoch,
|
|
137
|
+
self.interval,
|
|
138
|
+
skip_first=self.skip_first,
|
|
139
|
+
)
|
|
140
|
+
or "train" not in self.splits
|
|
141
|
+
):
|
|
142
|
+
return
|
|
143
|
+
self.function(trainer, pl_module)
|
|
144
|
+
|
|
145
|
+
@override
|
|
146
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
147
|
+
if (
|
|
148
|
+
not _check_step(
|
|
149
|
+
trainer.current_epoch,
|
|
150
|
+
self.interval,
|
|
151
|
+
skip_first=self.skip_first,
|
|
152
|
+
)
|
|
153
|
+
or "val" not in self.splits
|
|
154
|
+
):
|
|
155
|
+
return
|
|
156
|
+
self.function(trainer, pl_module)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def on_test_epoch_start(self, trainer, pl_module):
|
|
160
|
+
if (
|
|
161
|
+
not _check_step(
|
|
162
|
+
trainer.current_epoch,
|
|
163
|
+
self.interval,
|
|
164
|
+
skip_first=self.skip_first,
|
|
165
|
+
)
|
|
166
|
+
or "test" not in self.splits
|
|
167
|
+
):
|
|
168
|
+
return
|
|
169
|
+
self.function(trainer, pl_module)
|
|
170
|
+
|
|
171
|
+
@override
|
|
172
|
+
def on_predict_epoch_start(self, trainer, pl_module):
|
|
173
|
+
if (
|
|
174
|
+
not _check_step(
|
|
175
|
+
trainer.current_epoch,
|
|
176
|
+
self.interval,
|
|
177
|
+
skip_first=self.skip_first,
|
|
178
|
+
)
|
|
179
|
+
or "predict" not in self.splits
|
|
180
|
+
):
|
|
181
|
+
return
|
|
182
|
+
self.function(trainer, pl_module)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class IntervalCallback(Callback):
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
function: Callable[[Trainer, LightningModule], None],
|
|
189
|
+
*,
|
|
190
|
+
step_interval: int | None = None,
|
|
191
|
+
epoch_interval: int | None = None,
|
|
192
|
+
skip_first: bool = False,
|
|
193
|
+
splits: list[Split] = ["train", "val", "test", "predict"],
|
|
194
|
+
):
|
|
195
|
+
super().__init__()
|
|
196
|
+
|
|
197
|
+
self.callback = None
|
|
198
|
+
|
|
199
|
+
if step_interval is not None:
|
|
200
|
+
self.callback = StepIntervalCallback(
|
|
201
|
+
function,
|
|
202
|
+
interval=step_interval,
|
|
203
|
+
splits=splits,
|
|
204
|
+
skip_first=skip_first,
|
|
205
|
+
)
|
|
206
|
+
elif epoch_interval is not None:
|
|
207
|
+
self.callback = EpochIntervalCallback(
|
|
208
|
+
function,
|
|
209
|
+
interval=epoch_interval,
|
|
210
|
+
splits=splits,
|
|
211
|
+
skip_first=skip_first,
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError("Either step_interval or epoch_interval must be specified")
|
|
215
|
+
|
|
216
|
+
@override
|
|
217
|
+
def on_train_batch_start(
|
|
218
|
+
self,
|
|
219
|
+
trainer,
|
|
220
|
+
pl_module,
|
|
221
|
+
batch,
|
|
222
|
+
batch_idx,
|
|
223
|
+
dataloader_idx=1,
|
|
224
|
+
):
|
|
225
|
+
if not isinstance(self.callback, StepIntervalCallback):
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
self.callback.on_train_batch_start(
|
|
229
|
+
trainer,
|
|
230
|
+
pl_module,
|
|
231
|
+
batch,
|
|
232
|
+
batch_idx,
|
|
233
|
+
dataloader_idx=1,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
@override
|
|
237
|
+
def on_validation_batch_start(
|
|
238
|
+
self,
|
|
239
|
+
trainer,
|
|
240
|
+
pl_module,
|
|
241
|
+
batch,
|
|
242
|
+
batch_idx,
|
|
243
|
+
dataloader_idx=1,
|
|
244
|
+
):
|
|
245
|
+
if not isinstance(self.callback, StepIntervalCallback):
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
self.callback.on_validation_batch_start(
|
|
249
|
+
trainer,
|
|
250
|
+
pl_module,
|
|
251
|
+
batch,
|
|
252
|
+
batch_idx,
|
|
253
|
+
dataloader_idx=1,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
@override
|
|
257
|
+
def on_test_batch_start(
|
|
258
|
+
self,
|
|
259
|
+
trainer,
|
|
260
|
+
pl_module,
|
|
261
|
+
batch,
|
|
262
|
+
batch_idx,
|
|
263
|
+
dataloader_idx=1,
|
|
264
|
+
):
|
|
265
|
+
if not isinstance(self.callback, StepIntervalCallback):
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
self.callback.on_test_batch_start(
|
|
269
|
+
trainer,
|
|
270
|
+
pl_module,
|
|
271
|
+
batch,
|
|
272
|
+
batch_idx,
|
|
273
|
+
dataloader_idx=1,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
@override
|
|
277
|
+
def on_predict_batch_start(
|
|
278
|
+
self,
|
|
279
|
+
trainer,
|
|
280
|
+
pl_module,
|
|
281
|
+
batch,
|
|
282
|
+
batch_idx,
|
|
283
|
+
dataloader_idx=1,
|
|
284
|
+
):
|
|
285
|
+
if not isinstance(self.callback, StepIntervalCallback):
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
self.callback.on_predict_batch_start(
|
|
289
|
+
trainer,
|
|
290
|
+
pl_module,
|
|
291
|
+
batch,
|
|
292
|
+
batch_idx,
|
|
293
|
+
dataloader_idx=1,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@override
|
|
297
|
+
def on_train_epoch_start(self, trainer, pl_module):
|
|
298
|
+
if not isinstance(self.callback, EpochIntervalCallback):
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
self.callback.on_train_epoch_start(trainer, pl_module)
|
|
302
|
+
|
|
303
|
+
@override
|
|
304
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
305
|
+
if not isinstance(self.callback, EpochIntervalCallback):
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
self.callback.on_validation_epoch_start(trainer, pl_module)
|
|
309
|
+
|
|
310
|
+
@override
|
|
311
|
+
def on_test_epoch_start(self, trainer, pl_module):
|
|
312
|
+
if not isinstance(self.callback, EpochIntervalCallback):
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
self.callback.on_test_epoch_start(trainer, pl_module)
|
|
316
|
+
|
|
317
|
+
@override
|
|
318
|
+
def on_predict_epoch_start(self, trainer, pl_module):
|
|
319
|
+
if not isinstance(self.callback, EpochIntervalCallback):
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
self.callback.on_predict_epoch_start(trainer, pl_module)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from lightning.fabric.utilities.types import _PATH
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Checkpoint
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LatestEpochCheckpoint(Checkpoint):
|
|
13
|
+
DEFAULT_FILENAME = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
dirpath: _PATH,
|
|
18
|
+
filename: str | None = None,
|
|
19
|
+
save_weights_only: bool = False,
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self._dirpath = Path(dirpath)
|
|
24
|
+
self._filename = filename or self.DEFAULT_FILENAME
|
|
25
|
+
self._save_weights_only = save_weights_only
|
|
26
|
+
|
|
27
|
+
# Also, we hold a reference to the last checkpoint path
|
|
28
|
+
# to be able to remove it when a new checkpoint is saved.
|
|
29
|
+
self._last_ckpt_path: Path | None = None
|
|
30
|
+
|
|
31
|
+
def _ckpt_path(self, trainer: Trainer):
|
|
32
|
+
return self._dirpath / self._filename.format(
|
|
33
|
+
epoch=trainer.current_epoch, step=trainer.global_step
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
|
38
|
+
# Remove the last checkpoint if it exists
|
|
39
|
+
if self._last_ckpt_path is not None:
|
|
40
|
+
trainer.strategy.remove_checkpoint(self._last_ckpt_path)
|
|
41
|
+
|
|
42
|
+
# Save the new checkpoint
|
|
43
|
+
filepath = self._ckpt_path(trainer)
|
|
44
|
+
trainer.save_checkpoint(filepath, self._save_weights_only)
|
|
45
|
+
self._last_ckpt_path = filepath
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LogEpochCallback(Callback):
|
|
13
|
+
def __init__(self, metric_name: str = "computed_epoch"):
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
self.metric_name = metric_name
|
|
17
|
+
|
|
18
|
+
@override
|
|
19
|
+
def on_train_batch_start(
|
|
20
|
+
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
|
|
21
|
+
):
|
|
22
|
+
if trainer.logger is None:
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
# If trainer.num_training_batches is not set or is nan/inf, we cannot calculate the epoch
|
|
26
|
+
if (
|
|
27
|
+
not trainer.num_training_batches
|
|
28
|
+
or math.isnan(trainer.num_training_batches)
|
|
29
|
+
or math.isinf(trainer.num_training_batches)
|
|
30
|
+
):
|
|
31
|
+
log.warning("Trainer has no valid num_training_batches. Cannot log epoch.")
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
epoch = pl_module.global_step / trainer.num_training_batches
|
|
35
|
+
pl_module.log(self.metric_name, epoch, on_step=True, on_epoch=False)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Literal, cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from lightning.pytorch import Callback, LightningModule, Trainer
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
|
|
12
|
+
log = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def grad_norm(
|
|
16
|
+
module: nn.Module,
|
|
17
|
+
norm_type: float | int | str,
|
|
18
|
+
group_separator: str = "/",
|
|
19
|
+
grad: bool = True,
|
|
20
|
+
) -> dict[str, torch.Tensor | float]:
|
|
21
|
+
"""Compute each parameter's gradient's norm and their overall norm.
|
|
22
|
+
|
|
23
|
+
The overall norm is computed over all gradients together, as if they
|
|
24
|
+
were concatenated into a single vector.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
module: :class:`torch.nn.Module` to inspect.
|
|
28
|
+
norm_type: The type of the used p-norm, cast to float if necessary.
|
|
29
|
+
Can be ``'inf'`` for infinity norm.
|
|
30
|
+
group_separator: The separator string used by the logger to group
|
|
31
|
+
the gradients norms in their own subfolder instead of the logs one.
|
|
32
|
+
|
|
33
|
+
Return:
|
|
34
|
+
norms: The dictionary of p-norms of each parameter's gradient and
|
|
35
|
+
a special entry for the total p-norm of the gradients viewed
|
|
36
|
+
as a single vector.
|
|
37
|
+
"""
|
|
38
|
+
norm_type = float(norm_type)
|
|
39
|
+
if norm_type <= 0:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if grad:
|
|
45
|
+
norms = {
|
|
46
|
+
f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type)
|
|
47
|
+
for name, p in module.named_parameters()
|
|
48
|
+
if p.grad is not None
|
|
49
|
+
}
|
|
50
|
+
if norms:
|
|
51
|
+
total_norm = torch.tensor(list(norms.values())).norm(norm_type)
|
|
52
|
+
norms[f"grad_{norm_type}_norm_total"] = total_norm
|
|
53
|
+
else:
|
|
54
|
+
norms = {
|
|
55
|
+
f"param_{norm_type}_norm{group_separator}{name}": p.data.norm(norm_type)
|
|
56
|
+
for name, p in module.named_parameters()
|
|
57
|
+
if p.grad is not None
|
|
58
|
+
}
|
|
59
|
+
if norms:
|
|
60
|
+
total_norm = torch.tensor(list(norms.values())).norm(norm_type)
|
|
61
|
+
norms[f"param_{norm_type}_norm_total"] = total_norm
|
|
62
|
+
|
|
63
|
+
return norms
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _to_norm_type(log_grad_norm_per_param: float | str | Literal[True]):
|
|
67
|
+
norm_type = 2.0
|
|
68
|
+
if log_grad_norm_per_param is not True:
|
|
69
|
+
norm_type = log_grad_norm_per_param
|
|
70
|
+
return norm_type
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compute_norm(
|
|
74
|
+
pl_module: LightningModule,
|
|
75
|
+
optimizer: Optimizer | None = None,
|
|
76
|
+
p: float | str = 2.0,
|
|
77
|
+
*,
|
|
78
|
+
grad: bool,
|
|
79
|
+
) -> torch.Tensor:
|
|
80
|
+
if optimizer is not None:
|
|
81
|
+
tensors = [
|
|
82
|
+
cast(torch.Tensor, p.grad if grad else p)
|
|
83
|
+
for group in optimizer.param_groups
|
|
84
|
+
for p in group["params"]
|
|
85
|
+
if p.grad is not None
|
|
86
|
+
]
|
|
87
|
+
else:
|
|
88
|
+
tensors = [
|
|
89
|
+
p.grad if grad else p for p in pl_module.parameters() if p.grad is not None
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
if not tensors:
|
|
93
|
+
return torch.tensor(0.0, device=pl_module.device)
|
|
94
|
+
|
|
95
|
+
return torch.norm(torch.stack([torch.norm(g, p=p) for g in tensors]), p=p)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class NormLoggingCallback(Callback):
|
|
99
|
+
def __init__(self, config: "NormLoggingConfig"):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
self.config = config
|
|
103
|
+
|
|
104
|
+
def _perform_norm_logging(
|
|
105
|
+
self,
|
|
106
|
+
pl_module: LightningModule,
|
|
107
|
+
optimizer: Optimizer,
|
|
108
|
+
prefix: str,
|
|
109
|
+
):
|
|
110
|
+
# Gradient norm logging
|
|
111
|
+
if log_grad_norm := self.config.log_grad_norm:
|
|
112
|
+
norm = compute_norm(
|
|
113
|
+
pl_module,
|
|
114
|
+
optimizer,
|
|
115
|
+
_to_norm_type(log_grad_norm),
|
|
116
|
+
grad=True,
|
|
117
|
+
)
|
|
118
|
+
pl_module.log(f"{prefix}grad_norm", norm, on_step=True, on_epoch=False)
|
|
119
|
+
if log_grad_norm_per_param := self.config.log_grad_norm_per_param:
|
|
120
|
+
norm_type = _to_norm_type(log_grad_norm_per_param)
|
|
121
|
+
pl_module.log_dict(
|
|
122
|
+
{
|
|
123
|
+
f"{prefix}{k}": v
|
|
124
|
+
for k, v in grad_norm(pl_module, norm_type, grad=True).items()
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Parameter norm logging
|
|
129
|
+
if log_param_norm := self.config.log_param_norm:
|
|
130
|
+
norm = compute_norm(
|
|
131
|
+
pl_module,
|
|
132
|
+
optimizer,
|
|
133
|
+
_to_norm_type(log_param_norm),
|
|
134
|
+
grad=False,
|
|
135
|
+
)
|
|
136
|
+
pl_module.log(f"{prefix}param_norm", norm, on_step=True, on_epoch=False)
|
|
137
|
+
if log_param_norm_per_param := self.config.log_param_norm_per_param:
|
|
138
|
+
norm_type = _to_norm_type(log_param_norm_per_param)
|
|
139
|
+
pl_module.log_dict(
|
|
140
|
+
{
|
|
141
|
+
f"{prefix}{k}": v
|
|
142
|
+
for k, v in grad_norm(pl_module, norm_type, grad=False).items()
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
|
|
148
|
+
if len(trainer.optimizers) == 1:
|
|
149
|
+
optimizer = trainer.optimizers[0]
|
|
150
|
+
self._perform_norm_logging(pl_module, optimizer, prefix="train/")
|
|
151
|
+
else:
|
|
152
|
+
for i, optimizer in enumerate(trainer.optimizers):
|
|
153
|
+
self._perform_norm_logging(
|
|
154
|
+
pl_module, optimizer, prefix=f"train/optimizer_{i}/"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class NormLoggingConfig(CallbackConfigBase):
|
|
159
|
+
name: Literal["norm_logging"] = "norm_logging"
|
|
160
|
+
|
|
161
|
+
log_grad_norm: bool | str | float = False
|
|
162
|
+
"""If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
|
|
163
|
+
log_grad_norm_per_param: bool | str | float = False
|
|
164
|
+
"""If enabled, will log the gradient norm for each model parameter to the logger."""
|
|
165
|
+
|
|
166
|
+
log_param_norm: bool | str | float = False
|
|
167
|
+
"""If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
|
|
168
|
+
log_param_norm_per_param: bool | str | float = False
|
|
169
|
+
"""If enabled, will log the parameter norm for each model parameter to the logger."""
|
|
170
|
+
|
|
171
|
+
def __bool__(self):
|
|
172
|
+
return any(
|
|
173
|
+
v
|
|
174
|
+
for v in (
|
|
175
|
+
self.log_grad_norm,
|
|
176
|
+
self.log_grad_norm_per_param,
|
|
177
|
+
self.log_param_norm,
|
|
178
|
+
self.log_param_norm_per_param,
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@override
|
|
183
|
+
def construct_callbacks(self, root_config):
|
|
184
|
+
if not self:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
yield NormLoggingCallback(self)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch import Trainer
|
|
7
|
+
from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OnExceptionCheckpoint(_OnExceptionCheckpoint):
|
|
14
|
+
@property
|
|
15
|
+
@override
|
|
16
|
+
def ckpt_path(self) -> str:
|
|
17
|
+
ckpt_path = super().ckpt_path
|
|
18
|
+
|
|
19
|
+
# Remve the extension and add the current timestamp
|
|
20
|
+
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
21
|
+
ckpt_path, ext = os.path.splitext(ckpt_path)
|
|
22
|
+
return f"{ckpt_path}_{timestamp}{ext}"
|
|
23
|
+
|
|
24
|
+
@override
|
|
25
|
+
def on_exception(self, trainer: Trainer, *_: Any, **__: Any) -> None:
|
|
26
|
+
# We override this to checkpoint the model manually,
|
|
27
|
+
# without calling the dist barrier.
|
|
28
|
+
|
|
29
|
+
# trainer.save_checkpoint(self.ckpt_path)
|
|
30
|
+
|
|
31
|
+
if trainer.model is None:
|
|
32
|
+
raise AttributeError(
|
|
33
|
+
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
|
|
34
|
+
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
|
|
35
|
+
)
|
|
36
|
+
checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
|
|
37
|
+
trainer.strategy.save_checkpoint(
|
|
38
|
+
checkpoint, self.ckpt_path, storage_options=None
|
|
39
|
+
)
|
|
40
|
+
# self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
|
|
44
|
+
trainer.strategy.remove_checkpoint(self.ckpt_path)
|