fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
5
|
+
from torch.utils.flop_counter import flop_registry
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pprint import pformat
|
|
8
|
+
from typing import Any, Protocol, TYPE_CHECKING, cast
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
import lightning as L
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
15
|
+
|
|
16
|
+
from fkat.utils.logging import rank0_logger
|
|
17
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
18
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
19
|
+
from fkat.pytorch.schedule import Schedule, Every
|
|
20
|
+
|
|
21
|
+
logger = rank0_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FlopRecipe(Protocol):
|
|
25
|
+
"""FlopRecipe provides an estimation of Floating Point Operation (Flop) during a training batch."""
|
|
26
|
+
|
|
27
|
+
def get_batch_flop(self, pl_module: "L.LightningModule") -> int: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GPTModel(FlopRecipe):
|
|
31
|
+
"""``GPTModel`` FLOP recipe provides an estimation of Floating Point Operation (Flop)
|
|
32
|
+
during a training batch for MegatronGPTModels.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _calculate_flop(
|
|
37
|
+
batch_size: int,
|
|
38
|
+
seq_length: int,
|
|
39
|
+
num_layers: int,
|
|
40
|
+
kv_channels: int | None,
|
|
41
|
+
num_attention_heads: int,
|
|
42
|
+
num_query_groups: int,
|
|
43
|
+
hidden_size: int,
|
|
44
|
+
ffn_hidden_size: int,
|
|
45
|
+
group_query_attention: bool,
|
|
46
|
+
swiglu: bool,
|
|
47
|
+
padded_vocab_size: int,
|
|
48
|
+
num_moe_layers: int,
|
|
49
|
+
moe_ffn_hidden_size: int,
|
|
50
|
+
num_experts: int,
|
|
51
|
+
moe_router_topk: int,
|
|
52
|
+
consider_activation_recompute: bool = False,
|
|
53
|
+
) -> int:
|
|
54
|
+
if kv_channels is None:
|
|
55
|
+
kv_channels = hidden_size // num_attention_heads
|
|
56
|
+
|
|
57
|
+
params = {
|
|
58
|
+
"batch_size": batch_size,
|
|
59
|
+
"seq_length": seq_length,
|
|
60
|
+
"num_layers": num_layers,
|
|
61
|
+
"kv_channels": kv_channels,
|
|
62
|
+
"num_attention_heads": num_attention_heads,
|
|
63
|
+
"num_query_groups": num_query_groups,
|
|
64
|
+
"hidden_size": hidden_size,
|
|
65
|
+
"ffn_hidden_size": ffn_hidden_size,
|
|
66
|
+
"group_query_attention": group_query_attention,
|
|
67
|
+
"swiglu": swiglu,
|
|
68
|
+
"padded_vocab_size": padded_vocab_size,
|
|
69
|
+
"num_moe_layers": num_moe_layers,
|
|
70
|
+
"moe_ffn_hidden_size": moe_ffn_hidden_size,
|
|
71
|
+
"num_experts": num_experts,
|
|
72
|
+
"moe_router_topk": moe_router_topk,
|
|
73
|
+
"consider_activation_recompute": consider_activation_recompute,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
logger.debug("Called _calculate_flop with parameters:\n%s", pformat(params))
|
|
77
|
+
|
|
78
|
+
# Attention projection size.
|
|
79
|
+
query_projection_size = kv_channels * num_attention_heads
|
|
80
|
+
query_projection_to_hidden_size_ratio = query_projection_size / hidden_size
|
|
81
|
+
# Group Query Attention.
|
|
82
|
+
if not group_query_attention:
|
|
83
|
+
num_query_groups = num_attention_heads
|
|
84
|
+
# MoE.
|
|
85
|
+
num_experts_routed_to = 1 if num_experts is None else moe_router_topk
|
|
86
|
+
gated_linear_multiplier = 3 / 2 if swiglu else 1
|
|
87
|
+
|
|
88
|
+
# The 12x term below comes from the following factors; for more details, see
|
|
89
|
+
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
|
|
90
|
+
# - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
|
|
91
|
+
# backward wgrad [weight gradient], backward dgrad [data gradient]).
|
|
92
|
+
# When the HFU is considered, additional forward pass needs to be factored in.
|
|
93
|
+
# - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
|
|
94
|
+
# architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
|
|
95
|
+
# in MLP layer).
|
|
96
|
+
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
|
|
97
|
+
compute_factor = 4 if consider_activation_recompute else 3
|
|
98
|
+
expansion_factor = compute_factor * 2 * 2
|
|
99
|
+
moe_ratio = num_moe_layers / num_layers
|
|
100
|
+
dense_ratio = 1.0 - moe_ratio
|
|
101
|
+
return int(
|
|
102
|
+
expansion_factor
|
|
103
|
+
* batch_size
|
|
104
|
+
* seq_length
|
|
105
|
+
* num_layers
|
|
106
|
+
* hidden_size
|
|
107
|
+
* hidden_size
|
|
108
|
+
* (
|
|
109
|
+
# Attention.
|
|
110
|
+
(
|
|
111
|
+
(1 + (num_query_groups / num_attention_heads) + (seq_length / hidden_size))
|
|
112
|
+
* query_projection_to_hidden_size_ratio
|
|
113
|
+
)
|
|
114
|
+
# MLP.
|
|
115
|
+
# Interleave
|
|
116
|
+
+ dense_ratio * ((ffn_hidden_size / hidden_size) * gated_linear_multiplier)
|
|
117
|
+
+ moe_ratio * ((moe_ffn_hidden_size / hidden_size) * num_experts_routed_to * gated_linear_multiplier)
|
|
118
|
+
)
|
|
119
|
+
+
|
|
120
|
+
# Logits
|
|
121
|
+
6 * batch_size * seq_length * hidden_size * padded_vocab_size
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def get_batch_flop(self, pl_module: "L.LightningModule") -> int:
|
|
125
|
+
"""Resolves Floating Pointer Operations (Flop) for the given batch size for GPT models."""
|
|
126
|
+
if hasattr(pl_module, "cfg") and isinstance(pl_module.cfg, dict):
|
|
127
|
+
config = cast(dict[str, Any], pl_module.cfg)
|
|
128
|
+
|
|
129
|
+
# reject calculating Flop per batch if key is missing from module config
|
|
130
|
+
_mandatory_config_keys = [
|
|
131
|
+
"global_batch_size",
|
|
132
|
+
"encoder_seq_length",
|
|
133
|
+
"num_layers",
|
|
134
|
+
"hidden_size",
|
|
135
|
+
"num_attention_heads",
|
|
136
|
+
"ffn_hidden_size",
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
for key in _mandatory_config_keys:
|
|
140
|
+
if key not in config:
|
|
141
|
+
raise KeyError(f"Key {key} not presented in module config {config}")
|
|
142
|
+
|
|
143
|
+
return GPTModel._calculate_flop(
|
|
144
|
+
batch_size=config.get("global_batch_size", 1),
|
|
145
|
+
seq_length=config.get("encoder_seq_length", 1024),
|
|
146
|
+
num_layers=config.get("num_layers", 24),
|
|
147
|
+
kv_channels=config.get(
|
|
148
|
+
"kv_channels", config.get("hidden_size", 4096) // config.get("num_attention_heads", 16)
|
|
149
|
+
),
|
|
150
|
+
num_attention_heads=config.get("num_attention_heads", 16),
|
|
151
|
+
num_query_groups=config.get("num_query_groups", config.get("num_attention_heads", 16)),
|
|
152
|
+
# Default to num_attention_heads
|
|
153
|
+
hidden_size=config.get("hidden_size", 4096),
|
|
154
|
+
ffn_hidden_size=config.get("ffn_hidden_size", 4 * config.get("hidden_size", 4096)),
|
|
155
|
+
# Typical FFN expansion
|
|
156
|
+
group_query_attention=config.get("group_query_attention", False),
|
|
157
|
+
swiglu=config.get("swiglu", False),
|
|
158
|
+
padded_vocab_size=config.get("padded_vocab_size", config.get("vocab_size", 50257)),
|
|
159
|
+
num_moe_layers=config.get("num_moe_layers", 0),
|
|
160
|
+
moe_ffn_hidden_size=config.get("moe_ffn_hidden_size", 0),
|
|
161
|
+
num_experts=config.get("num_experts", 0),
|
|
162
|
+
moe_router_topk=config.get("moe_router_topk", 1),
|
|
163
|
+
consider_activation_recompute=config.get("consider_activation_recompute", False),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
raise TypeError(f"{self.__class__} does not support calculating flop for {pl_module.__class__}")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Trace(TorchDispatchMode, FlopRecipe):
|
|
170
|
+
"""``Trace`` FLOP recipe is a lightweight counter mode that only counts global flops
|
|
171
|
+
without using ModuleTracker to track module hierarchy.
|
|
172
|
+
|
|
173
|
+
Example usage
|
|
174
|
+
|
|
175
|
+
.. code-block:: python
|
|
176
|
+
|
|
177
|
+
with TraceFlopRecipe() as flop_counter:
|
|
178
|
+
mod.sum().backward()
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(self) -> None:
|
|
182
|
+
super().__init__()
|
|
183
|
+
|
|
184
|
+
self.flop_registry = {
|
|
185
|
+
**flop_registry,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
self._reset_flops_count()
|
|
189
|
+
|
|
190
|
+
def _reset_flops_count(self) -> None:
|
|
191
|
+
self.flop_counts = {"Global": 0, "Tracked": 0, "Untracked": 0}
|
|
192
|
+
|
|
193
|
+
def get_batch_flop(self, pl_module: "L.LightningModule") -> int:
|
|
194
|
+
return self.flop_counts.get("Global", 0)
|
|
195
|
+
|
|
196
|
+
def get_tracked_operations_count(self) -> int | None:
|
|
197
|
+
return self.flop_counts.get("Tracked", 0)
|
|
198
|
+
|
|
199
|
+
def get_untracked_operations_count(self) -> int | None:
|
|
200
|
+
return self.flop_counts.get("Untracked", 0)
|
|
201
|
+
|
|
202
|
+
def _count_flops(self, func_packet: Any, out: Any, args: tuple[()], kwargs: Any) -> Any:
|
|
203
|
+
if func_packet in self.flop_registry:
|
|
204
|
+
flop_count_func = self.flop_registry[func_packet]
|
|
205
|
+
flop_count = flop_count_func(*args, **kwargs, out_val=out)
|
|
206
|
+
self.flop_counts["Global"] += flop_count
|
|
207
|
+
self.flop_counts["Tracked"] += 1
|
|
208
|
+
else:
|
|
209
|
+
self.flop_counts["Untracked"] += 1
|
|
210
|
+
|
|
211
|
+
return out
|
|
212
|
+
|
|
213
|
+
def __torch_dispatch__(self, func: Any, types: Any, args: tuple[()] = (), kwargs: Any = None) -> Any:
|
|
214
|
+
kwargs = kwargs if kwargs else {}
|
|
215
|
+
out = func(*args, **kwargs)
|
|
216
|
+
return self._count_flops(func._overloadpacket, out, args, kwargs)
|
|
217
|
+
|
|
218
|
+
def __enter__(self) -> TorchDispatchMode:
|
|
219
|
+
self._reset_flops_count()
|
|
220
|
+
super().__enter__()
|
|
221
|
+
return self
|
|
222
|
+
|
|
223
|
+
def __exit__(self, *args: tuple[()] | None) -> None:
|
|
224
|
+
super().__exit__(*args)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class Accelerator:
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
name: str,
|
|
231
|
+
fp32: int,
|
|
232
|
+
fp16: int | None = None,
|
|
233
|
+
bf16: int | None = None,
|
|
234
|
+
int8: int | None = None,
|
|
235
|
+
int4: int | None = None,
|
|
236
|
+
) -> None:
|
|
237
|
+
self.name = name
|
|
238
|
+
self.fp32 = fp32
|
|
239
|
+
self.fp16 = fp16 or fp32 * 2
|
|
240
|
+
self.bf16 = bf16 or self.fp16
|
|
241
|
+
self.int8 = int8 or self.fp16 * 2
|
|
242
|
+
self.int4 = int4 or self.int8 * 2
|
|
243
|
+
|
|
244
|
+
def flops(self, dtype: torch.dtype) -> int:
|
|
245
|
+
if dtype == torch.float32 and self.fp32:
|
|
246
|
+
return self.fp32
|
|
247
|
+
if dtype == torch.float16 and self.fp16:
|
|
248
|
+
return self.fp16
|
|
249
|
+
if dtype == torch.bfloat16 and self.bf16:
|
|
250
|
+
return self.bf16
|
|
251
|
+
if dtype == torch.int8 and self.int8:
|
|
252
|
+
return self.int8
|
|
253
|
+
raise ValueError(f"No {dtype} flops details for {self.name}")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
TFLOPS = 10**12
|
|
257
|
+
|
|
258
|
+
V100 = Accelerator("V100", fp32=130 * TFLOPS)
|
|
259
|
+
H100 = Accelerator("H100", fp32=989 * TFLOPS)
|
|
260
|
+
H200 = Accelerator("H200", fp32=989 * TFLOPS)
|
|
261
|
+
A100 = Accelerator("A100", fp32=312 * TFLOPS)
|
|
262
|
+
A10G = Accelerator("A10G", fp32=35 * TFLOPS)
|
|
263
|
+
A10 = Accelerator("A10", fp32=int(31.2 * TFLOPS))
|
|
264
|
+
L40S = Accelerator("L40S", fp32=int(91.6 * TFLOPS))
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def get_flops(dtype: "torch.dtype", device: "torch.device") -> int:
|
|
268
|
+
gpu_name = ""
|
|
269
|
+
if device.type == "cuda":
|
|
270
|
+
gpu_name = torch.cuda.get_device_name(device)
|
|
271
|
+
if "V100" in gpu_name:
|
|
272
|
+
return V100.flops(dtype)
|
|
273
|
+
elif "H100" in gpu_name:
|
|
274
|
+
return H100.flops(dtype)
|
|
275
|
+
elif "H200" in gpu_name:
|
|
276
|
+
return H200.flops(dtype)
|
|
277
|
+
elif "A100" in gpu_name:
|
|
278
|
+
return A100.flops(dtype)
|
|
279
|
+
elif "A10G" in gpu_name:
|
|
280
|
+
return A10G.flops(dtype)
|
|
281
|
+
elif "A10" in gpu_name:
|
|
282
|
+
return A10.flops(dtype)
|
|
283
|
+
elif "L40S" in gpu_name:
|
|
284
|
+
return L40S.flops(dtype)
|
|
285
|
+
|
|
286
|
+
raise ValueError(f"No flops details for {device} with name {gpu_name}")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def dtype(trainer: L.Trainer) -> "torch.dtype":
|
|
290
|
+
if str(trainer.precision).startswith("32"):
|
|
291
|
+
return torch.float32
|
|
292
|
+
if str(trainer.precision).startswith("16"):
|
|
293
|
+
return torch.float16
|
|
294
|
+
if str(trainer.precision).startswith("bf16"):
|
|
295
|
+
return torch.bfloat16
|
|
296
|
+
raise ValueError(f"Can't infer dtype for {trainer.precision}")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class Flops(L.Callback):
|
|
300
|
+
"""
|
|
301
|
+
A PyTorch Lightning callback that measures and logs floating-point operations (FLOPs) and
|
|
302
|
+
Model FLOP Utilization (MFU) during training, validation, testing, and prediction.
|
|
303
|
+
|
|
304
|
+
This callback helps to monitor the computational efficiency of models by measuring:
|
|
305
|
+
- Total machine FLOPs available
|
|
306
|
+
- Per-batch FLOPs used by the model
|
|
307
|
+
- Model FLOP Utilization (MFU), i.e., how efficiently the model uses the available compute
|
|
308
|
+
- Batch throughput (batches per second)
|
|
309
|
+
|
|
310
|
+
It supports two methods for estimating FLOPs:
|
|
311
|
+
1. Tracing-based estimation via a `Trace` context manager.
|
|
312
|
+
2. Formula-based estimation using a predefined GPTModel FLOP calculator.
|
|
313
|
+
|
|
314
|
+
Metrics are logged periodically (or once) to the experiment logger (e.g., MLflow) and include:
|
|
315
|
+
|
|
316
|
+
* `mfu`: Model FLOP Utilization (traced)
|
|
317
|
+
* `actual_batches_per_sec`: Measured throughput
|
|
318
|
+
* `max_batches_per_sec`: Theoretical max throughput
|
|
319
|
+
* `batch_flops`: FLOPs used in the current batch
|
|
320
|
+
* `batch_flops_from_formula`: FLOPs estimated via formula (if available)
|
|
321
|
+
* `mfu_from_formula`: MFU based on formula-based estimation
|
|
322
|
+
* `tracked_operations`: Number of FLOPs tracked during tracing
|
|
323
|
+
* `untracked_operations`: Number of operations not accounted for by the tracer
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
schedule (Optional[Schedule]): Controls when logging occurs during training. Defaults to Every 5 batch.
|
|
327
|
+
- FLOPs are always calculated at least once at the beginning.
|
|
328
|
+
|
|
329
|
+
Example:
|
|
330
|
+
>>> trainer = L.Trainer(callbacks=[Flops(log_every_n_batches=10)])
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(self, schedule: Schedule | None = None, *args: Any, **kwargs: Any) -> None:
|
|
334
|
+
"""Measures the total floating-point operations per second and MFU.
|
|
335
|
+
Args:
|
|
336
|
+
schedule (Optional[Schedule]): Controls when logging occurs during training. Defaults to Every 5 batch.
|
|
337
|
+
- FLOPs are always calculated at least once at the beginning.
|
|
338
|
+
Returns:
|
|
339
|
+
None
|
|
340
|
+
"""
|
|
341
|
+
self.schedule = schedule or Every(n_batches=5)
|
|
342
|
+
self.kwargs = kwargs
|
|
343
|
+
self.trace_flops_recipe: Trace | None = None
|
|
344
|
+
self.gpt_flops_recipe = GPTModel() # @TODO(hanfange) - make it more configurable
|
|
345
|
+
self.total_flops = self.batch_idx = 0
|
|
346
|
+
self.mfu_from_formula: float | None = None
|
|
347
|
+
self.batch_flops_from_formula: int | None = None
|
|
348
|
+
self.batch_flops = self.mfu = torch.empty(0)
|
|
349
|
+
self.operations_tracked = self.operations_untracked = torch.empty(0)
|
|
350
|
+
self.start_time: datetime | None = None
|
|
351
|
+
self.is_first_batch = True
|
|
352
|
+
|
|
353
|
+
self._timer_active = False
|
|
354
|
+
self._cb_logger: LightningLogger | None = None
|
|
355
|
+
|
|
356
|
+
@override
|
|
357
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
358
|
+
"""Called when fit, validate, test, predict, or tune begins."""
|
|
359
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
360
|
+
self.total_flops = get_flops(dtype(trainer), pl_module.device) * trainer.num_nodes * trainer.num_devices
|
|
361
|
+
self.batch_flops = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
|
|
362
|
+
self.operations_tracked = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
|
|
363
|
+
self.operations_untracked = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
|
|
364
|
+
|
|
365
|
+
def _should_recalulate_batch_flops(self, trainer: "L.Trainer") -> bool:
|
|
366
|
+
"""
|
|
367
|
+
_should_recalulate_batch_flops decides whether to recalculate the value of self.batch_flops.
|
|
368
|
+
Recalcuting batch_flops needs to enter flops counter mode and might cause potential vRAM leakage.
|
|
369
|
+
|
|
370
|
+
Only calculate batch flops for the first batch.
|
|
371
|
+
|
|
372
|
+
Returns: bool
|
|
373
|
+
"""
|
|
374
|
+
return self.is_first_batch
|
|
375
|
+
|
|
376
|
+
def _should_report_batch_throughput_and_mfu(self, trainer: "L.Trainer") -> bool:
|
|
377
|
+
"""
|
|
378
|
+
_should_report_batch_throughput_and_mfu decides whether to publish batch throughput and mfu metrics to MLFlow
|
|
379
|
+
|
|
380
|
+
Only report every n batches.
|
|
381
|
+
|
|
382
|
+
Returns: bool
|
|
383
|
+
"""
|
|
384
|
+
return self.schedule.check(stage="train", batch_idx=self.batch_idx, step=trainer.global_step, trainer=trainer)
|
|
385
|
+
|
|
386
|
+
def _start(self, trainer: "L.Trainer", batch_idx: int) -> None:
|
|
387
|
+
self.batch_idx = batch_idx
|
|
388
|
+
if trainer.sanity_checking:
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
if self._should_recalulate_batch_flops(trainer): # calculate number of flops
|
|
392
|
+
self.trace_flops_recipe = Trace()
|
|
393
|
+
self.trace_flops_recipe.__enter__()
|
|
394
|
+
|
|
395
|
+
if self._should_report_batch_throughput_and_mfu(trainer): # report MFU every n batches
|
|
396
|
+
self.start_time = datetime.now()
|
|
397
|
+
self._timer_active = True
|
|
398
|
+
|
|
399
|
+
def _stop(self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any) -> None:
|
|
400
|
+
if trainer.sanity_checking:
|
|
401
|
+
return
|
|
402
|
+
|
|
403
|
+
if self._should_recalulate_batch_flops(trainer): # only calculate number of flops once as it is a constant
|
|
404
|
+
self.trace_flops_recipe.__exit__(None, None, None) # type: ignore[union-attr]
|
|
405
|
+
flop = self.trace_flops_recipe.get_batch_flop(None) # type: ignore
|
|
406
|
+
self.batch_flops.fill_(flop)
|
|
407
|
+
self.operations_tracked.fill_(self.trace_flops_recipe.get_tracked_operations_count()) # type: ignore
|
|
408
|
+
self.operations_untracked.fill_(self.trace_flops_recipe.get_untracked_operations_count()) # type: ignore
|
|
409
|
+
trainer.strategy.reduce(self.batch_flops, reduce_op="sum")
|
|
410
|
+
trainer.strategy.reduce(self.operations_tracked, reduce_op="sum")
|
|
411
|
+
trainer.strategy.reduce(self.operations_untracked, reduce_op="sum")
|
|
412
|
+
|
|
413
|
+
# main node emits metrics for every n batches
|
|
414
|
+
if (
|
|
415
|
+
trainer.global_rank == 0 and self._should_report_batch_throughput_and_mfu(trainer) and self._timer_active
|
|
416
|
+
): # report MFU every n batches
|
|
417
|
+
assert self._cb_logger
|
|
418
|
+
assert self.start_time
|
|
419
|
+
self._timer_active = False
|
|
420
|
+
|
|
421
|
+
now = datetime.now()
|
|
422
|
+
actual_batches_per_sec = 1 / (now - self.start_time).total_seconds()
|
|
423
|
+
|
|
424
|
+
max_batches_per_sec = self.total_flops / self.batch_flops.item()
|
|
425
|
+
mfu = actual_batches_per_sec / max_batches_per_sec
|
|
426
|
+
|
|
427
|
+
metrics = {
|
|
428
|
+
"mfu": mfu,
|
|
429
|
+
"actual_batches_per_sec": actual_batches_per_sec,
|
|
430
|
+
"max_batches_per_sec": max_batches_per_sec,
|
|
431
|
+
"batch_flops": self.batch_flops.item(),
|
|
432
|
+
"total_flops": self.total_flops,
|
|
433
|
+
"batch_flops_tracked_operations": self.operations_tracked.item(),
|
|
434
|
+
"batch_flops_untracked_operations": self.operations_untracked.item(),
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
# attempt to calculate batch flop using formula-based approach, rank-zero only
|
|
438
|
+
try:
|
|
439
|
+
self.batch_flops_from_formula = self.gpt_flops_recipe.get_batch_flop(pl_module)
|
|
440
|
+
max_batches_per_sec_from_formula = self.total_flops / self.batch_flops_from_formula
|
|
441
|
+
self.mfu_from_formula = actual_batches_per_sec / max_batches_per_sec_from_formula
|
|
442
|
+
|
|
443
|
+
# emit mfu calculated from formula, if applicable
|
|
444
|
+
metrics.update(
|
|
445
|
+
{
|
|
446
|
+
"mfu_from_formula": self.mfu_from_formula,
|
|
447
|
+
"batch_flops_from_formula": self.batch_flops_from_formula,
|
|
448
|
+
}
|
|
449
|
+
)
|
|
450
|
+
except Exception as e:
|
|
451
|
+
logger.debug(f"Could not calculate FLOP using formula: {e}")
|
|
452
|
+
|
|
453
|
+
self._cb_logger.log_batch(metrics=metrics, timestamp=int(now.timestamp() * 1e3), step=trainer.global_step)
|
|
454
|
+
|
|
455
|
+
self.is_first_batch = False
|
|
456
|
+
|
|
457
|
+
@override
|
|
458
|
+
def on_train_batch_start(
|
|
459
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
460
|
+
) -> None:
|
|
461
|
+
"""
|
|
462
|
+
Called at the beginning of each training batch.
|
|
463
|
+
|
|
464
|
+
This method initiates FLOP tracing and throughput timing for the current batch,
|
|
465
|
+
depending on the logging frequency.
|
|
466
|
+
|
|
467
|
+
If conditions are met, it:
|
|
468
|
+
- Begins FLOP tracing using a `Trace` context manager.
|
|
469
|
+
- Records the start time to later compute batch throughput.
|
|
470
|
+
|
|
471
|
+
Tracing and logging are skipped during sanity checks.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
trainer (Trainer): The PyTorch Lightning trainer instance.
|
|
475
|
+
pl_module (LightningModule): The model being trained.
|
|
476
|
+
batch (Any): The current batch of data.
|
|
477
|
+
batch_idx (int): Index of the current batch.
|
|
478
|
+
*args: Additional positional arguments.
|
|
479
|
+
**kwargs: Additional keyword arguments.
|
|
480
|
+
"""
|
|
481
|
+
self._start(trainer, batch_idx)
|
|
482
|
+
|
|
483
|
+
@override
|
|
484
|
+
def on_train_batch_end(
|
|
485
|
+
self,
|
|
486
|
+
trainer: "L.Trainer",
|
|
487
|
+
pl_module: "L.LightningModule",
|
|
488
|
+
outputs: "STEP_OUTPUT",
|
|
489
|
+
batch: Any,
|
|
490
|
+
batch_idx: int,
|
|
491
|
+
) -> None:
|
|
492
|
+
"""
|
|
493
|
+
Called at the end of each training batch.
|
|
494
|
+
|
|
495
|
+
This method finalizes FLOP tracing and logs performance metrics if applicable.
|
|
496
|
+
|
|
497
|
+
If conditions are met, it:
|
|
498
|
+
- Ends FLOP tracing and calculates batch-level FLOPs.
|
|
499
|
+
- Aggregates tracked and untracked operations across devices.
|
|
500
|
+
- Computes Model FLOP Utilization (MFU) based on actual vs. theoretical throughput.
|
|
501
|
+
- Optionally estimates FLOPs using a formula-based approach (`GPTModel`).
|
|
502
|
+
- Logs performance metrics (e.g., MFU, throughput, FLOPs) to the experiment logger.
|
|
503
|
+
|
|
504
|
+
Logging is only performed on the global rank 0 process and is skipped during sanity checks.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
trainer (Trainer): The PyTorch Lightning trainer instance.
|
|
508
|
+
pl_module (LightningModule): The model being trained.
|
|
509
|
+
outputs (STEP_OUTPUT): The outputs from the training step.
|
|
510
|
+
batch (Any): The current batch of data.
|
|
511
|
+
*args: Additional positional arguments.
|
|
512
|
+
**kwargs: Additional keyword arguments.
|
|
513
|
+
"""
|
|
514
|
+
self._stop(trainer, pl_module, batch)
|
|
515
|
+
|
|
516
|
+
@override
|
|
517
|
+
def on_validation_batch_start(
|
|
518
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
519
|
+
) -> None:
|
|
520
|
+
"""Called when the validation batch begins."""
|
|
521
|
+
self._start(trainer, batch_idx)
|
|
522
|
+
|
|
523
|
+
@override
|
|
524
|
+
def on_validation_batch_end(
|
|
525
|
+
self,
|
|
526
|
+
trainer: "L.Trainer",
|
|
527
|
+
pl_module: "L.LightningModule",
|
|
528
|
+
outputs: "STEP_OUTPUT",
|
|
529
|
+
batch: Any,
|
|
530
|
+
batch_idx: int,
|
|
531
|
+
dataloader_idx: int = 0,
|
|
532
|
+
) -> None:
|
|
533
|
+
"""Called when the validation batch begins."""
|
|
534
|
+
self._stop(trainer, pl_module, batch)
|
|
535
|
+
|
|
536
|
+
@override
|
|
537
|
+
def on_predict_batch_start(
|
|
538
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
539
|
+
) -> None:
|
|
540
|
+
"""Called when the predict batch begins."""
|
|
541
|
+
self._start(trainer, batch_idx)
|
|
542
|
+
|
|
543
|
+
@override
|
|
544
|
+
def on_predict_batch_end(
|
|
545
|
+
self,
|
|
546
|
+
trainer: "L.Trainer",
|
|
547
|
+
pl_module: "L.LightningModule",
|
|
548
|
+
outputs: Any,
|
|
549
|
+
batch: Any,
|
|
550
|
+
batch_idx: int,
|
|
551
|
+
dataloader_idx: int = 0,
|
|
552
|
+
) -> None:
|
|
553
|
+
"""Called when the predict batch begins."""
|
|
554
|
+
self._stop(trainer, pl_module, batch)
|
|
555
|
+
|
|
556
|
+
@override
|
|
557
|
+
def on_test_batch_start(
|
|
558
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
559
|
+
) -> None:
|
|
560
|
+
"""Called when the test batch begins."""
|
|
561
|
+
self._start(trainer, batch_idx)
|
|
562
|
+
|
|
563
|
+
@override
|
|
564
|
+
def on_test_batch_end(
|
|
565
|
+
self,
|
|
566
|
+
trainer: "L.Trainer",
|
|
567
|
+
pl_module: "L.LightningModule",
|
|
568
|
+
outputs: "STEP_OUTPUT",
|
|
569
|
+
batch: Any,
|
|
570
|
+
batch_idx: int,
|
|
571
|
+
dataloader_idx: int = 0,
|
|
572
|
+
) -> None:
|
|
573
|
+
"""Called when the test batch begins."""
|
|
574
|
+
self._stop(trainer, pl_module, batch)
|