fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,569 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import hashlib
5
+ import yaml
6
+ import logging
7
+ import tempfile
8
+ from importlib.metadata import distributions
9
+ from collections.abc import Callable, Hashable
10
+ from functools import partial
11
+ from typing import Any, TYPE_CHECKING
12
+ from typing_extensions import override
13
+
14
+ import numpy as np
15
+ import numpy.typing as ntp
16
+ import torch
17
+ import lightning as L
18
+ from lightning.pytorch.utilities.seed import _collect_rng_states # type: ignore[attr-defined]
19
+
20
+ if TYPE_CHECKING:
21
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
22
+
23
+ from fkat.pytorch.schedule import (
24
+ Schedule,
25
+ Never,
26
+ )
27
+ from fkat.pytorch.loggers import LightningLogger
28
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
29
+
30
+ logger: logging.Logger = logging.getLogger(__name__)
31
+
32
+ PATH_PREFIX = "introspection"
33
+ DTYPES = {
34
+ "float32": "fp32",
35
+ "float64": "fp64",
36
+ "float16": "fp16",
37
+ "bfloat16": "bf16",
38
+ "int64": "i64",
39
+ "int32": "i32",
40
+ "int16": "i16",
41
+ "int8": "i8",
42
+ "uint8": "u8",
43
+ "bool": "b",
44
+ "complex64": "c64",
45
+ "complex128": "c128",
46
+ "qint8": "qi8",
47
+ "quint8": "qui8",
48
+ "qint32": "qi32",
49
+ "float8_e4m3fn": "fp8_e4m3fn",
50
+ "float8_e4m3fnuz": "fp8_e4m3fnuz",
51
+ "float8_e5m2": "fp8_e5m2",
52
+ "float8_e5m2fnuz": "fp8_e5m2fnuz",
53
+ "float8_e8m0fnu": "fp8_e8m0fnu",
54
+ }
55
+
56
+
57
+ def _get_dtype(dtype: torch.dtype) -> str:
58
+ for k, v in DTYPES.items():
59
+ if getattr(torch, k) == dtype:
60
+ return v
61
+ return str(dtype)
62
+
63
+
64
+ def _process_tensor_or_ndarray(
65
+ item: Any, tensor_stats: set[str], path: str, parent_dict: dict[str, Any] | None
66
+ ) -> bytes | None:
67
+ """Process tensor or ndarray items and calculate checksums.
68
+
69
+ Args:
70
+ item: The tensor or ndarray to process
71
+ tensor_stats (set[str]): Tensor stats to collect
72
+ path: Current path in the nested structure
73
+ parent_dict: Dictionary to store checksums in
74
+
75
+ Returns:
76
+ bytes | None: Digest of the checksum
77
+ """
78
+ if isinstance(item, torch.Tensor):
79
+ cks = tensor_checksum(item)
80
+ else: # np.ndarray
81
+ cks = numpy_checksum(item)
82
+
83
+ if parent_dict is not None:
84
+ parent_dict[path] = _format(item, cks.hexdigest(), tensor_stats)
85
+ return cks.digest()
86
+
87
+
88
+ def _process_nested(sub_item: Any, tensor_stats: set[str], nested_cks: Any, fn: Callable[[Any], Any]) -> None:
89
+ nested: dict[str, Any] = {}
90
+ nested_key = None
91
+ if isinstance(sub_item, torch.Tensor | np.ndarray):
92
+ digest = _process_tensor_or_ndarray(sub_item, tensor_stats, (nested_key := "temp"), nested)
93
+ elif isinstance(sub_item, list | tuple | set | frozenset) or hasattr(sub_item, "items"):
94
+ digest = _process_collection(sub_item, tensor_stats, (nested_key := "temp"), nested)
95
+ if nested_key:
96
+ processed = None
97
+ if digest is not None:
98
+ nested_cks.update(digest)
99
+ processed = nested[nested_key]
100
+ fn(processed)
101
+ else:
102
+ fn(sub_item)
103
+ digest = str(sub_item).encode("utf-8")
104
+ nested_cks.update(digest)
105
+
106
+
107
+ def _ensure_hashable(item: Any) -> Hashable:
108
+ if isinstance(item, Hashable):
109
+ return item
110
+ if isinstance(item, list | tuple):
111
+ return tuple(_ensure_hashable(i) for i in item)
112
+ if isinstance(item, set | frozenset):
113
+ return frozenset(_ensure_hashable(i) for i in item)
114
+ return str(item)
115
+
116
+
117
+ def _process_collection(
118
+ item: Any,
119
+ tensor_stats: set[str],
120
+ path: str,
121
+ parent_dict: dict[str, Any] | None,
122
+ ) -> bytes | None:
123
+ """Process collection items (list, tuple, dict) and calculate checksums.
124
+ Maintains the nested structure of collections while replacing only tensor values with hashes.
125
+ Primitive values are kept as is.
126
+
127
+ Args:
128
+ item: The collection to process
129
+ tensor_stats (set[str]): Tensor stats to collect
130
+ path: Current path in the nested structure
131
+ parent_dict: Dictionary to store checksums in
132
+
133
+ Returns:
134
+ bytes | None: Digest of the overall checksum
135
+ """
136
+ nested_cks = hashlib.md5()
137
+ result: list[Any] | dict[Any, Any] | None = None
138
+ if hasattr(item, "items"): # map-like
139
+ result = {}
140
+ for sub_key, sub_item in item.items():
141
+ nested: list[str] = []
142
+ _process_nested(sub_key, tensor_stats, nested_cks, nested.append)
143
+ key = _ensure_hashable(nested.pop())
144
+ _process_nested(sub_item, tensor_stats, nested_cks, partial(result.__setitem__, key))
145
+ elif isinstance(item, list | tuple | set | frozenset):
146
+ result = []
147
+ for sub_item in item:
148
+ _process_nested(sub_item, tensor_stats, nested_cks, result.append)
149
+ if result is not None and parent_dict is not None and path:
150
+ parent_dict[path] = result
151
+ return nested_cks.digest()
152
+
153
+
154
+ def process_item(
155
+ item: Any, tensor_stats: set[str], path: str = "", parent_dict: dict[str, Any] | None = None
156
+ ) -> bytes | None:
157
+ """Recursively process items in the batch and calculate checksums.
158
+ Only tensor values are replaced with hashes, primitive values are kept as is.
159
+
160
+ Args:
161
+ item: The item to process
162
+ tensor_stats (set[str]): Tensor stats to collect
163
+ path: Current path in the nested structure
164
+ parent_dict: Dictionary to store checksums in
165
+
166
+ Returns:
167
+ bytes | None: Digest of the checksum if available, None otherwise
168
+ """
169
+ if isinstance(item, torch.Tensor | np.ndarray):
170
+ return _process_tensor_or_ndarray(item, tensor_stats, path, parent_dict)
171
+ elif isinstance(item, list | tuple | set | frozenset) or hasattr(item, "items"):
172
+ return _process_collection(item, tensor_stats, path, parent_dict)
173
+ elif not isinstance(item, str | int | float | bool) or item is not None:
174
+ logging.warning(f"Converting {type(item).__name__} to string for checksum")
175
+ item = str(item)
176
+ if parent_dict is not None:
177
+ parent_dict[path] = item
178
+ return str(item).encode("utf-8")
179
+
180
+
181
+ def tensor_checksum(tensor: torch.Tensor) -> Any:
182
+ """Tensor checksum hash, returns the same value for tensors with identical contents
183
+
184
+ Args:
185
+ tensor (torch.Tensor): tensor to generate the checksum
186
+
187
+ Return:
188
+ hashlib.md5: checksum hash
189
+ """
190
+ tensor = tensor.to(torch.float32) if tensor.dtype == torch.bfloat16 else tensor
191
+ return numpy_checksum(tensor.detach().cpu().numpy())
192
+
193
+
194
+ def numpy_checksum(ndarray: ntp.NDArray[Any]) -> Any:
195
+ """Numpy NDArray checksum, returns the same value for ndarrays with identical contents
196
+
197
+ Args:
198
+ tensor (torch.Tensor): tensor to generate the checksum
199
+
200
+ Return:
201
+ hashlib.md5: checksum hash
202
+ """
203
+ return hashlib.md5(ndarray.tobytes())
204
+
205
+
206
+ def _params_checksums(
207
+ model: torch.nn.Module,
208
+ params_checksum: bool,
209
+ grads_checksum: bool,
210
+ gradients: dict[str, Any],
211
+ tensor_stats: set[str],
212
+ ) -> dict[str, Any]:
213
+ parameters: dict[str, Any] = {}
214
+ grads_cks = hashlib.md5()
215
+ params_cks = hashlib.md5()
216
+ for name, param in model.named_parameters():
217
+ parameters[name] = {}
218
+ if params_checksum and param.data is not None:
219
+ param_cks = tensor_checksum(param.data)
220
+ parameters[name]["data"] = _format(param.data, param_cks.hexdigest(), tensor_stats)
221
+ params_cks.update(param_cks.digest())
222
+ grad_cks, repr = gradients.get(name, (None, None))
223
+ if grads_cks is not None:
224
+ assert grad_cks
225
+ parameters[name]["grad"] = repr
226
+ grads_cks.update(grad_cks.digest())
227
+ if params_checksum:
228
+ _add_digest(parameters, "__all_data__", params_cks)
229
+ if grads_checksum:
230
+ _add_digest(parameters, "__all_grads__", grads_cks)
231
+ return parameters
232
+
233
+
234
+ def _format(tensor: torch.Tensor, hash: str, tensor_stats: set[str]) -> str: # noqa: C901
235
+ if tensor.dim() == 0:
236
+ return str(tensor)
237
+ chunks: list[str] = []
238
+ if "shape" in tensor_stats:
239
+ chunks.append("×".join(str(d) for d in tensor.shape))
240
+ if "dtype" in tensor_stats:
241
+ chunks.append(_get_dtype(tensor.dtype))
242
+ if "infs" in tensor_stats:
243
+ num_pos_infs = tensor.isposinf().sum()
244
+ chunks.append(f"{num_pos_infs}∞̟")
245
+ num_neg_infs = tensor.isneginf().sum()
246
+ chunks.append(f"{num_neg_infs}∞̠")
247
+ if "nans" in tensor_stats:
248
+ num_nans = tensor.isnan().sum()
249
+ chunks.append(f"{num_nans}⚠")
250
+ if "zeros" in tensor_stats:
251
+ num_zeros = tensor.numel() - tensor.count_nonzero().item()
252
+ chunks.append(f"{num_zeros}⌀")
253
+ if "med" in tensor_stats:
254
+ med = tensor.median().item()
255
+ chunks.append(f"{med}m̃")
256
+ if "mean" in tensor_stats:
257
+ mean = tensor.float().mean().item()
258
+ chunks.append(f"{mean}μ")
259
+ if "amean" in tensor_stats:
260
+ amean = tensor.abs().float().mean().item()
261
+ chunks.append(f"{amean}μ⁺")
262
+ if "std" in tensor_stats:
263
+ std = tensor.float().std().item()
264
+ chunks.append(f"{std}σ")
265
+ if "var" in tensor_stats:
266
+ var = tensor.float().var(unbiased=False).item()
267
+ chunks.append(f"{var}σ²")
268
+ if "uvar" in tensor_stats:
269
+ uvar = tensor.float().var(unbiased=True).item()
270
+ chunks.append(f"{uvar}s²")
271
+ if "skew" in tensor_stats:
272
+ t = tensor.float()
273
+ std = t.std().item()
274
+ skew = ((t - t.mean()) / std).pow(3).double().mean().item() if std != 0 else "?"
275
+ chunks.append(f"{skew}γ₁")
276
+ if "kurt" in tensor_stats:
277
+ t = tensor.float()
278
+ std = t.std().item()
279
+ kurt = ((t - t.mean()) / std).pow(4).double().mean().item() if std != 0 else "?"
280
+ chunks.append(f"{kurt}γ₂")
281
+ if "mode" in tensor_stats:
282
+ vals, counts = torch.unique(tensor, return_counts=True)
283
+ _, idx = counts.max(0)
284
+ mode = vals[idx]
285
+ chunks.append(f"{mode}Mo")
286
+ if "min" in tensor_stats:
287
+ mi = tensor.min().item()
288
+ chunks.append(f"{mi}↤")
289
+ if "max" in tensor_stats:
290
+ ma = tensor.max().item()
291
+ chunks.append(f"{ma}↦")
292
+ if "hash" in tensor_stats:
293
+ chunks.append(hash)
294
+ res = "|".join(chunks)
295
+ return res
296
+
297
+
298
+ def _buffers_checksums(model: torch.nn.Module, tensor_stats: set[str]) -> dict[str, Any]:
299
+ buffers = {}
300
+ buffers_cks = hashlib.md5()
301
+ for name, buffer in model.named_buffers():
302
+ if buffer is not None:
303
+ buffers_cks = tensor_checksum(buffer)
304
+ buffers[name] = _format(buffer, buffers_cks.hexdigest(), tensor_stats)
305
+ buffers_cks.update(buffers_cks.digest())
306
+ _add_digest(buffers, "__all_buffers__", buffers_cks)
307
+ return buffers
308
+
309
+
310
+ def _rngs_checksums() -> dict[str, Any]:
311
+ rngs: dict[str, Any] = {}
312
+ rng_states = _collect_rng_states(include_cuda=True) # type: ignore[attr-defined]
313
+ for name, state in rng_states.items():
314
+ if name == "torch":
315
+ cks = tensor_checksum(state).hexdigest()
316
+ elif name == "torch.cuda":
317
+ cks = tensor_checksum(torch.stack(state, dim=0)).hexdigest()
318
+ elif name == "numpy":
319
+ cks = {
320
+ "algo": state[0],
321
+ "state": numpy_checksum(state[1]).hexdigest(),
322
+ "pos": state[2],
323
+ "has_gauss": state[3],
324
+ "cached_gaussian": state[4],
325
+ }
326
+ elif name == "python":
327
+ cks = {
328
+ "version": state[0],
329
+ "state": numpy_checksum(np.array(state[1])).hexdigest(),
330
+ "gaussian": state[2],
331
+ }
332
+ else:
333
+ raise RuntimeError(f"Unsupported RNG state '{name}': {state}")
334
+ rngs[name] = cks
335
+ return rngs
336
+
337
+
338
+ def get_checksums(
339
+ trainer: L.Trainer,
340
+ model: torch.nn.Module,
341
+ gradients: dict[str, Any],
342
+ checksums: set[str],
343
+ tensor_stats: set[str],
344
+ ) -> dict[str, Any]:
345
+ """Checksums for internal model state and training context.
346
+
347
+ Args:
348
+ model (torch.nn.Module): model to generate the parameters checksum
349
+ checksums (set[str]): - checksums to collect
350
+
351
+ Returns:
352
+ dict: captured checksums
353
+ """
354
+ cks: dict[str, Any] = {}
355
+ params_checksum = "params" in checksums
356
+ grads_checksum = "grads" in checksums
357
+ if params_checksum or grads_checksum:
358
+ cks["parameters"] = _params_checksums(model, params_checksum, grads_checksum, gradients, tensor_stats)
359
+ if "buffers" in checksums:
360
+ cks["buffers"] = _buffers_checksums(model, tensor_stats)
361
+ if "optimizers" in checksums:
362
+ cks["optimizers"] = _optimizers_checksums(trainer, model, tensor_stats)
363
+ if "rngs" in checksums:
364
+ cks["rngs"] = _rngs_checksums()
365
+ return cks
366
+
367
+
368
+ def _batch_checksums(batch: Any, batch_idx: int, tensor_stats: set[str]) -> dict[str, Any]:
369
+ """Calculate checksums for batch data.
370
+ Args:
371
+ batch (Any): The batch data to generate checksums for
372
+ batch_idx (int): The batch index
373
+ Returns:
374
+ dict[str, Any]: Dictionary containing checksums for batch elements
375
+ """
376
+ checksums: dict[str, Any] = {"__batch_idx__": batch_idx}
377
+ _add_digest(checksums, "__all_batch__", process_item(batch, tensor_stats, "__batch__", checksums))
378
+ return checksums
379
+
380
+
381
+ def _add_digest(cks: dict[str, Any], name: str, digest: Any) -> None:
382
+ if digest is not None:
383
+ cks[name] = digest.hex() if isinstance(digest, bytes) else digest.hexdigest()
384
+
385
+
386
+ def _optimizers_checksums(trainer: L.Trainer, pl_module: torch.nn.Module, tensor_stats: set[str]) -> list[Any]:
387
+ """Extract and compute checksums of optimizer states"""
388
+ checksums = []
389
+ for opt in trainer.optimizers:
390
+ opt_cks = {"__type__": type(opt).__name__}
391
+ _add_digest(opt_cks, "__all_defaults__", process_item(opt.defaults, tensor_stats, "defaults", opt_cks))
392
+ _add_digest(
393
+ opt_cks, "__all_param_groups__", process_item(opt.param_groups, tensor_stats, "param_groups", opt_cks)
394
+ )
395
+ _add_digest(opt_cks, "__all_state__", process_item(opt.state, tensor_stats, "state", opt_cks))
396
+ checksums.append(opt_cks)
397
+ return checksums
398
+
399
+
400
+ class Introspection(L.Callback):
401
+ def __init__(
402
+ self,
403
+ checksums: set[str] | None = None,
404
+ tensor_stats: set[str] | None = None,
405
+ env_vars: bool = False,
406
+ pip_freeze: bool = False,
407
+ output_path_prefix: str | None = None,
408
+ schedule: Schedule | None = None,
409
+ ) -> None:
410
+ """Introspection PyTorch Lightning callback.
411
+ This callback helps capture internal model and training environment states
412
+ to help investigate model performance and other training regressions.
413
+ It publishes reports that help compare the internal state between different
414
+ steps and runs.
415
+
416
+ Args:
417
+ checksums (set[str] | None): Checksums to collect.
418
+ Use any combination of the following options:
419
+ ``"params"`` - capture checksum for every model parameter,
420
+ ``"buffers"`` - capture checksum for every model buffer (auxiliary tensor),
421
+ ``"grads"`` - capture checksum for every model parameter's gradient,
422
+ ``"rngs"`` - capture checksum for every Random Number Generator (RNG) state,
423
+ ``"optimizers"`` - capture checksum for optimizers state,
424
+ ``"batch"`` - capture checksum for batch data (model input).
425
+ Defaults to no checksum collection.
426
+ tensor_stats (set[str] | None): Tensor details to collect next to checksums,
427
+ Use any combination of the following options:
428
+ ``"shape"`` - tensor shape,
429
+ ``"dtype"`` - tensor dtype,
430
+ ``"infs"`` - count the number of positive and negative infinity elements (∞̟ and ∞̠),
431
+ ``"nans"`` - count the number of NaN elements (⚠),
432
+ ``"zeros"`` - count the number of 0 elements (⌀),
433
+ ``"med"`` - median value (m̃),
434
+ ``"mean"`` - mean value (μ),
435
+ ``"amean"`` - absolute mean value (μ⁺),
436
+ ``"std"`` - standard deviation (σ),
437
+ ``"var"`` - variance (biased) (σ²),
438
+ ``"uvar"`` - unbiased variance (s²),
439
+ ``"skew"`` - skewness (γ₁),
440
+ ``"kurt"`` - kurtosis (γ₂),
441
+ ``"mode"`` - mode (Mo),
442
+ ``"min"`` - min value (↤),
443
+ ``"max"`` - max value (↦),
444
+ ``"hash"`` - tensor content hash,
445
+ Defaults to ``{"hash"}``
446
+ env_vars (bool): capture environment variables when the training starts,
447
+ defaults to ``False``
448
+ pip_freeze (bool): capture installed pip packages when the training starts,
449
+ defaults to ``False``
450
+ output_path_prefix (str | None): output path prefix for generated reports,
451
+ use to persist these files locally, defaults to temporary location
452
+ that is cleaned as soon as the published by logger.
453
+ schedule (Schedule | None): Controls when logging occurs during training.
454
+ Defaults to :class:`Never`.
455
+ """
456
+ self.checksums = set(checksums or set())
457
+ self.tensor_stats = set(tensor_stats or {"hash"})
458
+ self.env_vars = env_vars
459
+ self.pip_freeze = pip_freeze
460
+ self.output_path_prefix = output_path_prefix
461
+ self.schedule = schedule or Never()
462
+ self._should_publish = False
463
+ self._cb_logger: LightningLogger | None = None
464
+ self.checksum: Any | None = None
465
+ self.grad_hooks: list[Any] = []
466
+ self.gradients: dict[str, tuple[Any, str]] = {}
467
+ self._hooks_registered = False
468
+
469
+ def _publish(self, file_name: str, path: str, data: dict[str, Any]) -> None:
470
+ with tempfile.TemporaryDirectory() as td:
471
+ output_file = os.path.join(self.output_path_prefix or td, file_name)
472
+ with open(output_file, "w", encoding="utf-8") as f:
473
+ yaml.dump(data, f, sort_keys=False, indent=2, default_flow_style=False, allow_unicode=True, width=10**6)
474
+ self._cb_logger.log_artifact(output_file, path) # type: ignore[union-attr]
475
+
476
+ @override
477
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
478
+ self._cb_logger = CallbackLogger(trainer)
479
+ self.stage = stage
480
+ if self.env_vars:
481
+ self._publish("env_vars.yaml", PATH_PREFIX, dict(os.environ))
482
+ if self.pip_freeze:
483
+ self._publish(
484
+ "pip_freeze.yaml",
485
+ PATH_PREFIX,
486
+ {p.metadata["Name"]: p.version for p in distributions()},
487
+ )
488
+
489
+ @override
490
+ def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
491
+ """Register hooks at the start of training to capture all gradients including the first step"""
492
+ self._register_hooks(pl_module)
493
+
494
+ @override
495
+ def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
496
+ """Remove hooks at the end of training"""
497
+ self._remove_hooks()
498
+
499
+ def _remove_hooks(self) -> None:
500
+ """Remove all registered hooks"""
501
+ for hook in self.grad_hooks:
502
+ hook.remove()
503
+ self.grad_hooks = []
504
+ self._hooks_registered = False
505
+
506
+ def _register_hooks(self, pl_module: "L.LightningModule") -> None:
507
+ """Register hooks on parameters to capture gradients"""
508
+ if self._hooks_registered or "grads" not in self.checksums:
509
+ return
510
+ for name, param in pl_module.named_parameters():
511
+ if param.requires_grad:
512
+
513
+ def hook(param_name: str, grad: torch.Tensor) -> torch.Tensor:
514
+ if grad is not None:
515
+ g = grad.detach()
516
+ grad_cks = tensor_checksum(g)
517
+ self.gradients[param_name] = (grad_cks, _format(g, grad_cks.hexdigest(), self.tensor_stats))
518
+ return grad
519
+
520
+ hook = param.register_hook(partial(hook, name))
521
+ self.grad_hooks.append(hook)
522
+ self._hooks_registered = True
523
+
524
+ @override
525
+ def on_train_batch_start(
526
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
527
+ ) -> None:
528
+ if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
529
+ self.checksum = {}
530
+
531
+ @override
532
+ def on_train_batch_end(
533
+ self,
534
+ trainer: "L.Trainer",
535
+ pl_module: "L.LightningModule",
536
+ outputs: "STEP_OUTPUT",
537
+ batch: Any,
538
+ batch_idx: int,
539
+ ) -> None:
540
+ self.gradients.clear()
541
+ if self.checksum is None:
542
+ return
543
+ if "batch" in self.checksums:
544
+ self.checksum["batch"] = _batch_checksums(batch, batch_idx, self.tensor_stats)
545
+ self._publish(
546
+ f"rank{trainer.global_rank}.yaml",
547
+ f"{PATH_PREFIX}/{self.stage}/step={trainer.global_step}",
548
+ self.checksum,
549
+ )
550
+ self.checksum = None
551
+
552
+ @override
553
+ def on_before_optimizer_step(
554
+ self,
555
+ trainer: "L.Trainer",
556
+ pl_module: "L.LightningModule",
557
+ optimizer: torch.optim.Optimizer,
558
+ ) -> None:
559
+ if self.checksum is None:
560
+ return
561
+ self.checksum.update(
562
+ get_checksums(
563
+ trainer,
564
+ pl_module,
565
+ self.gradients,
566
+ self.checksums,
567
+ self.tensor_stats,
568
+ )
569
+ )
@@ -0,0 +1,45 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+ from typing_extensions import override
5
+ import datetime as dt
6
+
7
+ import torch
8
+ import fsspec
9
+ import lightning as L
10
+
11
+ from fkat.pytorch.schedule import Schedule, Never
12
+
13
+
14
+ class OptimizerSnapshot(L.Callback):
15
+ """
16
+ Callback that saves optimizer state at specified intervals during training.
17
+
18
+ This callback allows you to capture the state of optimizers at specific points
19
+ during training, which can be useful for debugging, analysis, or resuming training
20
+ from specific optimization states.
21
+
22
+ Args:
23
+ output_path_prefix (str): Output path prefix for generated optimizer snapshots.
24
+ schedule (Optional[Schedule]): Schedule at which to take a snapshot of optimizers.
25
+ Defaults to ``Never``
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ output_path_prefix: str,
31
+ schedule: Schedule | None = None,
32
+ ) -> None:
33
+ self.output_path_prefix = output_path_prefix
34
+ self.schedule = schedule or Never()
35
+
36
+ @override
37
+ def on_train_batch_start(
38
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
39
+ ) -> None:
40
+ if self.schedule.check(trainer=trainer, stage="train", batch_idx=batch_idx, step=trainer.global_step):
41
+ timestamp = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")
42
+ for i, opt in enumerate(trainer.optimizers):
43
+ path = f"{self.output_path_prefix}rank{trainer.global_rank}_opt{i}_{timestamp}.pt"
44
+ with fsspec.open(path, "wb", makedirs=True) as f:
45
+ torch.save(opt, f)