nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,549 @@
1
+ # type: ignore
2
+ # Copyright The Lightning AI team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import time
16
+ from collections import deque
17
+ from typing import (
18
+ TYPE_CHECKING,
19
+ Any,
20
+ Callable,
21
+ Deque,
22
+ Dict,
23
+ List,
24
+ Optional,
25
+ TypeVar,
26
+ Union,
27
+ )
28
+
29
+ import torch
30
+ from lightning.fabric.plugins import Precision as FabricPrecision
31
+ from lightning.fabric.utilities.throughput import (
32
+ _plugin_to_compute_dtype as fabric_plugin_to_compute_dtype,
33
+ )
34
+ from lightning.fabric.utilities.throughput import get_available_flops
35
+ from lightning.pytorch.callbacks import Callback
36
+ from lightning.pytorch.plugins import (
37
+ BitsandbytesPrecision,
38
+ DeepSpeedPrecision,
39
+ DoublePrecision,
40
+ FSDPPrecision,
41
+ HalfPrecision,
42
+ MixedPrecision,
43
+ Precision,
44
+ TransformerEnginePrecision,
45
+ XLAPrecision,
46
+ )
47
+ from lightning.pytorch.trainer.states import RunningStage, TrainerFn
48
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
49
+ from typing_extensions import override
50
+
51
+ if TYPE_CHECKING:
52
+ from lightning.pytorch import LightningModule, Trainer
53
+
54
+
55
+ _THROUGHPUT_METRICS = Dict[str, Union[int, float]]
56
+
57
+ T = TypeVar("T", bound=float)
58
+
59
+
60
+ class _MonotonicWindow(List[T]):
61
+ """Custom fixed size list that only supports right-append and ensures that all values increase monotonically."""
62
+
63
+ def __init__(self, maxlen: int) -> None:
64
+ super().__init__()
65
+ self.maxlen = maxlen
66
+
67
+ @property
68
+ def last(self) -> Optional[T]:
69
+ if len(self) > 0:
70
+ return self[-1]
71
+ return None
72
+
73
+ @override
74
+ def append(self, x: T) -> None:
75
+ last = self.last
76
+ if last is not None and last >= x:
77
+ rank_zero_warn(
78
+ f"Expected the value to increase, last: {last}, current: {x}"
79
+ )
80
+ list.append(self, x)
81
+ # truncate excess
82
+ if len(self) > self.maxlen:
83
+ del self[0]
84
+
85
+ @override
86
+ def __setitem__(self, key: Any, value: Any) -> None:
87
+ # assigning is not implemented since we don't use it. it could be by checking all previous values
88
+ raise NotImplementedError("__setitem__ is not supported")
89
+
90
+
91
+ class Throughput:
92
+ """Computes throughput.
93
+
94
+ +------------------------+-------------------------------------------------------------------------------------+
95
+ | Key | Value |
96
+ +========================+=====================================================================================+
97
+ | batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
98
+ | | processed per second |
99
+ +--------------------------+-----------------------------------------------------------------------------------+
100
+ | samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
101
+ | | processed per second |
102
+ +--------------------------+-----------------------------------------------------------------------------------+
103
+ | items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
104
+ | | processed per second |
105
+ +--------------------------+-----------------------------------------------------------------------------------+
106
+ | flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
107
+ | | processed per second |
108
+ +--------------------------+-----------------------------------------------------------------------------------+
109
+ | device/batches_per_sec | batches_per_sec divided by world size |
110
+ +--------------------------+-----------------------------------------------------------------------------------+
111
+ | device/samples_per_sec | samples_per_sec divided by world size |
112
+ +--------------------------+-----------------------------------------------------------------------------------+
113
+ | device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data |
114
+ +--------------------------+-----------------------------------------------------------------------------------+
115
+ | device/flops_per_sec | flops_per_sec divided by world size. |
116
+ +--------------------------+-----------------------------------------------------------------------------------+
117
+ | device/mfu | device/flops_per_sec divided by world size. |
118
+ +--------------------------+-----------------------------------------------------------------------------------+
119
+ | time | Total elapsed time |
120
+ +--------------------------+-----------------------------------------------------------------------------------+
121
+ | batches | Total batches seen |
122
+ +--------------------------+-----------------------------------------------------------------------------------+
123
+ | samples | Total samples seen |
124
+ +--------------------------+-----------------------------------------------------------------------------------+
125
+ | lengths | Total items seen |
126
+ +--------------------------+-----------------------------------------------------------------------------------+
127
+
128
+ Example::
129
+
130
+ throughput = Throughput()
131
+ t0 = time()
132
+ for i in range(1000):
133
+ do_work()
134
+ if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct
135
+ throughput.update(time=time() - t0, samples=i)
136
+ if i % 10 == 0:
137
+ print(throughput.compute())
138
+
139
+ Notes:
140
+ - The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only
141
+ takes a single ``available_flops`` value.
142
+ - items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using
143
+ samples_per_sec or batches_per_sec to measure throughput under this circumstance.
144
+
145
+ Args:
146
+ available_flops: Number of theoretical flops available for a single device.
147
+ world_size: Number of devices available across hosts. Global metrics are not included if the world size is 1.
148
+ window_size: Number of batches to use for a rolling average.
149
+ separator: Key separator to use when creating per-device and global metrics.
150
+
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ available_flops: Optional[float] = None,
156
+ world_size: int = 1,
157
+ window_size: int = 100,
158
+ separator: str = "/",
159
+ ) -> None:
160
+ self.available_flops = available_flops
161
+ self.separator = separator
162
+ assert world_size > 0
163
+ self.world_size = world_size
164
+
165
+ # throughput is computed over a window of values. at least 2 is enforced since it looks at the difference
166
+ # between the first and last elements
167
+ assert window_size > 1
168
+ # custom class instead of `deque(maxlen=)` because it's easy for users to mess up their timer/counters and log
169
+ # values that do not increase monotonically. this class will raise an error if that happens.
170
+ self._time: _MonotonicWindow[float] = _MonotonicWindow(maxlen=window_size)
171
+ self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
172
+ self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
173
+ self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
174
+ self._flops: Deque[int] = deque(maxlen=window_size)
175
+
176
+ def update(
177
+ self,
178
+ *,
179
+ time: float,
180
+ batches: int,
181
+ samples: int,
182
+ lengths: Optional[int] = None,
183
+ flops: Optional[int] = None,
184
+ ) -> None:
185
+ """Update throughput metrics.
186
+
187
+ Args:
188
+ time: Total elapsed time in seconds. It should monotonically increase by the iteration time with each
189
+ call.
190
+ batches: Total batches seen per device. It should monotonically increase with each call.
191
+ samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
192
+ lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
193
+ each call.
194
+ flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using
195
+ :func:`measure_flops` and multiplying it by the number of batches that have been processed.
196
+ The value might be different in each device if the batch size is not the same.
197
+
198
+ """
199
+ self._time.append(time)
200
+ if samples < batches:
201
+ raise ValueError(
202
+ f"Expected samples ({samples}) to be greater or equal than batches ({batches})"
203
+ )
204
+ self._batches.append(batches)
205
+ self._samples.append(samples)
206
+ if lengths is not None:
207
+ if lengths < samples:
208
+ raise ValueError(
209
+ f"Expected lengths ({lengths}) to be greater or equal than samples ({samples})"
210
+ )
211
+ self._lengths.append(lengths)
212
+ if len(self._samples) != len(self._lengths):
213
+ raise RuntimeError(
214
+ f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples"
215
+ f" ({len(self._samples)})"
216
+ )
217
+ if flops is not None:
218
+ # sum of flops across ranks
219
+ self._flops.append(flops * self.world_size)
220
+
221
+ def compute(self) -> _THROUGHPUT_METRICS:
222
+ """Compute throughput metrics."""
223
+ metrics = {
224
+ "time": self._time[-1],
225
+ "batches": self._batches[-1],
226
+ "samples": self._samples[-1],
227
+ }
228
+ if self._lengths:
229
+ metrics["lengths"] = self._lengths[-1]
230
+
231
+ add_global_metrics = self.world_size > 1
232
+ # a different but valid design choice would be to still compute all these metrics even if the window of values
233
+ # has not been filled
234
+ if len(self._time) == self._time.maxlen:
235
+ elapsed_time = self._time[-1] - self._time[0]
236
+ elapsed_batches = self._batches[-1] - self._batches[0]
237
+ elapsed_samples = self._samples[-1] - self._samples[0]
238
+ # we are safe from ZeroDivisionError thanks to `_MonotonicWindow`
239
+ dev_samples_per_sec = elapsed_samples / elapsed_time
240
+ dev_batches_per_sec = elapsed_batches / elapsed_time
241
+ metrics.update(
242
+ {
243
+ f"device{self.separator}batches_per_sec": elapsed_batches
244
+ / elapsed_time,
245
+ f"device{self.separator}samples_per_sec": dev_samples_per_sec,
246
+ }
247
+ )
248
+ if add_global_metrics:
249
+ samples_per_sec = dev_batches_per_sec * self.world_size
250
+ metrics.update(
251
+ {
252
+ "batches_per_sec": samples_per_sec,
253
+ "samples_per_sec": dev_samples_per_sec * self.world_size,
254
+ }
255
+ )
256
+
257
+ if len(self._lengths) == self._lengths.maxlen:
258
+ elapsed_lengths = self._lengths[-1] - self._lengths[0]
259
+ dev_items_per_sec = elapsed_lengths / elapsed_time
260
+ metrics[f"device{self.separator}items_per_sec"] = dev_items_per_sec
261
+ if add_global_metrics:
262
+ items_per_sec = dev_items_per_sec * self.world_size
263
+ metrics["items_per_sec"] = items_per_sec
264
+
265
+ if len(self._flops) == self._flops.maxlen:
266
+ elapsed_flops = sum(self._flops) - self._flops[0]
267
+ elapsed_time = self._time[-1] - self._time[0]
268
+ flops_per_sec = elapsed_flops / elapsed_time
269
+ dev_flops_per_sec = flops_per_sec / self.world_size
270
+ if add_global_metrics:
271
+ metrics["flops_per_sec"] = flops_per_sec
272
+ metrics[f"device{self.separator}flops_per_sec"] = dev_flops_per_sec
273
+ if self.available_flops:
274
+ metrics[f"device{self.separator}mfu"] = (
275
+ dev_flops_per_sec / self.available_flops
276
+ )
277
+
278
+ return metrics
279
+
280
+ def reset(self) -> None:
281
+ self._time.clear()
282
+ self._batches.clear()
283
+ self._samples.clear()
284
+ self._lengths.clear()
285
+ self._flops.clear()
286
+
287
+
288
+ class ThroughputMonitor(Callback):
289
+ r"""Computes and logs throughput with the :class:`~lightning.fabric.utilities.throughput.Throughput`
290
+
291
+ Example::
292
+
293
+ class MyModel(LightningModule):
294
+ def setup(self, stage):
295
+ with torch.device("meta"):
296
+ model = MyModel()
297
+
298
+ def sample_forward():
299
+ batch = torch.randn(..., device="meta")
300
+ return model(batch)
301
+
302
+ self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
303
+
304
+
305
+ logger = ...
306
+ throughput = ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0))
307
+ trainer = Trainer(max_steps=1000, log_every_n_steps=10, callbacks=throughput, logger=logger)
308
+ model = MyModel()
309
+ trainer.fit(model)
310
+
311
+ Notes:
312
+ - It assumes that the batch size is the same during all iterations.
313
+ - It will try to access a ``flops_per_batch`` attribute on your ``LightningModule`` on every iteration.
314
+ We suggest using the :func:`~lightning.fabric.utilities.throughput.measure_flops` function for this.
315
+ You might want to compute it differently each time based on your setup.
316
+
317
+ Args:
318
+ batch_size_fn: A function to compute the number of samples given a batch.
319
+ length_fn: A function to compute the number of items in a sample given a batch.
320
+ \**kwargs: See available parameters in
321
+ :class:`~lightning.fabric.utilities.throughput.Throughput`
322
+
323
+ """
324
+
325
+ def __init__(
326
+ self,
327
+ batch_size_fn: Callable[[Any], int],
328
+ length_fn: Optional[Callable[[Any], int | None]] = None,
329
+ **kwargs: Any,
330
+ ) -> None:
331
+ super().__init__()
332
+ self.kwargs = kwargs
333
+ self.batch_size_fn = batch_size_fn
334
+ self.length_fn = length_fn
335
+ self.available_flops: Optional[int] = None
336
+ self._throughputs: Dict[RunningStage, Throughput] = {}
337
+ self._t0s: Dict[RunningStage, float] = {}
338
+ self._samples: Dict[RunningStage, int] = {}
339
+ self._lengths: Dict[RunningStage, int] = {}
340
+
341
+ @override
342
+ def setup(
343
+ self, trainer: "Trainer", pl_module: "LightningModule", stage: str
344
+ ) -> None:
345
+ dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
346
+ self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)
347
+
348
+ if stage == TrainerFn.FITTING and trainer.enable_validation:
349
+ # `fit` includes validation inside
350
+ throughput = Throughput(
351
+ available_flops=self.available_flops,
352
+ world_size=trainer.world_size,
353
+ **self.kwargs,
354
+ )
355
+ self._throughputs[RunningStage.VALIDATING] = throughput
356
+
357
+ throughput = Throughput(
358
+ available_flops=self.available_flops,
359
+ world_size=trainer.world_size,
360
+ **self.kwargs,
361
+ )
362
+ stage = trainer.state.stage
363
+ assert stage is not None
364
+ self._throughputs[stage] = throughput
365
+
366
+ def _start(self, trainer: "Trainer") -> None:
367
+ stage = trainer.state.stage
368
+ assert stage is not None
369
+ self._throughputs[stage].reset()
370
+ self._samples[stage] = 0
371
+ self._lengths[stage] = 0
372
+ self._t0s[stage] = time.perf_counter()
373
+
374
+ def _update(
375
+ self,
376
+ trainer: "Trainer",
377
+ pl_module: "LightningModule",
378
+ batch: Any,
379
+ iter_num: int,
380
+ ) -> None:
381
+ stage = trainer.state.stage
382
+ assert stage is not None
383
+ throughput = self._throughputs[stage]
384
+
385
+ if trainer.strategy.root_device.type == "cuda":
386
+ # required or else perf_counter() won't be correct
387
+ torch.cuda.synchronize()
388
+
389
+ elapsed = time.perf_counter() - self._t0s[stage]
390
+ if self.length_fn is not None:
391
+ with torch.inference_mode():
392
+ if (length := self.length_fn(batch)) is not None:
393
+ self._lengths[stage] += length
394
+
395
+ if hasattr(pl_module, "flops_per_batch"):
396
+ flops_per_batch = pl_module.flops_per_batch
397
+ else:
398
+ rank_zero_warn(
399
+ "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property"
400
+ f" in {type(pl_module).__name__} to compute the FLOPs."
401
+ )
402
+ flops_per_batch = None
403
+
404
+ with torch.inference_mode():
405
+ self._samples[stage] += self.batch_size_fn(batch)
406
+
407
+ throughput.update(
408
+ time=elapsed,
409
+ batches=iter_num,
410
+ samples=self._samples[stage],
411
+ lengths=None if self.length_fn is None else self._lengths[stage],
412
+ flops=flops_per_batch,
413
+ )
414
+
415
+ def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
416
+ if not trainer._logger_connector.should_update_logs:
417
+ return
418
+ stage = trainer.state.stage
419
+ assert stage is not None
420
+ throughput = self._throughputs[stage]
421
+ metrics = throughput.compute()
422
+ # prefix with the stage to avoid collisions
423
+ metrics = {
424
+ f"{stage.value}{throughput.separator}{k}": v for k, v in metrics.items()
425
+ }
426
+ trainer._logger_connector.log_metrics(metrics, step=iter_num) # type: ignore[arg-type]
427
+
428
+ @override
429
+ @rank_zero_only
430
+ def on_train_start(self, trainer: "Trainer", *_: Any) -> None:
431
+ self._start(trainer)
432
+
433
+ @override
434
+ @rank_zero_only
435
+ def on_train_batch_end(
436
+ self,
437
+ trainer: "Trainer",
438
+ pl_module: "LightningModule",
439
+ outputs: Any,
440
+ batch: Any,
441
+ *_: Any,
442
+ ) -> None:
443
+ self._update(trainer, pl_module, batch, trainer.fit_loop.total_batch_idx + 1)
444
+ # log only when gradient accumulation is over. this ensures that we only measure when the effective batch has
445
+ # finished and the `optimizer.step()` time is included
446
+ if not trainer.fit_loop._should_accumulate():
447
+ self._compute(trainer)
448
+
449
+ @override
450
+ @rank_zero_only
451
+ def on_validation_start(self, trainer: "Trainer", *_: Any) -> None:
452
+ if trainer.sanity_checking:
453
+ return
454
+ self._start(trainer)
455
+
456
+ @override
457
+ @rank_zero_only
458
+ def on_validation_batch_end(
459
+ self,
460
+ trainer: "Trainer",
461
+ pl_module: "LightningModule",
462
+ outputs: Any,
463
+ batch: Any,
464
+ *_: Any,
465
+ **__: Any,
466
+ ) -> None:
467
+ if trainer.sanity_checking:
468
+ return
469
+ iter_num = trainer._evaluation_loop.batch_progress.total.ready
470
+ self._update(trainer, pl_module, batch, iter_num)
471
+ self._compute(trainer, iter_num)
472
+
473
+ @override
474
+ @rank_zero_only
475
+ def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
476
+ if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
477
+ return
478
+ # add the validation time to the training time before continuing to avoid sinking the training throughput
479
+ training_finished = self._t0s[RunningStage.TRAINING] + sum(
480
+ self._throughputs[RunningStage.TRAINING]._time
481
+ )
482
+ time_between_train_and_val = (
483
+ self._t0s[RunningStage.VALIDATING] - training_finished
484
+ )
485
+ val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
486
+ self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
487
+
488
+ @override
489
+ @rank_zero_only
490
+ def on_test_start(self, trainer: "Trainer", *_: Any) -> None:
491
+ self._start(trainer)
492
+
493
+ @override
494
+ @rank_zero_only
495
+ def on_test_batch_end(
496
+ self,
497
+ trainer: "Trainer",
498
+ pl_module: "LightningModule",
499
+ outputs: Any,
500
+ batch: Any,
501
+ *_: Any,
502
+ **__: Any,
503
+ ) -> None:
504
+ iter_num = trainer._evaluation_loop.batch_progress.total.ready
505
+ self._update(trainer, pl_module, batch, iter_num)
506
+ self._compute(trainer, iter_num)
507
+
508
+ @override
509
+ @rank_zero_only
510
+ def on_predict_start(self, trainer: "Trainer", *_: Any) -> None:
511
+ self._start(trainer)
512
+
513
+ @override
514
+ @rank_zero_only
515
+ def on_predict_batch_end(
516
+ self,
517
+ trainer: "Trainer",
518
+ pl_module: "LightningModule",
519
+ outputs: Any,
520
+ batch: Any,
521
+ *_: Any,
522
+ **__: Any,
523
+ ) -> None:
524
+ iter_num = trainer.predict_loop.batch_progress.total.ready
525
+ self._update(trainer, pl_module, batch, iter_num)
526
+ self._compute(trainer, iter_num)
527
+
528
+
529
+ def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype:
530
+ # TODO: integrate this into the precision plugins
531
+ if not isinstance(plugin, Precision):
532
+ return fabric_plugin_to_compute_dtype(plugin)
533
+ if isinstance(plugin, BitsandbytesPrecision):
534
+ return plugin.dtype
535
+ if isinstance(plugin, HalfPrecision):
536
+ return plugin._desired_input_dtype
537
+ if isinstance(plugin, MixedPrecision):
538
+ return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
539
+ if isinstance(plugin, DoublePrecision):
540
+ return torch.double
541
+ if isinstance(plugin, (XLAPrecision, DeepSpeedPrecision)):
542
+ return plugin._desired_dtype
543
+ if isinstance(plugin, TransformerEnginePrecision):
544
+ return torch.int8
545
+ if isinstance(plugin, FSDPPrecision):
546
+ return plugin.mixed_precision_config.reduce_dtype or torch.float32
547
+ if isinstance(plugin, Precision):
548
+ return torch.float32
549
+ raise NotImplementedError(plugin)
@@ -0,0 +1,113 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, TypeAlias, TypedDict
6
+
7
+ from lightning.pytorch import Callback
8
+
9
+ from ..config import TypedConfig
10
+
11
+ if TYPE_CHECKING:
12
+ from ..model.config import BaseConfig
13
+
14
+
15
+ class CallbackMetadataDict(TypedDict, total=False):
16
+ ignore_if_exists: bool
17
+ """If `True`, the callback will not be added if another callback with the same class already exists."""
18
+
19
+ priority: int
20
+ """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
+
22
+
23
+ class CallbackMetadataConfig(TypedConfig):
24
+ ignore_if_exists: bool = False
25
+ """If `True`, the callback will not be added if another callback with the same class already exists."""
26
+
27
+ priority: int = 0
28
+ """Priority of the callback. Callbacks with higher priority will be loaded first."""
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class CallbackWithMetadata:
33
+ callback: Callback
34
+ metadata: CallbackMetadataConfig
35
+
36
+
37
+ ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
38
+
39
+
40
+ class CallbackConfigBase(TypedConfig, ABC):
41
+ metadata: CallbackMetadataConfig = CallbackMetadataConfig()
42
+ """Metadata for the callback."""
43
+
44
+ def with_metadata(self, callback: Callback, **metadata: CallbackMetadataDict):
45
+ return CallbackWithMetadata(
46
+ callback=callback, metadata=self.metadata.model_copy(update=metadata)
47
+ )
48
+
49
+ @abstractmethod
50
+ def construct_callbacks(
51
+ self, root_config: "BaseConfig"
52
+ ) -> Iterable[Callback | CallbackWithMetadata]: ...
53
+
54
+
55
+ # region Config resolution helpers
56
+ def _construct_callbacks_with_metadata(
57
+ config: CallbackConfigBase, root_config: "BaseConfig"
58
+ ) -> Iterable[CallbackWithMetadata]:
59
+ for callback in config.construct_callbacks(root_config):
60
+ if isinstance(callback, CallbackWithMetadata):
61
+ yield callback
62
+ continue
63
+
64
+ callback = config.with_metadata(callback)
65
+ yield callback
66
+
67
+
68
+ def _filter_ignore_if_exists(callbacks: list[CallbackWithMetadata]):
69
+ # First, let's do a pass over all callbacks to hold the count of each callback class
70
+ callback_classes = Counter(callback.callback.__class__ for callback in callbacks)
71
+
72
+ # Remove non-duplicates
73
+ callbacks_filtered: list[CallbackWithMetadata] = []
74
+ for callback in callbacks:
75
+ # If `ignore_if_exists` is `True` and there is already a callback of the same class, skip this callback
76
+ if (
77
+ callback.metadata.ignore_if_exists
78
+ and callback_classes[callback.callback.__class__] > 1
79
+ ):
80
+ continue
81
+
82
+ callbacks_filtered.append(callback)
83
+
84
+ return callbacks_filtered
85
+
86
+
87
+ def _process_and_filter_callbacks(
88
+ callbacks: Iterable[CallbackWithMetadata],
89
+ ) -> list[Callback]:
90
+ callbacks = list(callbacks)
91
+
92
+ # Sort by priority (higher priority first)
93
+ callbacks.sort(key=lambda callback: callback.metadata.priority, reverse=True)
94
+
95
+ # Process `ignore_if_exists`
96
+ callbacks = _filter_ignore_if_exists(callbacks)
97
+
98
+ return [callback.callback for callback in callbacks]
99
+
100
+
101
+ def resolve_all_callbacks(root_config: "BaseConfig"):
102
+ callback_configs = [
103
+ config for config in root_config.ll_all_callback_configs() if config is not None
104
+ ]
105
+ callbacks = _process_and_filter_callbacks(
106
+ callback
107
+ for callback_config in callback_configs
108
+ for callback in _construct_callbacks_with_metadata(callback_config, root_config)
109
+ )
110
+ return callbacks
111
+
112
+
113
+ # endregion