fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,574 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import torch
4
+ from torch.utils._python_dispatch import TorchDispatchMode
5
+ from torch.utils.flop_counter import flop_registry
6
+ from datetime import datetime
7
+ from pprint import pformat
8
+ from typing import Any, Protocol, TYPE_CHECKING, cast
9
+ from typing_extensions import override
10
+
11
+ import lightning as L
12
+
13
+ if TYPE_CHECKING:
14
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
15
+
16
+ from fkat.utils.logging import rank0_logger
17
+ from fkat.pytorch.loggers import LightningLogger
18
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
19
+ from fkat.pytorch.schedule import Schedule, Every
20
+
21
+ logger = rank0_logger(__name__)
22
+
23
+
24
+ class FlopRecipe(Protocol):
25
+ """FlopRecipe provides an estimation of Floating Point Operation (Flop) during a training batch."""
26
+
27
+ def get_batch_flop(self, pl_module: "L.LightningModule") -> int: ...
28
+
29
+
30
+ class GPTModel(FlopRecipe):
31
+ """``GPTModel`` FLOP recipe provides an estimation of Floating Point Operation (Flop)
32
+ during a training batch for MegatronGPTModels.
33
+ """
34
+
35
+ @staticmethod
36
+ def _calculate_flop(
37
+ batch_size: int,
38
+ seq_length: int,
39
+ num_layers: int,
40
+ kv_channels: int | None,
41
+ num_attention_heads: int,
42
+ num_query_groups: int,
43
+ hidden_size: int,
44
+ ffn_hidden_size: int,
45
+ group_query_attention: bool,
46
+ swiglu: bool,
47
+ padded_vocab_size: int,
48
+ num_moe_layers: int,
49
+ moe_ffn_hidden_size: int,
50
+ num_experts: int,
51
+ moe_router_topk: int,
52
+ consider_activation_recompute: bool = False,
53
+ ) -> int:
54
+ if kv_channels is None:
55
+ kv_channels = hidden_size // num_attention_heads
56
+
57
+ params = {
58
+ "batch_size": batch_size,
59
+ "seq_length": seq_length,
60
+ "num_layers": num_layers,
61
+ "kv_channels": kv_channels,
62
+ "num_attention_heads": num_attention_heads,
63
+ "num_query_groups": num_query_groups,
64
+ "hidden_size": hidden_size,
65
+ "ffn_hidden_size": ffn_hidden_size,
66
+ "group_query_attention": group_query_attention,
67
+ "swiglu": swiglu,
68
+ "padded_vocab_size": padded_vocab_size,
69
+ "num_moe_layers": num_moe_layers,
70
+ "moe_ffn_hidden_size": moe_ffn_hidden_size,
71
+ "num_experts": num_experts,
72
+ "moe_router_topk": moe_router_topk,
73
+ "consider_activation_recompute": consider_activation_recompute,
74
+ }
75
+
76
+ logger.debug("Called _calculate_flop with parameters:\n%s", pformat(params))
77
+
78
+ # Attention projection size.
79
+ query_projection_size = kv_channels * num_attention_heads
80
+ query_projection_to_hidden_size_ratio = query_projection_size / hidden_size
81
+ # Group Query Attention.
82
+ if not group_query_attention:
83
+ num_query_groups = num_attention_heads
84
+ # MoE.
85
+ num_experts_routed_to = 1 if num_experts is None else moe_router_topk
86
+ gated_linear_multiplier = 3 / 2 if swiglu else 1
87
+
88
+ # The 12x term below comes from the following factors; for more details, see
89
+ # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
90
+ # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
91
+ # backward wgrad [weight gradient], backward dgrad [data gradient]).
92
+ # When the HFU is considered, additional forward pass needs to be factored in.
93
+ # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
94
+ # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
95
+ # in MLP layer).
96
+ # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
97
+ compute_factor = 4 if consider_activation_recompute else 3
98
+ expansion_factor = compute_factor * 2 * 2
99
+ moe_ratio = num_moe_layers / num_layers
100
+ dense_ratio = 1.0 - moe_ratio
101
+ return int(
102
+ expansion_factor
103
+ * batch_size
104
+ * seq_length
105
+ * num_layers
106
+ * hidden_size
107
+ * hidden_size
108
+ * (
109
+ # Attention.
110
+ (
111
+ (1 + (num_query_groups / num_attention_heads) + (seq_length / hidden_size))
112
+ * query_projection_to_hidden_size_ratio
113
+ )
114
+ # MLP.
115
+ # Interleave
116
+ + dense_ratio * ((ffn_hidden_size / hidden_size) * gated_linear_multiplier)
117
+ + moe_ratio * ((moe_ffn_hidden_size / hidden_size) * num_experts_routed_to * gated_linear_multiplier)
118
+ )
119
+ +
120
+ # Logits
121
+ 6 * batch_size * seq_length * hidden_size * padded_vocab_size
122
+ )
123
+
124
+ def get_batch_flop(self, pl_module: "L.LightningModule") -> int:
125
+ """Resolves Floating Pointer Operations (Flop) for the given batch size for GPT models."""
126
+ if hasattr(pl_module, "cfg") and isinstance(pl_module.cfg, dict):
127
+ config = cast(dict[str, Any], pl_module.cfg)
128
+
129
+ # reject calculating Flop per batch if key is missing from module config
130
+ _mandatory_config_keys = [
131
+ "global_batch_size",
132
+ "encoder_seq_length",
133
+ "num_layers",
134
+ "hidden_size",
135
+ "num_attention_heads",
136
+ "ffn_hidden_size",
137
+ ]
138
+
139
+ for key in _mandatory_config_keys:
140
+ if key not in config:
141
+ raise KeyError(f"Key {key} not presented in module config {config}")
142
+
143
+ return GPTModel._calculate_flop(
144
+ batch_size=config.get("global_batch_size", 1),
145
+ seq_length=config.get("encoder_seq_length", 1024),
146
+ num_layers=config.get("num_layers", 24),
147
+ kv_channels=config.get(
148
+ "kv_channels", config.get("hidden_size", 4096) // config.get("num_attention_heads", 16)
149
+ ),
150
+ num_attention_heads=config.get("num_attention_heads", 16),
151
+ num_query_groups=config.get("num_query_groups", config.get("num_attention_heads", 16)),
152
+ # Default to num_attention_heads
153
+ hidden_size=config.get("hidden_size", 4096),
154
+ ffn_hidden_size=config.get("ffn_hidden_size", 4 * config.get("hidden_size", 4096)),
155
+ # Typical FFN expansion
156
+ group_query_attention=config.get("group_query_attention", False),
157
+ swiglu=config.get("swiglu", False),
158
+ padded_vocab_size=config.get("padded_vocab_size", config.get("vocab_size", 50257)),
159
+ num_moe_layers=config.get("num_moe_layers", 0),
160
+ moe_ffn_hidden_size=config.get("moe_ffn_hidden_size", 0),
161
+ num_experts=config.get("num_experts", 0),
162
+ moe_router_topk=config.get("moe_router_topk", 1),
163
+ consider_activation_recompute=config.get("consider_activation_recompute", False),
164
+ )
165
+
166
+ raise TypeError(f"{self.__class__} does not support calculating flop for {pl_module.__class__}")
167
+
168
+
169
+ class Trace(TorchDispatchMode, FlopRecipe):
170
+ """``Trace`` FLOP recipe is a lightweight counter mode that only counts global flops
171
+ without using ModuleTracker to track module hierarchy.
172
+
173
+ Example usage
174
+
175
+ .. code-block:: python
176
+
177
+ with TraceFlopRecipe() as flop_counter:
178
+ mod.sum().backward()
179
+ """
180
+
181
+ def __init__(self) -> None:
182
+ super().__init__()
183
+
184
+ self.flop_registry = {
185
+ **flop_registry,
186
+ }
187
+
188
+ self._reset_flops_count()
189
+
190
+ def _reset_flops_count(self) -> None:
191
+ self.flop_counts = {"Global": 0, "Tracked": 0, "Untracked": 0}
192
+
193
+ def get_batch_flop(self, pl_module: "L.LightningModule") -> int:
194
+ return self.flop_counts.get("Global", 0)
195
+
196
+ def get_tracked_operations_count(self) -> int | None:
197
+ return self.flop_counts.get("Tracked", 0)
198
+
199
+ def get_untracked_operations_count(self) -> int | None:
200
+ return self.flop_counts.get("Untracked", 0)
201
+
202
+ def _count_flops(self, func_packet: Any, out: Any, args: tuple[()], kwargs: Any) -> Any:
203
+ if func_packet in self.flop_registry:
204
+ flop_count_func = self.flop_registry[func_packet]
205
+ flop_count = flop_count_func(*args, **kwargs, out_val=out)
206
+ self.flop_counts["Global"] += flop_count
207
+ self.flop_counts["Tracked"] += 1
208
+ else:
209
+ self.flop_counts["Untracked"] += 1
210
+
211
+ return out
212
+
213
+ def __torch_dispatch__(self, func: Any, types: Any, args: tuple[()] = (), kwargs: Any = None) -> Any:
214
+ kwargs = kwargs if kwargs else {}
215
+ out = func(*args, **kwargs)
216
+ return self._count_flops(func._overloadpacket, out, args, kwargs)
217
+
218
+ def __enter__(self) -> TorchDispatchMode:
219
+ self._reset_flops_count()
220
+ super().__enter__()
221
+ return self
222
+
223
+ def __exit__(self, *args: tuple[()] | None) -> None:
224
+ super().__exit__(*args)
225
+
226
+
227
+ class Accelerator:
228
+ def __init__(
229
+ self,
230
+ name: str,
231
+ fp32: int,
232
+ fp16: int | None = None,
233
+ bf16: int | None = None,
234
+ int8: int | None = None,
235
+ int4: int | None = None,
236
+ ) -> None:
237
+ self.name = name
238
+ self.fp32 = fp32
239
+ self.fp16 = fp16 or fp32 * 2
240
+ self.bf16 = bf16 or self.fp16
241
+ self.int8 = int8 or self.fp16 * 2
242
+ self.int4 = int4 or self.int8 * 2
243
+
244
+ def flops(self, dtype: torch.dtype) -> int:
245
+ if dtype == torch.float32 and self.fp32:
246
+ return self.fp32
247
+ if dtype == torch.float16 and self.fp16:
248
+ return self.fp16
249
+ if dtype == torch.bfloat16 and self.bf16:
250
+ return self.bf16
251
+ if dtype == torch.int8 and self.int8:
252
+ return self.int8
253
+ raise ValueError(f"No {dtype} flops details for {self.name}")
254
+
255
+
256
+ TFLOPS = 10**12
257
+
258
+ V100 = Accelerator("V100", fp32=130 * TFLOPS)
259
+ H100 = Accelerator("H100", fp32=989 * TFLOPS)
260
+ H200 = Accelerator("H200", fp32=989 * TFLOPS)
261
+ A100 = Accelerator("A100", fp32=312 * TFLOPS)
262
+ A10G = Accelerator("A10G", fp32=35 * TFLOPS)
263
+ A10 = Accelerator("A10", fp32=int(31.2 * TFLOPS))
264
+ L40S = Accelerator("L40S", fp32=int(91.6 * TFLOPS))
265
+
266
+
267
+ def get_flops(dtype: "torch.dtype", device: "torch.device") -> int:
268
+ gpu_name = ""
269
+ if device.type == "cuda":
270
+ gpu_name = torch.cuda.get_device_name(device)
271
+ if "V100" in gpu_name:
272
+ return V100.flops(dtype)
273
+ elif "H100" in gpu_name:
274
+ return H100.flops(dtype)
275
+ elif "H200" in gpu_name:
276
+ return H200.flops(dtype)
277
+ elif "A100" in gpu_name:
278
+ return A100.flops(dtype)
279
+ elif "A10G" in gpu_name:
280
+ return A10G.flops(dtype)
281
+ elif "A10" in gpu_name:
282
+ return A10.flops(dtype)
283
+ elif "L40S" in gpu_name:
284
+ return L40S.flops(dtype)
285
+
286
+ raise ValueError(f"No flops details for {device} with name {gpu_name}")
287
+
288
+
289
+ def dtype(trainer: L.Trainer) -> "torch.dtype":
290
+ if str(trainer.precision).startswith("32"):
291
+ return torch.float32
292
+ if str(trainer.precision).startswith("16"):
293
+ return torch.float16
294
+ if str(trainer.precision).startswith("bf16"):
295
+ return torch.bfloat16
296
+ raise ValueError(f"Can't infer dtype for {trainer.precision}")
297
+
298
+
299
+ class Flops(L.Callback):
300
+ """
301
+ A PyTorch Lightning callback that measures and logs floating-point operations (FLOPs) and
302
+ Model FLOP Utilization (MFU) during training, validation, testing, and prediction.
303
+
304
+ This callback helps to monitor the computational efficiency of models by measuring:
305
+ - Total machine FLOPs available
306
+ - Per-batch FLOPs used by the model
307
+ - Model FLOP Utilization (MFU), i.e., how efficiently the model uses the available compute
308
+ - Batch throughput (batches per second)
309
+
310
+ It supports two methods for estimating FLOPs:
311
+ 1. Tracing-based estimation via a `Trace` context manager.
312
+ 2. Formula-based estimation using a predefined GPTModel FLOP calculator.
313
+
314
+ Metrics are logged periodically (or once) to the experiment logger (e.g., MLflow) and include:
315
+
316
+ * `mfu`: Model FLOP Utilization (traced)
317
+ * `actual_batches_per_sec`: Measured throughput
318
+ * `max_batches_per_sec`: Theoretical max throughput
319
+ * `batch_flops`: FLOPs used in the current batch
320
+ * `batch_flops_from_formula`: FLOPs estimated via formula (if available)
321
+ * `mfu_from_formula`: MFU based on formula-based estimation
322
+ * `tracked_operations`: Number of FLOPs tracked during tracing
323
+ * `untracked_operations`: Number of operations not accounted for by the tracer
324
+
325
+ Args:
326
+ schedule (Optional[Schedule]): Controls when logging occurs during training. Defaults to Every 5 batch.
327
+ - FLOPs are always calculated at least once at the beginning.
328
+
329
+ Example:
330
+ >>> trainer = L.Trainer(callbacks=[Flops(log_every_n_batches=10)])
331
+ """
332
+
333
+ def __init__(self, schedule: Schedule | None = None, *args: Any, **kwargs: Any) -> None:
334
+ """Measures the total floating-point operations per second and MFU.
335
+ Args:
336
+ schedule (Optional[Schedule]): Controls when logging occurs during training. Defaults to Every 5 batch.
337
+ - FLOPs are always calculated at least once at the beginning.
338
+ Returns:
339
+ None
340
+ """
341
+ self.schedule = schedule or Every(n_batches=5)
342
+ self.kwargs = kwargs
343
+ self.trace_flops_recipe: Trace | None = None
344
+ self.gpt_flops_recipe = GPTModel() # @TODO(hanfange) - make it more configurable
345
+ self.total_flops = self.batch_idx = 0
346
+ self.mfu_from_formula: float | None = None
347
+ self.batch_flops_from_formula: int | None = None
348
+ self.batch_flops = self.mfu = torch.empty(0)
349
+ self.operations_tracked = self.operations_untracked = torch.empty(0)
350
+ self.start_time: datetime | None = None
351
+ self.is_first_batch = True
352
+
353
+ self._timer_active = False
354
+ self._cb_logger: LightningLogger | None = None
355
+
356
+ @override
357
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
358
+ """Called when fit, validate, test, predict, or tune begins."""
359
+ self._cb_logger = CallbackLogger(trainer)
360
+ self.total_flops = get_flops(dtype(trainer), pl_module.device) * trainer.num_nodes * trainer.num_devices
361
+ self.batch_flops = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
362
+ self.operations_tracked = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
363
+ self.operations_untracked = torch.tensor(0, dtype=torch.int64, device=pl_module.device)
364
+
365
+ def _should_recalulate_batch_flops(self, trainer: "L.Trainer") -> bool:
366
+ """
367
+ _should_recalulate_batch_flops decides whether to recalculate the value of self.batch_flops.
368
+ Recalcuting batch_flops needs to enter flops counter mode and might cause potential vRAM leakage.
369
+
370
+ Only calculate batch flops for the first batch.
371
+
372
+ Returns: bool
373
+ """
374
+ return self.is_first_batch
375
+
376
+ def _should_report_batch_throughput_and_mfu(self, trainer: "L.Trainer") -> bool:
377
+ """
378
+ _should_report_batch_throughput_and_mfu decides whether to publish batch throughput and mfu metrics to MLFlow
379
+
380
+ Only report every n batches.
381
+
382
+ Returns: bool
383
+ """
384
+ return self.schedule.check(stage="train", batch_idx=self.batch_idx, step=trainer.global_step, trainer=trainer)
385
+
386
+ def _start(self, trainer: "L.Trainer", batch_idx: int) -> None:
387
+ self.batch_idx = batch_idx
388
+ if trainer.sanity_checking:
389
+ return
390
+
391
+ if self._should_recalulate_batch_flops(trainer): # calculate number of flops
392
+ self.trace_flops_recipe = Trace()
393
+ self.trace_flops_recipe.__enter__()
394
+
395
+ if self._should_report_batch_throughput_and_mfu(trainer): # report MFU every n batches
396
+ self.start_time = datetime.now()
397
+ self._timer_active = True
398
+
399
+ def _stop(self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any) -> None:
400
+ if trainer.sanity_checking:
401
+ return
402
+
403
+ if self._should_recalulate_batch_flops(trainer): # only calculate number of flops once as it is a constant
404
+ self.trace_flops_recipe.__exit__(None, None, None) # type: ignore[union-attr]
405
+ flop = self.trace_flops_recipe.get_batch_flop(None) # type: ignore
406
+ self.batch_flops.fill_(flop)
407
+ self.operations_tracked.fill_(self.trace_flops_recipe.get_tracked_operations_count()) # type: ignore
408
+ self.operations_untracked.fill_(self.trace_flops_recipe.get_untracked_operations_count()) # type: ignore
409
+ trainer.strategy.reduce(self.batch_flops, reduce_op="sum")
410
+ trainer.strategy.reduce(self.operations_tracked, reduce_op="sum")
411
+ trainer.strategy.reduce(self.operations_untracked, reduce_op="sum")
412
+
413
+ # main node emits metrics for every n batches
414
+ if (
415
+ trainer.global_rank == 0 and self._should_report_batch_throughput_and_mfu(trainer) and self._timer_active
416
+ ): # report MFU every n batches
417
+ assert self._cb_logger
418
+ assert self.start_time
419
+ self._timer_active = False
420
+
421
+ now = datetime.now()
422
+ actual_batches_per_sec = 1 / (now - self.start_time).total_seconds()
423
+
424
+ max_batches_per_sec = self.total_flops / self.batch_flops.item()
425
+ mfu = actual_batches_per_sec / max_batches_per_sec
426
+
427
+ metrics = {
428
+ "mfu": mfu,
429
+ "actual_batches_per_sec": actual_batches_per_sec,
430
+ "max_batches_per_sec": max_batches_per_sec,
431
+ "batch_flops": self.batch_flops.item(),
432
+ "total_flops": self.total_flops,
433
+ "batch_flops_tracked_operations": self.operations_tracked.item(),
434
+ "batch_flops_untracked_operations": self.operations_untracked.item(),
435
+ }
436
+
437
+ # attempt to calculate batch flop using formula-based approach, rank-zero only
438
+ try:
439
+ self.batch_flops_from_formula = self.gpt_flops_recipe.get_batch_flop(pl_module)
440
+ max_batches_per_sec_from_formula = self.total_flops / self.batch_flops_from_formula
441
+ self.mfu_from_formula = actual_batches_per_sec / max_batches_per_sec_from_formula
442
+
443
+ # emit mfu calculated from formula, if applicable
444
+ metrics.update(
445
+ {
446
+ "mfu_from_formula": self.mfu_from_formula,
447
+ "batch_flops_from_formula": self.batch_flops_from_formula,
448
+ }
449
+ )
450
+ except Exception as e:
451
+ logger.debug(f"Could not calculate FLOP using formula: {e}")
452
+
453
+ self._cb_logger.log_batch(metrics=metrics, timestamp=int(now.timestamp() * 1e3), step=trainer.global_step)
454
+
455
+ self.is_first_batch = False
456
+
457
+ @override
458
+ def on_train_batch_start(
459
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
460
+ ) -> None:
461
+ """
462
+ Called at the beginning of each training batch.
463
+
464
+ This method initiates FLOP tracing and throughput timing for the current batch,
465
+ depending on the logging frequency.
466
+
467
+ If conditions are met, it:
468
+ - Begins FLOP tracing using a `Trace` context manager.
469
+ - Records the start time to later compute batch throughput.
470
+
471
+ Tracing and logging are skipped during sanity checks.
472
+
473
+ Args:
474
+ trainer (Trainer): The PyTorch Lightning trainer instance.
475
+ pl_module (LightningModule): The model being trained.
476
+ batch (Any): The current batch of data.
477
+ batch_idx (int): Index of the current batch.
478
+ *args: Additional positional arguments.
479
+ **kwargs: Additional keyword arguments.
480
+ """
481
+ self._start(trainer, batch_idx)
482
+
483
+ @override
484
+ def on_train_batch_end(
485
+ self,
486
+ trainer: "L.Trainer",
487
+ pl_module: "L.LightningModule",
488
+ outputs: "STEP_OUTPUT",
489
+ batch: Any,
490
+ batch_idx: int,
491
+ ) -> None:
492
+ """
493
+ Called at the end of each training batch.
494
+
495
+ This method finalizes FLOP tracing and logs performance metrics if applicable.
496
+
497
+ If conditions are met, it:
498
+ - Ends FLOP tracing and calculates batch-level FLOPs.
499
+ - Aggregates tracked and untracked operations across devices.
500
+ - Computes Model FLOP Utilization (MFU) based on actual vs. theoretical throughput.
501
+ - Optionally estimates FLOPs using a formula-based approach (`GPTModel`).
502
+ - Logs performance metrics (e.g., MFU, throughput, FLOPs) to the experiment logger.
503
+
504
+ Logging is only performed on the global rank 0 process and is skipped during sanity checks.
505
+
506
+ Args:
507
+ trainer (Trainer): The PyTorch Lightning trainer instance.
508
+ pl_module (LightningModule): The model being trained.
509
+ outputs (STEP_OUTPUT): The outputs from the training step.
510
+ batch (Any): The current batch of data.
511
+ *args: Additional positional arguments.
512
+ **kwargs: Additional keyword arguments.
513
+ """
514
+ self._stop(trainer, pl_module, batch)
515
+
516
+ @override
517
+ def on_validation_batch_start(
518
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
519
+ ) -> None:
520
+ """Called when the validation batch begins."""
521
+ self._start(trainer, batch_idx)
522
+
523
+ @override
524
+ def on_validation_batch_end(
525
+ self,
526
+ trainer: "L.Trainer",
527
+ pl_module: "L.LightningModule",
528
+ outputs: "STEP_OUTPUT",
529
+ batch: Any,
530
+ batch_idx: int,
531
+ dataloader_idx: int = 0,
532
+ ) -> None:
533
+ """Called when the validation batch begins."""
534
+ self._stop(trainer, pl_module, batch)
535
+
536
+ @override
537
+ def on_predict_batch_start(
538
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
539
+ ) -> None:
540
+ """Called when the predict batch begins."""
541
+ self._start(trainer, batch_idx)
542
+
543
+ @override
544
+ def on_predict_batch_end(
545
+ self,
546
+ trainer: "L.Trainer",
547
+ pl_module: "L.LightningModule",
548
+ outputs: Any,
549
+ batch: Any,
550
+ batch_idx: int,
551
+ dataloader_idx: int = 0,
552
+ ) -> None:
553
+ """Called when the predict batch begins."""
554
+ self._stop(trainer, pl_module, batch)
555
+
556
+ @override
557
+ def on_test_batch_start(
558
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
559
+ ) -> None:
560
+ """Called when the test batch begins."""
561
+ self._start(trainer, batch_idx)
562
+
563
+ @override
564
+ def on_test_batch_end(
565
+ self,
566
+ trainer: "L.Trainer",
567
+ pl_module: "L.LightningModule",
568
+ outputs: "STEP_OUTPUT",
569
+ batch: Any,
570
+ batch_idx: int,
571
+ dataloader_idx: int = 0,
572
+ ) -> None:
573
+ """Called when the test batch begins."""
574
+ self._stop(trainer, pl_module, batch)