nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +51 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/_base.py +2 -1
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.0.dist-info/RECORD +0 -162
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
@@ -1,551 +0,0 @@
|
|
1
|
-
# type: ignore
|
2
|
-
# Copyright The Lightning AI team.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
from __future__ import annotations
|
16
|
-
|
17
|
-
import time
|
18
|
-
from collections import deque
|
19
|
-
from typing import (
|
20
|
-
TYPE_CHECKING,
|
21
|
-
Any,
|
22
|
-
Callable,
|
23
|
-
Deque,
|
24
|
-
Dict,
|
25
|
-
List,
|
26
|
-
Optional,
|
27
|
-
TypeVar,
|
28
|
-
Union,
|
29
|
-
)
|
30
|
-
|
31
|
-
import torch
|
32
|
-
from lightning.fabric.plugins import Precision as FabricPrecision
|
33
|
-
from lightning.fabric.utilities.throughput import (
|
34
|
-
_plugin_to_compute_dtype as fabric_plugin_to_compute_dtype,
|
35
|
-
)
|
36
|
-
from lightning.fabric.utilities.throughput import get_available_flops
|
37
|
-
from lightning.pytorch.callbacks import Callback
|
38
|
-
from lightning.pytorch.plugins import (
|
39
|
-
BitsandbytesPrecision,
|
40
|
-
DeepSpeedPrecision,
|
41
|
-
DoublePrecision,
|
42
|
-
FSDPPrecision,
|
43
|
-
HalfPrecision,
|
44
|
-
MixedPrecision,
|
45
|
-
Precision,
|
46
|
-
TransformerEnginePrecision,
|
47
|
-
XLAPrecision,
|
48
|
-
)
|
49
|
-
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
|
50
|
-
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
51
|
-
from typing_extensions import override
|
52
|
-
|
53
|
-
if TYPE_CHECKING:
|
54
|
-
from lightning.pytorch import LightningModule, Trainer
|
55
|
-
|
56
|
-
|
57
|
-
_THROUGHPUT_METRICS = Dict[str, Union[int, float]]
|
58
|
-
|
59
|
-
T = TypeVar("T", bound=float)
|
60
|
-
|
61
|
-
|
62
|
-
class _MonotonicWindow(List[T]):
|
63
|
-
"""Custom fixed size list that only supports right-append and ensures that all values increase monotonically."""
|
64
|
-
|
65
|
-
def __init__(self, maxlen: int) -> None:
|
66
|
-
super().__init__()
|
67
|
-
self.maxlen = maxlen
|
68
|
-
|
69
|
-
@property
|
70
|
-
def last(self) -> Optional[T]:
|
71
|
-
if len(self) > 0:
|
72
|
-
return self[-1]
|
73
|
-
return None
|
74
|
-
|
75
|
-
@override
|
76
|
-
def append(self, x: T) -> None:
|
77
|
-
last = self.last
|
78
|
-
if last is not None and last >= x:
|
79
|
-
rank_zero_warn(
|
80
|
-
f"Expected the value to increase, last: {last}, current: {x}"
|
81
|
-
)
|
82
|
-
list.append(self, x)
|
83
|
-
# truncate excess
|
84
|
-
if len(self) > self.maxlen:
|
85
|
-
del self[0]
|
86
|
-
|
87
|
-
@override
|
88
|
-
def __setitem__(self, key: Any, value: Any) -> None:
|
89
|
-
# assigning is not implemented since we don't use it. it could be by checking all previous values
|
90
|
-
raise NotImplementedError("__setitem__ is not supported")
|
91
|
-
|
92
|
-
|
93
|
-
class Throughput:
|
94
|
-
"""Computes throughput.
|
95
|
-
|
96
|
-
+------------------------+-------------------------------------------------------------------------------------+
|
97
|
-
| Key | Value |
|
98
|
-
+========================+=====================================================================================+
|
99
|
-
| batches_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of batches |
|
100
|
-
| | processed per second |
|
101
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
102
|
-
| samples_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of samples |
|
103
|
-
| | processed per second |
|
104
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
105
|
-
| items_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of items |
|
106
|
-
| | processed per second |
|
107
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
108
|
-
| flpps_per_sec | Rolling average (over ``window_size`` most recent updates) of the number of flops |
|
109
|
-
| | processed per second |
|
110
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
111
|
-
| device/batches_per_sec | batches_per_sec divided by world size |
|
112
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
113
|
-
| device/samples_per_sec | samples_per_sec divided by world size |
|
114
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
115
|
-
| device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data |
|
116
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
117
|
-
| device/flops_per_sec | flops_per_sec divided by world size. |
|
118
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
119
|
-
| device/mfu | device/flops_per_sec divided by world size. |
|
120
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
121
|
-
| time | Total elapsed time |
|
122
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
123
|
-
| batches | Total batches seen |
|
124
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
125
|
-
| samples | Total samples seen |
|
126
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
127
|
-
| lengths | Total items seen |
|
128
|
-
+--------------------------+-----------------------------------------------------------------------------------+
|
129
|
-
|
130
|
-
Example::
|
131
|
-
|
132
|
-
throughput = Throughput()
|
133
|
-
t0 = time()
|
134
|
-
for i in range(1000):
|
135
|
-
do_work()
|
136
|
-
if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct
|
137
|
-
throughput.update(time=time() - t0, samples=i)
|
138
|
-
if i % 10 == 0:
|
139
|
-
print(throughput.compute())
|
140
|
-
|
141
|
-
Notes:
|
142
|
-
- The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only
|
143
|
-
takes a single ``available_flops`` value.
|
144
|
-
- items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using
|
145
|
-
samples_per_sec or batches_per_sec to measure throughput under this circumstance.
|
146
|
-
|
147
|
-
Args:
|
148
|
-
available_flops: Number of theoretical flops available for a single device.
|
149
|
-
world_size: Number of devices available across hosts. Global metrics are not included if the world size is 1.
|
150
|
-
window_size: Number of batches to use for a rolling average.
|
151
|
-
separator: Key separator to use when creating per-device and global metrics.
|
152
|
-
|
153
|
-
"""
|
154
|
-
|
155
|
-
def __init__(
|
156
|
-
self,
|
157
|
-
available_flops: Optional[float] = None,
|
158
|
-
world_size: int = 1,
|
159
|
-
window_size: int = 100,
|
160
|
-
separator: str = "/",
|
161
|
-
) -> None:
|
162
|
-
self.available_flops = available_flops
|
163
|
-
self.separator = separator
|
164
|
-
assert world_size > 0
|
165
|
-
self.world_size = world_size
|
166
|
-
|
167
|
-
# throughput is computed over a window of values. at least 2 is enforced since it looks at the difference
|
168
|
-
# between the first and last elements
|
169
|
-
assert window_size > 1
|
170
|
-
# custom class instead of `deque(maxlen=)` because it's easy for users to mess up their timer/counters and log
|
171
|
-
# values that do not increase monotonically. this class will raise an error if that happens.
|
172
|
-
self._time: _MonotonicWindow[float] = _MonotonicWindow(maxlen=window_size)
|
173
|
-
self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
|
174
|
-
self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
|
175
|
-
self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size)
|
176
|
-
self._flops: Deque[int] = deque(maxlen=window_size)
|
177
|
-
|
178
|
-
def update(
|
179
|
-
self,
|
180
|
-
*,
|
181
|
-
time: float,
|
182
|
-
batches: int,
|
183
|
-
samples: int,
|
184
|
-
lengths: Optional[int] = None,
|
185
|
-
flops: Optional[int] = None,
|
186
|
-
) -> None:
|
187
|
-
"""Update throughput metrics.
|
188
|
-
|
189
|
-
Args:
|
190
|
-
time: Total elapsed time in seconds. It should monotonically increase by the iteration time with each
|
191
|
-
call.
|
192
|
-
batches: Total batches seen per device. It should monotonically increase with each call.
|
193
|
-
samples: Total samples seen per device. It should monotonically increase by the batch size with each call.
|
194
|
-
lengths: Total length of the samples seen. It should monotonically increase by the lengths of a batch with
|
195
|
-
each call.
|
196
|
-
flops: Flops elapased per device since last ``update()`` call. You can easily compute this by using
|
197
|
-
:func:`measure_flops` and multiplying it by the number of batches that have been processed.
|
198
|
-
The value might be different in each device if the batch size is not the same.
|
199
|
-
|
200
|
-
"""
|
201
|
-
self._time.append(time)
|
202
|
-
if samples < batches:
|
203
|
-
raise ValueError(
|
204
|
-
f"Expected samples ({samples}) to be greater or equal than batches ({batches})"
|
205
|
-
)
|
206
|
-
self._batches.append(batches)
|
207
|
-
self._samples.append(samples)
|
208
|
-
if lengths is not None:
|
209
|
-
if lengths < samples:
|
210
|
-
raise ValueError(
|
211
|
-
f"Expected lengths ({lengths}) to be greater or equal than samples ({samples})"
|
212
|
-
)
|
213
|
-
self._lengths.append(lengths)
|
214
|
-
if len(self._samples) != len(self._lengths):
|
215
|
-
raise RuntimeError(
|
216
|
-
f"If lengths are passed ({len(self._lengths)}), there needs to be the same number of samples"
|
217
|
-
f" ({len(self._samples)})"
|
218
|
-
)
|
219
|
-
if flops is not None:
|
220
|
-
# sum of flops across ranks
|
221
|
-
self._flops.append(flops * self.world_size)
|
222
|
-
|
223
|
-
def compute(self) -> _THROUGHPUT_METRICS:
|
224
|
-
"""Compute throughput metrics."""
|
225
|
-
metrics = {
|
226
|
-
"time": self._time[-1],
|
227
|
-
"batches": self._batches[-1],
|
228
|
-
"samples": self._samples[-1],
|
229
|
-
}
|
230
|
-
if self._lengths:
|
231
|
-
metrics["lengths"] = self._lengths[-1]
|
232
|
-
|
233
|
-
add_global_metrics = self.world_size > 1
|
234
|
-
# a different but valid design choice would be to still compute all these metrics even if the window of values
|
235
|
-
# has not been filled
|
236
|
-
if len(self._time) == self._time.maxlen:
|
237
|
-
elapsed_time = self._time[-1] - self._time[0]
|
238
|
-
elapsed_batches = self._batches[-1] - self._batches[0]
|
239
|
-
elapsed_samples = self._samples[-1] - self._samples[0]
|
240
|
-
# we are safe from ZeroDivisionError thanks to `_MonotonicWindow`
|
241
|
-
dev_samples_per_sec = elapsed_samples / elapsed_time
|
242
|
-
dev_batches_per_sec = elapsed_batches / elapsed_time
|
243
|
-
metrics.update(
|
244
|
-
{
|
245
|
-
f"device{self.separator}batches_per_sec": elapsed_batches
|
246
|
-
/ elapsed_time,
|
247
|
-
f"device{self.separator}samples_per_sec": dev_samples_per_sec,
|
248
|
-
}
|
249
|
-
)
|
250
|
-
if add_global_metrics:
|
251
|
-
samples_per_sec = dev_batches_per_sec * self.world_size
|
252
|
-
metrics.update(
|
253
|
-
{
|
254
|
-
"batches_per_sec": samples_per_sec,
|
255
|
-
"samples_per_sec": dev_samples_per_sec * self.world_size,
|
256
|
-
}
|
257
|
-
)
|
258
|
-
|
259
|
-
if len(self._lengths) == self._lengths.maxlen:
|
260
|
-
elapsed_lengths = self._lengths[-1] - self._lengths[0]
|
261
|
-
dev_items_per_sec = elapsed_lengths / elapsed_time
|
262
|
-
metrics[f"device{self.separator}items_per_sec"] = dev_items_per_sec
|
263
|
-
if add_global_metrics:
|
264
|
-
items_per_sec = dev_items_per_sec * self.world_size
|
265
|
-
metrics["items_per_sec"] = items_per_sec
|
266
|
-
|
267
|
-
if len(self._flops) == self._flops.maxlen:
|
268
|
-
elapsed_flops = sum(self._flops) - self._flops[0]
|
269
|
-
elapsed_time = self._time[-1] - self._time[0]
|
270
|
-
flops_per_sec = elapsed_flops / elapsed_time
|
271
|
-
dev_flops_per_sec = flops_per_sec / self.world_size
|
272
|
-
if add_global_metrics:
|
273
|
-
metrics["flops_per_sec"] = flops_per_sec
|
274
|
-
metrics[f"device{self.separator}flops_per_sec"] = dev_flops_per_sec
|
275
|
-
if self.available_flops:
|
276
|
-
metrics[f"device{self.separator}mfu"] = (
|
277
|
-
dev_flops_per_sec / self.available_flops
|
278
|
-
)
|
279
|
-
|
280
|
-
return metrics
|
281
|
-
|
282
|
-
def reset(self) -> None:
|
283
|
-
self._time.clear()
|
284
|
-
self._batches.clear()
|
285
|
-
self._samples.clear()
|
286
|
-
self._lengths.clear()
|
287
|
-
self._flops.clear()
|
288
|
-
|
289
|
-
|
290
|
-
class ThroughputMonitor(Callback):
|
291
|
-
r"""Computes and logs throughput with the :class:`~lightning.fabric.utilities.throughput.Throughput`
|
292
|
-
|
293
|
-
Example::
|
294
|
-
|
295
|
-
class MyModel(LightningModule):
|
296
|
-
def setup(self, stage):
|
297
|
-
with torch.device("meta"):
|
298
|
-
model = MyModel()
|
299
|
-
|
300
|
-
def sample_forward():
|
301
|
-
batch = torch.randn(..., device="meta")
|
302
|
-
return model(batch)
|
303
|
-
|
304
|
-
self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
|
305
|
-
|
306
|
-
|
307
|
-
logger = ...
|
308
|
-
throughput = ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0))
|
309
|
-
trainer = Trainer(max_steps=1000, log_every_n_steps=10, callbacks=throughput, logger=logger)
|
310
|
-
model = MyModel()
|
311
|
-
trainer.fit(model)
|
312
|
-
|
313
|
-
Notes:
|
314
|
-
- It assumes that the batch size is the same during all iterations.
|
315
|
-
- It will try to access a ``flops_per_batch`` attribute on your ``LightningModule`` on every iteration.
|
316
|
-
We suggest using the :func:`~lightning.fabric.utilities.throughput.measure_flops` function for this.
|
317
|
-
You might want to compute it differently each time based on your setup.
|
318
|
-
|
319
|
-
Args:
|
320
|
-
batch_size_fn: A function to compute the number of samples given a batch.
|
321
|
-
length_fn: A function to compute the number of items in a sample given a batch.
|
322
|
-
\**kwargs: See available parameters in
|
323
|
-
:class:`~lightning.fabric.utilities.throughput.Throughput`
|
324
|
-
|
325
|
-
"""
|
326
|
-
|
327
|
-
def __init__(
|
328
|
-
self,
|
329
|
-
batch_size_fn: Callable[[Any], int],
|
330
|
-
length_fn: Optional[Callable[[Any], int | None]] = None,
|
331
|
-
**kwargs: Any,
|
332
|
-
) -> None:
|
333
|
-
super().__init__()
|
334
|
-
self.kwargs = kwargs
|
335
|
-
self.batch_size_fn = batch_size_fn
|
336
|
-
self.length_fn = length_fn
|
337
|
-
self.available_flops: Optional[int] = None
|
338
|
-
self._throughputs: Dict[RunningStage, Throughput] = {}
|
339
|
-
self._t0s: Dict[RunningStage, float] = {}
|
340
|
-
self._samples: Dict[RunningStage, int] = {}
|
341
|
-
self._lengths: Dict[RunningStage, int] = {}
|
342
|
-
|
343
|
-
@override
|
344
|
-
def setup(
|
345
|
-
self, trainer: "Trainer", pl_module: "LightningModule", stage: str
|
346
|
-
) -> None:
|
347
|
-
dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
|
348
|
-
self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)
|
349
|
-
|
350
|
-
if stage == TrainerFn.FITTING and trainer.enable_validation:
|
351
|
-
# `fit` includes validation inside
|
352
|
-
throughput = Throughput(
|
353
|
-
available_flops=self.available_flops,
|
354
|
-
world_size=trainer.world_size,
|
355
|
-
**self.kwargs,
|
356
|
-
)
|
357
|
-
self._throughputs[RunningStage.VALIDATING] = throughput
|
358
|
-
|
359
|
-
throughput = Throughput(
|
360
|
-
available_flops=self.available_flops,
|
361
|
-
world_size=trainer.world_size,
|
362
|
-
**self.kwargs,
|
363
|
-
)
|
364
|
-
stage = trainer.state.stage
|
365
|
-
assert stage is not None
|
366
|
-
self._throughputs[stage] = throughput
|
367
|
-
|
368
|
-
def _start(self, trainer: "Trainer") -> None:
|
369
|
-
stage = trainer.state.stage
|
370
|
-
assert stage is not None
|
371
|
-
self._throughputs[stage].reset()
|
372
|
-
self._samples[stage] = 0
|
373
|
-
self._lengths[stage] = 0
|
374
|
-
self._t0s[stage] = time.perf_counter()
|
375
|
-
|
376
|
-
def _update(
|
377
|
-
self,
|
378
|
-
trainer: "Trainer",
|
379
|
-
pl_module: "LightningModule",
|
380
|
-
batch: Any,
|
381
|
-
iter_num: int,
|
382
|
-
) -> None:
|
383
|
-
stage = trainer.state.stage
|
384
|
-
assert stage is not None
|
385
|
-
throughput = self._throughputs[stage]
|
386
|
-
|
387
|
-
if trainer.strategy.root_device.type == "cuda":
|
388
|
-
# required or else perf_counter() won't be correct
|
389
|
-
torch.cuda.synchronize()
|
390
|
-
|
391
|
-
elapsed = time.perf_counter() - self._t0s[stage]
|
392
|
-
if self.length_fn is not None:
|
393
|
-
with torch.inference_mode():
|
394
|
-
if (length := self.length_fn(batch)) is not None:
|
395
|
-
self._lengths[stage] += length
|
396
|
-
|
397
|
-
if hasattr(pl_module, "flops_per_batch"):
|
398
|
-
flops_per_batch = pl_module.flops_per_batch
|
399
|
-
else:
|
400
|
-
rank_zero_warn(
|
401
|
-
"When using the `ThroughputMonitor`, you need to define a `flops_per_batch` attribute or property"
|
402
|
-
f" in {type(pl_module).__name__} to compute the FLOPs."
|
403
|
-
)
|
404
|
-
flops_per_batch = None
|
405
|
-
|
406
|
-
with torch.inference_mode():
|
407
|
-
self._samples[stage] += self.batch_size_fn(batch)
|
408
|
-
|
409
|
-
throughput.update(
|
410
|
-
time=elapsed,
|
411
|
-
batches=iter_num,
|
412
|
-
samples=self._samples[stage],
|
413
|
-
lengths=None if self.length_fn is None else self._lengths[stage],
|
414
|
-
flops=flops_per_batch,
|
415
|
-
)
|
416
|
-
|
417
|
-
def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
|
418
|
-
if not trainer._logger_connector.should_update_logs:
|
419
|
-
return
|
420
|
-
stage = trainer.state.stage
|
421
|
-
assert stage is not None
|
422
|
-
throughput = self._throughputs[stage]
|
423
|
-
metrics = throughput.compute()
|
424
|
-
# prefix with the stage to avoid collisions
|
425
|
-
metrics = {
|
426
|
-
f"{stage.value}{throughput.separator}{k}": v for k, v in metrics.items()
|
427
|
-
}
|
428
|
-
trainer._logger_connector.log_metrics(metrics, step=iter_num) # type: ignore[arg-type]
|
429
|
-
|
430
|
-
@override
|
431
|
-
@rank_zero_only
|
432
|
-
def on_train_start(self, trainer: "Trainer", *_: Any) -> None:
|
433
|
-
self._start(trainer)
|
434
|
-
|
435
|
-
@override
|
436
|
-
@rank_zero_only
|
437
|
-
def on_train_batch_end(
|
438
|
-
self,
|
439
|
-
trainer: "Trainer",
|
440
|
-
pl_module: "LightningModule",
|
441
|
-
outputs: Any,
|
442
|
-
batch: Any,
|
443
|
-
*_: Any,
|
444
|
-
) -> None:
|
445
|
-
self._update(trainer, pl_module, batch, trainer.fit_loop.total_batch_idx + 1)
|
446
|
-
# log only when gradient accumulation is over. this ensures that we only measure when the effective batch has
|
447
|
-
# finished and the `optimizer.step()` time is included
|
448
|
-
if not trainer.fit_loop._should_accumulate():
|
449
|
-
self._compute(trainer)
|
450
|
-
|
451
|
-
@override
|
452
|
-
@rank_zero_only
|
453
|
-
def on_validation_start(self, trainer: "Trainer", *_: Any) -> None:
|
454
|
-
if trainer.sanity_checking:
|
455
|
-
return
|
456
|
-
self._start(trainer)
|
457
|
-
|
458
|
-
@override
|
459
|
-
@rank_zero_only
|
460
|
-
def on_validation_batch_end(
|
461
|
-
self,
|
462
|
-
trainer: "Trainer",
|
463
|
-
pl_module: "LightningModule",
|
464
|
-
outputs: Any,
|
465
|
-
batch: Any,
|
466
|
-
*_: Any,
|
467
|
-
**__: Any,
|
468
|
-
) -> None:
|
469
|
-
if trainer.sanity_checking:
|
470
|
-
return
|
471
|
-
iter_num = trainer._evaluation_loop.batch_progress.total.ready
|
472
|
-
self._update(trainer, pl_module, batch, iter_num)
|
473
|
-
self._compute(trainer, iter_num)
|
474
|
-
|
475
|
-
@override
|
476
|
-
@rank_zero_only
|
477
|
-
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
|
478
|
-
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
|
479
|
-
return
|
480
|
-
# add the validation time to the training time before continuing to avoid sinking the training throughput
|
481
|
-
training_finished = self._t0s[RunningStage.TRAINING] + sum(
|
482
|
-
self._throughputs[RunningStage.TRAINING]._time
|
483
|
-
)
|
484
|
-
time_between_train_and_val = (
|
485
|
-
self._t0s[RunningStage.VALIDATING] - training_finished
|
486
|
-
)
|
487
|
-
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
|
488
|
-
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time
|
489
|
-
|
490
|
-
@override
|
491
|
-
@rank_zero_only
|
492
|
-
def on_test_start(self, trainer: "Trainer", *_: Any) -> None:
|
493
|
-
self._start(trainer)
|
494
|
-
|
495
|
-
@override
|
496
|
-
@rank_zero_only
|
497
|
-
def on_test_batch_end(
|
498
|
-
self,
|
499
|
-
trainer: "Trainer",
|
500
|
-
pl_module: "LightningModule",
|
501
|
-
outputs: Any,
|
502
|
-
batch: Any,
|
503
|
-
*_: Any,
|
504
|
-
**__: Any,
|
505
|
-
) -> None:
|
506
|
-
iter_num = trainer._evaluation_loop.batch_progress.total.ready
|
507
|
-
self._update(trainer, pl_module, batch, iter_num)
|
508
|
-
self._compute(trainer, iter_num)
|
509
|
-
|
510
|
-
@override
|
511
|
-
@rank_zero_only
|
512
|
-
def on_predict_start(self, trainer: "Trainer", *_: Any) -> None:
|
513
|
-
self._start(trainer)
|
514
|
-
|
515
|
-
@override
|
516
|
-
@rank_zero_only
|
517
|
-
def on_predict_batch_end(
|
518
|
-
self,
|
519
|
-
trainer: "Trainer",
|
520
|
-
pl_module: "LightningModule",
|
521
|
-
outputs: Any,
|
522
|
-
batch: Any,
|
523
|
-
*_: Any,
|
524
|
-
**__: Any,
|
525
|
-
) -> None:
|
526
|
-
iter_num = trainer.predict_loop.batch_progress.total.ready
|
527
|
-
self._update(trainer, pl_module, batch, iter_num)
|
528
|
-
self._compute(trainer, iter_num)
|
529
|
-
|
530
|
-
|
531
|
-
def _plugin_to_compute_dtype(plugin: Union[FabricPrecision, Precision]) -> torch.dtype:
|
532
|
-
# TODO: integrate this into the precision plugins
|
533
|
-
if not isinstance(plugin, Precision):
|
534
|
-
return fabric_plugin_to_compute_dtype(plugin)
|
535
|
-
if isinstance(plugin, BitsandbytesPrecision):
|
536
|
-
return plugin.dtype
|
537
|
-
if isinstance(plugin, HalfPrecision):
|
538
|
-
return plugin._desired_input_dtype
|
539
|
-
if isinstance(plugin, MixedPrecision):
|
540
|
-
return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
|
541
|
-
if isinstance(plugin, DoublePrecision):
|
542
|
-
return torch.double
|
543
|
-
if isinstance(plugin, (XLAPrecision, DeepSpeedPrecision)):
|
544
|
-
return plugin._desired_dtype
|
545
|
-
if isinstance(plugin, TransformerEnginePrecision):
|
546
|
-
return torch.int8
|
547
|
-
if isinstance(plugin, FSDPPrecision):
|
548
|
-
return plugin.mixed_precision_config.reduce_dtype or torch.float32
|
549
|
-
if isinstance(plugin, Precision):
|
550
|
-
return torch.float32
|
551
|
-
raise NotImplementedError(plugin)
|
@@ -1,58 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
|
5
|
-
|
6
|
-
from typing_extensions import NotRequired, override
|
7
|
-
|
8
|
-
from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
|
9
|
-
from .base import CallbackConfigBase
|
10
|
-
|
11
|
-
log = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class ThroughputMonitorBatchStats(TypedDict):
|
15
|
-
batch_size: int
|
16
|
-
length: NotRequired[int | None]
|
17
|
-
|
18
|
-
|
19
|
-
@runtime_checkable
|
20
|
-
class SupportsThroughputMonitorModuleProtocol(Protocol):
|
21
|
-
def throughput_monitor_batch_stats(
|
22
|
-
self, batch: Any
|
23
|
-
) -> ThroughputMonitorBatchStats: ...
|
24
|
-
|
25
|
-
|
26
|
-
class ThroughputMonitor(_ThroughputMonitor):
|
27
|
-
def __init__(self, window_size: int = 100) -> None:
|
28
|
-
super().__init__(cast(Any, None), cast(Any, None), window_size=window_size)
|
29
|
-
|
30
|
-
@override
|
31
|
-
def setup(self, trainer, pl_module, stage):
|
32
|
-
if not isinstance(pl_module, SupportsThroughputMonitorModuleProtocol):
|
33
|
-
raise RuntimeError(
|
34
|
-
"The model does not implement `throughput_monitor_batch_stats`. "
|
35
|
-
"Please either implement this method, or do not use the `ThroughputMonitor` callback."
|
36
|
-
)
|
37
|
-
|
38
|
-
def batch_size_fn(batch):
|
39
|
-
return pl_module.throughput_monitor_batch_stats(batch)["batch_size"]
|
40
|
-
|
41
|
-
def length_fn(batch):
|
42
|
-
return pl_module.throughput_monitor_batch_stats(batch).get("length")
|
43
|
-
|
44
|
-
self.batch_size_fn = batch_size_fn
|
45
|
-
self.length_fn = length_fn
|
46
|
-
|
47
|
-
return super().setup(trainer, pl_module, stage)
|
48
|
-
|
49
|
-
|
50
|
-
class ThroughputMonitorConfig(CallbackConfigBase):
|
51
|
-
name: Literal["throughput_monitor"] = "throughput_monitor"
|
52
|
-
|
53
|
-
window_size: int = 100
|
54
|
-
"""Number of batches to use for a rolling average."""
|
55
|
-
|
56
|
-
@override
|
57
|
-
def create_callbacks(self, root_config):
|
58
|
-
yield ThroughputMonitor(window_size=self.window_size)
|
@@ -1,41 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
# Config/alias imports
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.model import BaseConfig as BaseConfig
|
11
|
-
from nshtrainer.model import DirectoryConfig as DirectoryConfig
|
12
|
-
from nshtrainer.model import MetricConfig as MetricConfig
|
13
|
-
from nshtrainer.model import TrainerConfig as TrainerConfig
|
14
|
-
from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
|
15
|
-
from nshtrainer.model.config import CallbackConfigBase as CallbackConfigBase
|
16
|
-
else:
|
17
|
-
|
18
|
-
def __getattr__(name):
|
19
|
-
import importlib
|
20
|
-
|
21
|
-
if name in globals():
|
22
|
-
return globals()[name]
|
23
|
-
if name == "MetricConfig":
|
24
|
-
return importlib.import_module("nshtrainer.model").MetricConfig
|
25
|
-
if name == "TrainerConfig":
|
26
|
-
return importlib.import_module("nshtrainer.model").TrainerConfig
|
27
|
-
if name == "BaseConfig":
|
28
|
-
return importlib.import_module("nshtrainer.model").BaseConfig
|
29
|
-
if name == "EnvironmentConfig":
|
30
|
-
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
31
|
-
if name == "DirectoryConfig":
|
32
|
-
return importlib.import_module("nshtrainer.model").DirectoryConfig
|
33
|
-
if name == "CallbackConfigBase":
|
34
|
-
return importlib.import_module("nshtrainer.model.config").CallbackConfigBase
|
35
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
36
|
-
|
37
|
-
|
38
|
-
# Submodule exports
|
39
|
-
from . import base as base
|
40
|
-
from . import config as config
|
41
|
-
from . import mixins as mixins
|
@@ -1,25 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
# Config/alias imports
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.model.base import BaseConfig as BaseConfig
|
11
|
-
from nshtrainer.model.base import EnvironmentConfig as EnvironmentConfig
|
12
|
-
else:
|
13
|
-
|
14
|
-
def __getattr__(name):
|
15
|
-
import importlib
|
16
|
-
|
17
|
-
if name in globals():
|
18
|
-
return globals()[name]
|
19
|
-
if name == "BaseConfig":
|
20
|
-
return importlib.import_module("nshtrainer.model.base").BaseConfig
|
21
|
-
if name == "EnvironmentConfig":
|
22
|
-
return importlib.import_module("nshtrainer.model.base").EnvironmentConfig
|
23
|
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
24
|
-
|
25
|
-
# Submodule exports
|