nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -1,551 +0,0 @@
1
- # type: ignore
2
- # Copyright The Lightning AI team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- from __future__ import annotations
16
-
17
- import time
18
- from collections import deque
19
- from typing import (
20
- TYPE_CHECKING,
21
- Any,
22
- Callable,
23
- Deque,
24
- Dict,
25
- List,
26
- Optional,
27
- TypeVar,
28
- Union,
29
- )
30
-
31
- import torch
32
- from lightning.fabric.plugins import Precision as FabricPrecision
33
- from lightning.fabric.utilities.throughput import (
34
- _plugin_to_compute_dtype as fabric_plugin_to_compute_dtype,
35
- )
36
- from lightning.fabric.utilities.throughput import get_available_flops
37
- from lightning.pytorch.callbacks import Callback
38
- from lightning.pytorch.plugins import (
39
- BitsandbytesPrecision,
40
- DeepSpeedPrecision,
41
- DoublePrecision,
42
- FSDPPrecision,
43
- HalfPrecision,
44
- MixedPrecision,
45
- Precision,
46
- TransformerEnginePrecision,
47
- XLAPrecision,
48
- )
49
- from lightning.pytorch.trainer.states import RunningStage, TrainerFn
50
- from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
51
- from typing_extensions import override
52
-
53
- if TYPE_CHECKING:
54
- from lightning.pytorch import LightningModule, Trainer
55
-
56
-
57
- _THROUGHPUT_METRICS = Dict[str, Union[int, float]]
58
-
59
- T = TypeVar("T", bound=float)
60
-
61
-
62
- class _MonotonicWindow(List[T]):
63
- """Custom fixed size list that only supports right-append and ensures that all values increase monotonically."""
64
-
65
- def __init__(self, maxlen: int) -> None:
66
- super().__init__()
67
- self.maxlen = maxlen
68
-
69
- @property
70
- def last(self) -> Optional[T]:
71
- if len(self) > 0:
72
- return self[-1]
73
- return None
74
-
75
- @override
76
- def append(self, x: T) -> None:
77
- last = self.last
78
- if last is not None and last >= x:
79
- rank_zero_warn(
80
- f"Expected the value to increase, last: {last}, current: {x}"
81
- )
82
- list.append(self, x)
83
- # truncate excess
84
- if len(self) > self.maxlen:
85
- del self[0]
86
-
87
- @override
88
- def __setitem__(self, key: Any, value: Any) -> None:
89
- # assigning is not implemented since we don't use it. it could be by checking all previous values
90
- raise NotImplementedError("__setitem__ is not supported")
91
-
92
-
93
- class Throughput:
94
- """Computes throughput.
95
-
96
- +------------------------+-------------------------------------------------------------------------------------+
97
- | Key | Value |
98
- +========================+=====================================================================================+
99
- | batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
100
- | | processed per second |
101
- +--------------------------+-----------------------------------------------------------------------------------+
102
- | samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
103
- | | processed per second |
104
- +--------------------------+-----------------------------------------------------------------------------------+
105
- | items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
106
- | | processed per second |
107
- +--------------------------+-----------------------------------------------------------------------------------+
108
- | flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
109
- | | processed per second |
110
- +--------------------------+-----------------------------------------------------------------------------------+
111
- | device/batches_per_sec | batches_per_sec divided by world size |
112
- +--------------------------+-----------------------------------------------------------------------------------+
113
- | device/samples_per_sec | samples_per_sec divided by world size |
114
- +--------------------------+-----------------------------------------------------------------------------------+
115
- | device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data |
116
- +--------------------------+-----------------------------------------------------------------------------------+
117
- | device/flops_per_sec | flops_per_sec divided by world size. |
118
- +--------------------------+-----------------------------------------------------------------------------------+
119
- | device/mfu | device/flops_per_sec divided by world size. |
120
- +--------------------------+-----------------------------------------------------------------------------------+
121
- | time | Total elapsed time |
122
- +--------------------------+-----------------------------------------------------------------------------------+
123
- | batches | Total batches seen |
124
- +--------------------------+-----------------------------------------------------------------------------------+
125
- | samples | Total samples seen |
126
- +--------------------------+-----------------------------------------------------------------------------------+
127
- | lengths | Total items seen |
128
- +--------------------------+-----------------------------------------------------------------------------------+
129
-
130
- Example::
131
-
132
- throughput = Throughput()
133
- t0 = time()
134
- for i in range(1000):
135
- do_work()
136
- if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct
137
- throughput.update(time=time() - t0, samples=i)
138
- if i % 10 == 0:
139
- print(throughput.compute())
140
-
141
- Notes:
142
- - The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only
143
- takes a single ``available_flops`` value.
144
- - items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using
145
- samples_per_sec or batches_per_sec to measure throughput under this circumstance.
146
-
147
- Args:
148
- available_flops: Number of theoretical flops available for a single device.
149
- world_size: Number of devices available across hosts. Global metrics are not included if the world size is 1.
150
- window_size: Number of batches to use for a rolling average.
151
- separator: Key separator to use when creating per-device and global metrics.
152
-
153
- """
154
-
155
- def __init__(
156
- self,
157
- available_flops: Optional[float] = None,
158
- world_size: int = 1,
159
- window_size: int = 100,
160
- separator: str = "/",
161
- ) -> None:
162
- self.available_flops = available_flops
163
- self.separator = separator
164
- assert world_size > 0
165
- self.world_size = world_size
166
-
167
- # throughput is computed over a window of values. at least 2 is enforced since it looks at the difference
168
- # between the first and last elements
169
- assert window_size > 1
170
- # custom class instead of `deque(maxlen=)` because it's easy for users to mess up their timer/counters and log
171
- # values that do not increase monotonically. this class will raise an error if that happens.
172
- self._time: _MonotonicWindow[float] = _MonotonicWindow(maxlen=window_size)
173
- self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
174
- self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
175
- self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
176
- self._flops: Deque[int] = deque(maxlen=window_size)
177
-
178
- def update(
179
- self,
180
- *,
181
- time: float,
182
- batches: int,
183
- samples: int,
184
- lengths: Optional[int] = None,
185
- flops: Optional[int] = None,
186
- ) -> None:
187
- """Update throughput metrics.
188
-
189
- Args:
190
- time: Total elapsed time in seconds. It should monotonically increase by the iteration time with each
191
- call.
192
- batches: Total batches seen per device. It should monotonically increase with each call.
193
- samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
194
- lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
195
- each call.
196
- flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using
197
- :func:`measure_flops` and multiplying it by the number of batches that have been processed.
198
- The value might be different in each device if the batch size is not the same.
199
-
200
- """
201
- self._time.append(time)
202
- if samples < batches:
203
- raise ValueError(
204
- f"Expected samples ({samples}) to be greater or equal than batches ({batches})"
205
- )
206
- self._batches.append(batches)
207
- self._samples.append(samples)
208
- if lengths is not None:
209
- if lengths < samples:
210
- raise ValueError(
211
- f"Expected lengths ({lengths}) to be greater or equal than samples ({samples})"
212
- )
213
- self._lengths.append(lengths)
214
- if len(self._samples) != len(self._lengths):
215
- raise RuntimeError(
216
- f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples"
217
- f" ({len(self._samples)})"
218
- )
219
- if flops is not None:
220
- # sum of flops across ranks
221
- self._flops.append(flops * self.world_size)
222
-
223
- def compute(self) -> _THROUGHPUT_METRICS:
224
- """Compute throughput metrics."""
225
- metrics = {
226
- "time": self._time[-1],
227
- "batches": self._batches[-1],
228
- "samples": self._samples[-1],
229
- }
230
- if self._lengths:
231
- metrics["lengths"] = self._lengths[-1]
232
-
233
- add_global_metrics = self.world_size > 1
234
- # a different but valid design choice would be to still compute all these metrics even if the window of values
235
- # has not been filled
236
- if len(self._time) == self._time.maxlen:
237
- elapsed_time = self._time[-1] - self._time[0]
238
- elapsed_batches = self._batches[-1] - self._batches[0]
239
- elapsed_samples = self._samples[-1] - self._samples[0]
240
- # we are safe from ZeroDivisionError thanks to `_MonotonicWindow`
241
- dev_samples_per_sec = elapsed_samples / elapsed_time
242
- dev_batches_per_sec = elapsed_batches / elapsed_time
243
- metrics.update(
244
- {
245
- f"device{self.separator}batches_per_sec": elapsed_batches
246
- / elapsed_time,
247
- f"device{self.separator}samples_per_sec": dev_samples_per_sec,
248
- }
249
- )
250
- if add_global_metrics:
251
- samples_per_sec = dev_batches_per_sec * self.world_size
252
- metrics.update(
253
- {
254
- "batches_per_sec": samples_per_sec,
255
- "samples_per_sec": dev_samples_per_sec * self.world_size,
256
- }
257
- )
258
-
259
- if len(self._lengths) == self._lengths.maxlen:
260
- elapsed_lengths = self._lengths[-1] - self._lengths[0]
261
- dev_items_per_sec = elapsed_lengths / elapsed_time
262
- metrics[f"device{self.separator}items_per_sec"] = dev_items_per_sec
263
- if add_global_metrics:
264
- items_per_sec = dev_items_per_sec * self.world_size
265
- metrics["items_per_sec"] = items_per_sec
266
-
267
- if len(self._flops) == self._flops.maxlen:
268
- elapsed_flops = sum(self._flops) - self._flops[0]
269
- elapsed_time = self._time[-1] - self._time[0]
270
- flops_per_sec = elapsed_flops / elapsed_time
271
- dev_flops_per_sec = flops_per_sec / self.world_size
272
- if add_global_metrics:
273
- metrics["flops_per_sec"] = flops_per_sec
274
- metrics[f"device{self.separator}flops_per_sec"] = dev_flops_per_sec
275
- if self.available_flops:
276
- metrics[f"device{self.separator}mfu"] = (
277
- dev_flops_per_sec / self.available_flops
278
- )
279
-
280
- return metrics
281
-
282
- def reset(self) -> None:
283
- self._time.clear()
284
- self._batches.clear()
285
- self._samples.clear()
286
- self._lengths.clear()
287
- self._flops.clear()
288
-
289
-
290
- class ThroughputMonitor(Callback):
291
- r"""Computes and logs throughput with the :class:`~lightning.fabric.utilities.throughput.Throughput`
292
-
293
- Example::
294
-
295
- class MyModel(LightningModule):
296
- def setup(self, stage):
297
- with torch.device("meta"):
298
- model = MyModel()
299
-
300
- def sample_forward():
301
- batch = torch.randn(..., device="meta")
302
- return model(batch)
303
-
304
- self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
305
-
306
-
307
- logger = ...
308
- throughput = ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0))
309
- trainer = Trainer(max_steps=1000, log_every_n_steps=10, callbacks=throughput, logger=logger)
310
- model = MyModel()
311
- trainer.fit(model)
312
-
313
- Notes:
314
- - It assumes that the batch size is the same during all iterations.
315
- - It will try to access a ``flops_per_batch`` attribute on your ``LightningModule`` on every iteration.
316
- We suggest using the :func:`~lightning.fabric.utilities.throughput.measure_flops` function for this.
317
- You might want to compute it differently each time based on your setup.
318
-
319
- Args:
320
- batch_size_fn: A function to compute the number of samples given a batch.
321
- length_fn: A function to compute the number of items in a sample given a batch.
322
- \**kwargs: See available parameters in
323
- :class:`~lightning.fabric.utilities.throughput.Throughput`
324
-
325
- """
326
-
327
- def __init__(
328
- self,
329
- batch_size_fn: Callable[[Any], int],
330
- length_fn: Optional[Callable[[Any], int | None]] = None,
331
- **kwargs: Any,
332
- ) -> None:
333
- super().__init__()
334
- self.kwargs = kwargs
335
- self.batch_size_fn = batch_size_fn
336
- self.length_fn = length_fn
337
- self.available_flops: Optional[int] = None
338
- self._throughputs: Dict[RunningStage, Throughput] = {}
339
- self._t0s: Dict[RunningStage, float] = {}
340
- self._samples: Dict[RunningStage, int] = {}
341
- self._lengths: Dict[RunningStage, int] = {}
342
-
343
- @override
344
- def setup(
345
- self, trainer: "Trainer", pl_module: "LightningModule", stage: str
346
- ) -> None:
347
- dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
348
- self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)
349
-
350
- if stage == TrainerFn.FITTING and trainer.enable_validation:
351
- # `fit` includes validation inside
352
- throughput = Throughput(
353
- available_flops=self.available_flops,
354
- world_size=trainer.world_size,
355
- **self.kwargs,
356
- )
357
- self._throughputs[RunningStage.VALIDATING] = throughput
358
-
359
- throughput = Throughput(
360
- available_flops=self.available_flops,
361
- world_size=trainer.world_size,
362
- **self.kwargs,
363
- )
364
- stage = trainer.state.stage
365
- assert stage is not None
366
- self._throughputs[stage] = throughput
367
-
368
- def _start(self, trainer: "Trainer") -> None:
369
- stage = trainer.state.stage
370
- assert stage is not None
371
- self._throughputs[stage].reset()
372
- self._samples[stage] = 0
373
- self._lengths[stage] = 0
374
- self._t0s[stage] = time.perf_counter()
375
-
376
- def _update(
377
- self,
378
- trainer: "Trainer",
379
- pl_module: "LightningModule",
380
- batch: Any,
381
- iter_num: int,
382
- ) -> None:
383
- stage = trainer.state.stage
384
- assert stage is not None
385
- throughput = self._throughputs[stage]
386
-
387
- if trainer.strategy.root_device.type == "cuda":
388
- # required or else perf_counter() won't be correct
389
- torch.cuda.synchronize()
390
-
391
- elapsed = time.perf_counter() - self._t0s[stage]
392
- if self.length_fn is not None:
393
- with torch.inference_mode():
394
- if (length := self.length_fn(batch)) is not None:
395
- self._lengths[stage] += length
396
-
397
- if hasattr(pl_module, "flops_per_batch"):
398
- flops_per_batch = pl_module.flops_per_batch
399
- else:
400
- rank_zero_warn(
401
- "When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property"
402
- f" in {type(pl_module).__name__} to compute the FLOPs."
403
- )
404
- flops_per_batch = None
405
-
406
- with torch.inference_mode():
407
- self._samples[stage] += self.batch_size_fn(batch)
408
-
409
- throughput.update(
410
- time=elapsed,
411
- batches=iter_num,
412
- samples=self._samples[stage],
413
- lengths=None if self.length_fn is None else self._lengths[stage],
414
- flops=flops_per_batch,
415
- )
416
-
417
- def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
418
- if not trainer._logger_connector.should_update_logs:
419
- return
420
- stage = trainer.state.stage
421
- assert stage is not None
422
- throughput = self._throughputs[stage]
423
- metrics = throughput.compute()
424
- # prefix with the stage to avoid collisions
425
- metrics = {
426
- f"{stage.value}{throughput.separator}{k}": v for k, v in metrics.items()
427
- }
428
- trainer._logger_connector.log_metrics(metrics, step=iter_num) # type: ignore[arg-type]
429
-
430
- @override
431
- @rank_zero_only
432
- def on_train_start(self, trainer: "Trainer", *_: Any) -> None:
433
- self._start(trainer)
434
-
435
- @override
436
- @rank_zero_only
437
- def on_train_batch_end(
438
- self,
439
- trainer: "Trainer",
440
- pl_module: "LightningModule",
441
- outputs: Any,
442
- batch: Any,
443
- *_: Any,
444
- ) -> None:
445
- self._update(trainer, pl_module, batch, trainer.fit_loop.total_batch_idx + 1)
446
- # log only when gradient accumulation is over. this ensures that we only measure when the effective batch has
447
- # finished and the `optimizer.step()` time is included
448
- if not trainer.fit_loop._should_accumulate():
449
- self._compute(trainer)
450
-
451
- @override
452
- @rank_zero_only
453
- def on_validation_start(self, trainer: "Trainer", *_: Any) -> None:
454
- if trainer.sanity_checking:
455
- return
456
- self._start(trainer)
457
-
458
- @override
459
- @rank_zero_only
460
- def on_validation_batch_end(
461
- self,
462
- trainer: "Trainer",
463
- pl_module: "LightningModule",
464
- outputs: Any,
465
- batch: Any,
466
- *_: Any,
467
- **__: Any,
468
- ) -> None:
469
- if trainer.sanity_checking:
470
- return
471
- iter_num = trainer._evaluation_loop.batch_progress.total.ready
472
- self._update(trainer, pl_module, batch, iter_num)
473
- self._compute(trainer, iter_num)
474
-
475
- @override
476
- @rank_zero_only
477
- def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
478
- if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
479
- return
480
- # add the validation time to the training time before continuing to avoid sinking the training throughput
481
- training_finished = self._t0s[RunningStage.TRAINING] + sum(
482
- self._throughputs[RunningStage.TRAINING]._time
483
- )
484
- time_between_train_and_val = (
485
- self._t0s[RunningStage.VALIDATING] - training_finished
486
- )
487
- val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
488
- self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
489
-
490
- @override
491
- @rank_zero_only
492
- def on_test_start(self, trainer: "Trainer", *_: Any) -> None:
493
- self._start(trainer)
494
-
495
- @override
496
- @rank_zero_only
497
- def on_test_batch_end(
498
- self,
499
- trainer: "Trainer",
500
- pl_module: "LightningModule",
501
- outputs: Any,
502
- batch: Any,
503
- *_: Any,
504
- **__: Any,
505
- ) -> None:
506
- iter_num = trainer._evaluation_loop.batch_progress.total.ready
507
- self._update(trainer, pl_module, batch, iter_num)
508
- self._compute(trainer, iter_num)
509
-
510
- @override
511
- @rank_zero_only
512
- def on_predict_start(self, trainer: "Trainer", *_: Any) -> None:
513
- self._start(trainer)
514
-
515
- @override
516
- @rank_zero_only
517
- def on_predict_batch_end(
518
- self,
519
- trainer: "Trainer",
520
- pl_module: "LightningModule",
521
- outputs: Any,
522
- batch: Any,
523
- *_: Any,
524
- **__: Any,
525
- ) -> None:
526
- iter_num = trainer.predict_loop.batch_progress.total.ready
527
- self._update(trainer, pl_module, batch, iter_num)
528
- self._compute(trainer, iter_num)
529
-
530
-
531
- def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype:
532
- # TODO: integrate this into the precision plugins
533
- if not isinstance(plugin, Precision):
534
- return fabric_plugin_to_compute_dtype(plugin)
535
- if isinstance(plugin, BitsandbytesPrecision):
536
- return plugin.dtype
537
- if isinstance(plugin, HalfPrecision):
538
- return plugin._desired_input_dtype
539
- if isinstance(plugin, MixedPrecision):
540
- return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
541
- if isinstance(plugin, DoublePrecision):
542
- return torch.double
543
- if isinstance(plugin, (XLAPrecision, DeepSpeedPrecision)):
544
- return plugin._desired_dtype
545
- if isinstance(plugin, TransformerEnginePrecision):
546
- return torch.int8
547
- if isinstance(plugin, FSDPPrecision):
548
- return plugin.mixed_precision_config.reduce_dtype or torch.float32
549
- if isinstance(plugin, Precision):
550
- return torch.float32
551
- raise NotImplementedError(plugin)
@@ -1,58 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
5
-
6
- from typing_extensions import NotRequired, override
7
-
8
- from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
9
- from .base import CallbackConfigBase
10
-
11
- log = logging.getLogger(__name__)
12
-
13
-
14
- class ThroughputMonitorBatchStats(TypedDict):
15
- batch_size: int
16
- length: NotRequired[int | None]
17
-
18
-
19
- @runtime_checkable
20
- class SupportsThroughputMonitorModuleProtocol(Protocol):
21
- def throughput_monitor_batch_stats(
22
- self, batch: Any
23
- ) -> ThroughputMonitorBatchStats: ...
24
-
25
-
26
- class ThroughputMonitor(_ThroughputMonitor):
27
- def __init__(self, window_size: int = 100) -> None:
28
- super().__init__(cast(Any, None), cast(Any, None), window_size=window_size)
29
-
30
- @override
31
- def setup(self, trainer, pl_module, stage):
32
- if not isinstance(pl_module, SupportsThroughputMonitorModuleProtocol):
33
- raise RuntimeError(
34
- "The model does not implement `throughput_monitor_batch_stats`. "
35
- "Please either implement this method, or do not use the `ThroughputMonitor` callback."
36
- )
37
-
38
- def batch_size_fn(batch):
39
- return pl_module.throughput_monitor_batch_stats(batch)["batch_size"]
40
-
41
- def length_fn(batch):
42
- return pl_module.throughput_monitor_batch_stats(batch).get("length")
43
-
44
- self.batch_size_fn = batch_size_fn
45
- self.length_fn = length_fn
46
-
47
- return super().setup(trainer, pl_module, stage)
48
-
49
-
50
- class ThroughputMonitorConfig(CallbackConfigBase):
51
- name: Literal["throughput_monitor"] = "throughput_monitor"
52
-
53
- window_size: int = 100
54
- """Number of batches to use for a rolling average."""
55
-
56
- @override
57
- def create_callbacks(self, root_config):
58
- yield ThroughputMonitor(window_size=self.window_size)
@@ -1,41 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.model import BaseConfig as BaseConfig
11
- from nshtrainer.model import DirectoryConfig as DirectoryConfig
12
- from nshtrainer.model import MetricConfig as MetricConfig
13
- from nshtrainer.model import TrainerConfig as TrainerConfig
14
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
15
- from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
16
- else:
17
-
18
- def __getattr__(name):
19
- import importlib
20
-
21
- if name in globals():
22
- return globals()[name]
23
- if name == "MetricConfig":
24
- return importlib.import_module("nshtrainer.model").MetricConfig
25
- if name == "TrainerConfig":
26
- return importlib.import_module("nshtrainer.model").TrainerConfig
27
- if name == "BaseConfig":
28
- return importlib.import_module("nshtrainer.model").BaseConfig
29
- if name == "EnvironmentConfig":
30
- return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
31
- if name == "DirectoryConfig":
32
- return importlib.import_module("nshtrainer.model").DirectoryConfig
33
- if name == "CallbackConfigBase":
34
- return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
35
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
36
-
37
-
38
- # Submodule exports
39
- from . import base as base
40
- from . import config as config
41
- from . import mixins as mixins
@@ -1,25 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- # Config/alias imports
8
-
9
- if TYPE_CHECKING:
10
- from nshtrainer.model.base import BaseConfig as BaseConfig
11
- from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
12
- else:
13
-
14
- def __getattr__(name):
15
- import importlib
16
-
17
- if name in globals():
18
- return globals()[name]
19
- if name == "BaseConfig":
20
- return importlib.import_module("nshtrainer.model.base").BaseConfig
21
- if name == "EnvironmentConfig":
22
- return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
23
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
24
-
25
- # Submodule exports