fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import hashlib
|
|
5
|
+
import yaml
|
|
6
|
+
import logging
|
|
7
|
+
import tempfile
|
|
8
|
+
from importlib.metadata import distributions
|
|
9
|
+
from collections.abc import Callable, Hashable
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import Any, TYPE_CHECKING
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import numpy.typing as ntp
|
|
16
|
+
import torch
|
|
17
|
+
import lightning as L
|
|
18
|
+
from lightning.pytorch.utilities.seed import _collect_rng_states # type: ignore[attr-defined]
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
22
|
+
|
|
23
|
+
from fkat.pytorch.schedule import (
|
|
24
|
+
Schedule,
|
|
25
|
+
Never,
|
|
26
|
+
)
|
|
27
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
28
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
29
|
+
|
|
30
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
PATH_PREFIX = "introspection"
|
|
33
|
+
DTYPES = {
|
|
34
|
+
"float32": "fp32",
|
|
35
|
+
"float64": "fp64",
|
|
36
|
+
"float16": "fp16",
|
|
37
|
+
"bfloat16": "bf16",
|
|
38
|
+
"int64": "i64",
|
|
39
|
+
"int32": "i32",
|
|
40
|
+
"int16": "i16",
|
|
41
|
+
"int8": "i8",
|
|
42
|
+
"uint8": "u8",
|
|
43
|
+
"bool": "b",
|
|
44
|
+
"complex64": "c64",
|
|
45
|
+
"complex128": "c128",
|
|
46
|
+
"qint8": "qi8",
|
|
47
|
+
"quint8": "qui8",
|
|
48
|
+
"qint32": "qi32",
|
|
49
|
+
"float8_e4m3fn": "fp8_e4m3fn",
|
|
50
|
+
"float8_e4m3fnuz": "fp8_e4m3fnuz",
|
|
51
|
+
"float8_e5m2": "fp8_e5m2",
|
|
52
|
+
"float8_e5m2fnuz": "fp8_e5m2fnuz",
|
|
53
|
+
"float8_e8m0fnu": "fp8_e8m0fnu",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_dtype(dtype: torch.dtype) -> str:
|
|
58
|
+
for k, v in DTYPES.items():
|
|
59
|
+
if getattr(torch, k) == dtype:
|
|
60
|
+
return v
|
|
61
|
+
return str(dtype)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _process_tensor_or_ndarray(
|
|
65
|
+
item: Any, tensor_stats: set[str], path: str, parent_dict: dict[str, Any] | None
|
|
66
|
+
) -> bytes | None:
|
|
67
|
+
"""Process tensor or ndarray items and calculate checksums.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
item: The tensor or ndarray to process
|
|
71
|
+
tensor_stats (set[str]): Tensor stats to collect
|
|
72
|
+
path: Current path in the nested structure
|
|
73
|
+
parent_dict: Dictionary to store checksums in
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
bytes | None: Digest of the checksum
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(item, torch.Tensor):
|
|
79
|
+
cks = tensor_checksum(item)
|
|
80
|
+
else: # np.ndarray
|
|
81
|
+
cks = numpy_checksum(item)
|
|
82
|
+
|
|
83
|
+
if parent_dict is not None:
|
|
84
|
+
parent_dict[path] = _format(item, cks.hexdigest(), tensor_stats)
|
|
85
|
+
return cks.digest()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _process_nested(sub_item: Any, tensor_stats: set[str], nested_cks: Any, fn: Callable[[Any], Any]) -> None:
|
|
89
|
+
nested: dict[str, Any] = {}
|
|
90
|
+
nested_key = None
|
|
91
|
+
if isinstance(sub_item, torch.Tensor | np.ndarray):
|
|
92
|
+
digest = _process_tensor_or_ndarray(sub_item, tensor_stats, (nested_key := "temp"), nested)
|
|
93
|
+
elif isinstance(sub_item, list | tuple | set | frozenset) or hasattr(sub_item, "items"):
|
|
94
|
+
digest = _process_collection(sub_item, tensor_stats, (nested_key := "temp"), nested)
|
|
95
|
+
if nested_key:
|
|
96
|
+
processed = None
|
|
97
|
+
if digest is not None:
|
|
98
|
+
nested_cks.update(digest)
|
|
99
|
+
processed = nested[nested_key]
|
|
100
|
+
fn(processed)
|
|
101
|
+
else:
|
|
102
|
+
fn(sub_item)
|
|
103
|
+
digest = str(sub_item).encode("utf-8")
|
|
104
|
+
nested_cks.update(digest)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _ensure_hashable(item: Any) -> Hashable:
|
|
108
|
+
if isinstance(item, Hashable):
|
|
109
|
+
return item
|
|
110
|
+
if isinstance(item, list | tuple):
|
|
111
|
+
return tuple(_ensure_hashable(i) for i in item)
|
|
112
|
+
if isinstance(item, set | frozenset):
|
|
113
|
+
return frozenset(_ensure_hashable(i) for i in item)
|
|
114
|
+
return str(item)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _process_collection(
|
|
118
|
+
item: Any,
|
|
119
|
+
tensor_stats: set[str],
|
|
120
|
+
path: str,
|
|
121
|
+
parent_dict: dict[str, Any] | None,
|
|
122
|
+
) -> bytes | None:
|
|
123
|
+
"""Process collection items (list, tuple, dict) and calculate checksums.
|
|
124
|
+
Maintains the nested structure of collections while replacing only tensor values with hashes.
|
|
125
|
+
Primitive values are kept as is.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
item: The collection to process
|
|
129
|
+
tensor_stats (set[str]): Tensor stats to collect
|
|
130
|
+
path: Current path in the nested structure
|
|
131
|
+
parent_dict: Dictionary to store checksums in
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
bytes | None: Digest of the overall checksum
|
|
135
|
+
"""
|
|
136
|
+
nested_cks = hashlib.md5()
|
|
137
|
+
result: list[Any] | dict[Any, Any] | None = None
|
|
138
|
+
if hasattr(item, "items"): # map-like
|
|
139
|
+
result = {}
|
|
140
|
+
for sub_key, sub_item in item.items():
|
|
141
|
+
nested: list[str] = []
|
|
142
|
+
_process_nested(sub_key, tensor_stats, nested_cks, nested.append)
|
|
143
|
+
key = _ensure_hashable(nested.pop())
|
|
144
|
+
_process_nested(sub_item, tensor_stats, nested_cks, partial(result.__setitem__, key))
|
|
145
|
+
elif isinstance(item, list | tuple | set | frozenset):
|
|
146
|
+
result = []
|
|
147
|
+
for sub_item in item:
|
|
148
|
+
_process_nested(sub_item, tensor_stats, nested_cks, result.append)
|
|
149
|
+
if result is not None and parent_dict is not None and path:
|
|
150
|
+
parent_dict[path] = result
|
|
151
|
+
return nested_cks.digest()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def process_item(
|
|
155
|
+
item: Any, tensor_stats: set[str], path: str = "", parent_dict: dict[str, Any] | None = None
|
|
156
|
+
) -> bytes | None:
|
|
157
|
+
"""Recursively process items in the batch and calculate checksums.
|
|
158
|
+
Only tensor values are replaced with hashes, primitive values are kept as is.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
item: The item to process
|
|
162
|
+
tensor_stats (set[str]): Tensor stats to collect
|
|
163
|
+
path: Current path in the nested structure
|
|
164
|
+
parent_dict: Dictionary to store checksums in
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
bytes | None: Digest of the checksum if available, None otherwise
|
|
168
|
+
"""
|
|
169
|
+
if isinstance(item, torch.Tensor | np.ndarray):
|
|
170
|
+
return _process_tensor_or_ndarray(item, tensor_stats, path, parent_dict)
|
|
171
|
+
elif isinstance(item, list | tuple | set | frozenset) or hasattr(item, "items"):
|
|
172
|
+
return _process_collection(item, tensor_stats, path, parent_dict)
|
|
173
|
+
elif not isinstance(item, str | int | float | bool) or item is not None:
|
|
174
|
+
logging.warning(f"Converting {type(item).__name__} to string for checksum")
|
|
175
|
+
item = str(item)
|
|
176
|
+
if parent_dict is not None:
|
|
177
|
+
parent_dict[path] = item
|
|
178
|
+
return str(item).encode("utf-8")
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def tensor_checksum(tensor: torch.Tensor) -> Any:
|
|
182
|
+
"""Tensor checksum hash, returns the same value for tensors with identical contents
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
tensor (torch.Tensor): tensor to generate the checksum
|
|
186
|
+
|
|
187
|
+
Return:
|
|
188
|
+
hashlib.md5: checksum hash
|
|
189
|
+
"""
|
|
190
|
+
tensor = tensor.to(torch.float32) if tensor.dtype == torch.bfloat16 else tensor
|
|
191
|
+
return numpy_checksum(tensor.detach().cpu().numpy())
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def numpy_checksum(ndarray: ntp.NDArray[Any]) -> Any:
|
|
195
|
+
"""Numpy NDArray checksum, returns the same value for ndarrays with identical contents
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
tensor (torch.Tensor): tensor to generate the checksum
|
|
199
|
+
|
|
200
|
+
Return:
|
|
201
|
+
hashlib.md5: checksum hash
|
|
202
|
+
"""
|
|
203
|
+
return hashlib.md5(ndarray.tobytes())
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _params_checksums(
|
|
207
|
+
model: torch.nn.Module,
|
|
208
|
+
params_checksum: bool,
|
|
209
|
+
grads_checksum: bool,
|
|
210
|
+
gradients: dict[str, Any],
|
|
211
|
+
tensor_stats: set[str],
|
|
212
|
+
) -> dict[str, Any]:
|
|
213
|
+
parameters: dict[str, Any] = {}
|
|
214
|
+
grads_cks = hashlib.md5()
|
|
215
|
+
params_cks = hashlib.md5()
|
|
216
|
+
for name, param in model.named_parameters():
|
|
217
|
+
parameters[name] = {}
|
|
218
|
+
if params_checksum and param.data is not None:
|
|
219
|
+
param_cks = tensor_checksum(param.data)
|
|
220
|
+
parameters[name]["data"] = _format(param.data, param_cks.hexdigest(), tensor_stats)
|
|
221
|
+
params_cks.update(param_cks.digest())
|
|
222
|
+
grad_cks, repr = gradients.get(name, (None, None))
|
|
223
|
+
if grads_cks is not None:
|
|
224
|
+
assert grad_cks
|
|
225
|
+
parameters[name]["grad"] = repr
|
|
226
|
+
grads_cks.update(grad_cks.digest())
|
|
227
|
+
if params_checksum:
|
|
228
|
+
_add_digest(parameters, "__all_data__", params_cks)
|
|
229
|
+
if grads_checksum:
|
|
230
|
+
_add_digest(parameters, "__all_grads__", grads_cks)
|
|
231
|
+
return parameters
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _format(tensor: torch.Tensor, hash: str, tensor_stats: set[str]) -> str: # noqa: C901
|
|
235
|
+
if tensor.dim() == 0:
|
|
236
|
+
return str(tensor)
|
|
237
|
+
chunks: list[str] = []
|
|
238
|
+
if "shape" in tensor_stats:
|
|
239
|
+
chunks.append("×".join(str(d) for d in tensor.shape))
|
|
240
|
+
if "dtype" in tensor_stats:
|
|
241
|
+
chunks.append(_get_dtype(tensor.dtype))
|
|
242
|
+
if "infs" in tensor_stats:
|
|
243
|
+
num_pos_infs = tensor.isposinf().sum()
|
|
244
|
+
chunks.append(f"{num_pos_infs}∞̟")
|
|
245
|
+
num_neg_infs = tensor.isneginf().sum()
|
|
246
|
+
chunks.append(f"{num_neg_infs}∞̠")
|
|
247
|
+
if "nans" in tensor_stats:
|
|
248
|
+
num_nans = tensor.isnan().sum()
|
|
249
|
+
chunks.append(f"{num_nans}⚠")
|
|
250
|
+
if "zeros" in tensor_stats:
|
|
251
|
+
num_zeros = tensor.numel() - tensor.count_nonzero().item()
|
|
252
|
+
chunks.append(f"{num_zeros}⌀")
|
|
253
|
+
if "med" in tensor_stats:
|
|
254
|
+
med = tensor.median().item()
|
|
255
|
+
chunks.append(f"{med}m̃")
|
|
256
|
+
if "mean" in tensor_stats:
|
|
257
|
+
mean = tensor.float().mean().item()
|
|
258
|
+
chunks.append(f"{mean}μ")
|
|
259
|
+
if "amean" in tensor_stats:
|
|
260
|
+
amean = tensor.abs().float().mean().item()
|
|
261
|
+
chunks.append(f"{amean}μ⁺")
|
|
262
|
+
if "std" in tensor_stats:
|
|
263
|
+
std = tensor.float().std().item()
|
|
264
|
+
chunks.append(f"{std}σ")
|
|
265
|
+
if "var" in tensor_stats:
|
|
266
|
+
var = tensor.float().var(unbiased=False).item()
|
|
267
|
+
chunks.append(f"{var}σ²")
|
|
268
|
+
if "uvar" in tensor_stats:
|
|
269
|
+
uvar = tensor.float().var(unbiased=True).item()
|
|
270
|
+
chunks.append(f"{uvar}s²")
|
|
271
|
+
if "skew" in tensor_stats:
|
|
272
|
+
t = tensor.float()
|
|
273
|
+
std = t.std().item()
|
|
274
|
+
skew = ((t - t.mean()) / std).pow(3).double().mean().item() if std != 0 else "?"
|
|
275
|
+
chunks.append(f"{skew}γ₁")
|
|
276
|
+
if "kurt" in tensor_stats:
|
|
277
|
+
t = tensor.float()
|
|
278
|
+
std = t.std().item()
|
|
279
|
+
kurt = ((t - t.mean()) / std).pow(4).double().mean().item() if std != 0 else "?"
|
|
280
|
+
chunks.append(f"{kurt}γ₂")
|
|
281
|
+
if "mode" in tensor_stats:
|
|
282
|
+
vals, counts = torch.unique(tensor, return_counts=True)
|
|
283
|
+
_, idx = counts.max(0)
|
|
284
|
+
mode = vals[idx]
|
|
285
|
+
chunks.append(f"{mode}Mo")
|
|
286
|
+
if "min" in tensor_stats:
|
|
287
|
+
mi = tensor.min().item()
|
|
288
|
+
chunks.append(f"{mi}↤")
|
|
289
|
+
if "max" in tensor_stats:
|
|
290
|
+
ma = tensor.max().item()
|
|
291
|
+
chunks.append(f"{ma}↦")
|
|
292
|
+
if "hash" in tensor_stats:
|
|
293
|
+
chunks.append(hash)
|
|
294
|
+
res = "|".join(chunks)
|
|
295
|
+
return res
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _buffers_checksums(model: torch.nn.Module, tensor_stats: set[str]) -> dict[str, Any]:
|
|
299
|
+
buffers = {}
|
|
300
|
+
buffers_cks = hashlib.md5()
|
|
301
|
+
for name, buffer in model.named_buffers():
|
|
302
|
+
if buffer is not None:
|
|
303
|
+
buffers_cks = tensor_checksum(buffer)
|
|
304
|
+
buffers[name] = _format(buffer, buffers_cks.hexdigest(), tensor_stats)
|
|
305
|
+
buffers_cks.update(buffers_cks.digest())
|
|
306
|
+
_add_digest(buffers, "__all_buffers__", buffers_cks)
|
|
307
|
+
return buffers
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _rngs_checksums() -> dict[str, Any]:
|
|
311
|
+
rngs: dict[str, Any] = {}
|
|
312
|
+
rng_states = _collect_rng_states(include_cuda=True) # type: ignore[attr-defined]
|
|
313
|
+
for name, state in rng_states.items():
|
|
314
|
+
if name == "torch":
|
|
315
|
+
cks = tensor_checksum(state).hexdigest()
|
|
316
|
+
elif name == "torch.cuda":
|
|
317
|
+
cks = tensor_checksum(torch.stack(state, dim=0)).hexdigest()
|
|
318
|
+
elif name == "numpy":
|
|
319
|
+
cks = {
|
|
320
|
+
"algo": state[0],
|
|
321
|
+
"state": numpy_checksum(state[1]).hexdigest(),
|
|
322
|
+
"pos": state[2],
|
|
323
|
+
"has_gauss": state[3],
|
|
324
|
+
"cached_gaussian": state[4],
|
|
325
|
+
}
|
|
326
|
+
elif name == "python":
|
|
327
|
+
cks = {
|
|
328
|
+
"version": state[0],
|
|
329
|
+
"state": numpy_checksum(np.array(state[1])).hexdigest(),
|
|
330
|
+
"gaussian": state[2],
|
|
331
|
+
}
|
|
332
|
+
else:
|
|
333
|
+
raise RuntimeError(f"Unsupported RNG state '{name}': {state}")
|
|
334
|
+
rngs[name] = cks
|
|
335
|
+
return rngs
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_checksums(
|
|
339
|
+
trainer: L.Trainer,
|
|
340
|
+
model: torch.nn.Module,
|
|
341
|
+
gradients: dict[str, Any],
|
|
342
|
+
checksums: set[str],
|
|
343
|
+
tensor_stats: set[str],
|
|
344
|
+
) -> dict[str, Any]:
|
|
345
|
+
"""Checksums for internal model state and training context.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
model (torch.nn.Module): model to generate the parameters checksum
|
|
349
|
+
checksums (set[str]): - checksums to collect
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
dict: captured checksums
|
|
353
|
+
"""
|
|
354
|
+
cks: dict[str, Any] = {}
|
|
355
|
+
params_checksum = "params" in checksums
|
|
356
|
+
grads_checksum = "grads" in checksums
|
|
357
|
+
if params_checksum or grads_checksum:
|
|
358
|
+
cks["parameters"] = _params_checksums(model, params_checksum, grads_checksum, gradients, tensor_stats)
|
|
359
|
+
if "buffers" in checksums:
|
|
360
|
+
cks["buffers"] = _buffers_checksums(model, tensor_stats)
|
|
361
|
+
if "optimizers" in checksums:
|
|
362
|
+
cks["optimizers"] = _optimizers_checksums(trainer, model, tensor_stats)
|
|
363
|
+
if "rngs" in checksums:
|
|
364
|
+
cks["rngs"] = _rngs_checksums()
|
|
365
|
+
return cks
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _batch_checksums(batch: Any, batch_idx: int, tensor_stats: set[str]) -> dict[str, Any]:
|
|
369
|
+
"""Calculate checksums for batch data.
|
|
370
|
+
Args:
|
|
371
|
+
batch (Any): The batch data to generate checksums for
|
|
372
|
+
batch_idx (int): The batch index
|
|
373
|
+
Returns:
|
|
374
|
+
dict[str, Any]: Dictionary containing checksums for batch elements
|
|
375
|
+
"""
|
|
376
|
+
checksums: dict[str, Any] = {"__batch_idx__": batch_idx}
|
|
377
|
+
_add_digest(checksums, "__all_batch__", process_item(batch, tensor_stats, "__batch__", checksums))
|
|
378
|
+
return checksums
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _add_digest(cks: dict[str, Any], name: str, digest: Any) -> None:
|
|
382
|
+
if digest is not None:
|
|
383
|
+
cks[name] = digest.hex() if isinstance(digest, bytes) else digest.hexdigest()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _optimizers_checksums(trainer: L.Trainer, pl_module: torch.nn.Module, tensor_stats: set[str]) -> list[Any]:
|
|
387
|
+
"""Extract and compute checksums of optimizer states"""
|
|
388
|
+
checksums = []
|
|
389
|
+
for opt in trainer.optimizers:
|
|
390
|
+
opt_cks = {"__type__": type(opt).__name__}
|
|
391
|
+
_add_digest(opt_cks, "__all_defaults__", process_item(opt.defaults, tensor_stats, "defaults", opt_cks))
|
|
392
|
+
_add_digest(
|
|
393
|
+
opt_cks, "__all_param_groups__", process_item(opt.param_groups, tensor_stats, "param_groups", opt_cks)
|
|
394
|
+
)
|
|
395
|
+
_add_digest(opt_cks, "__all_state__", process_item(opt.state, tensor_stats, "state", opt_cks))
|
|
396
|
+
checksums.append(opt_cks)
|
|
397
|
+
return checksums
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class Introspection(L.Callback):
|
|
401
|
+
def __init__(
|
|
402
|
+
self,
|
|
403
|
+
checksums: set[str] | None = None,
|
|
404
|
+
tensor_stats: set[str] | None = None,
|
|
405
|
+
env_vars: bool = False,
|
|
406
|
+
pip_freeze: bool = False,
|
|
407
|
+
output_path_prefix: str | None = None,
|
|
408
|
+
schedule: Schedule | None = None,
|
|
409
|
+
) -> None:
|
|
410
|
+
"""Introspection PyTorch Lightning callback.
|
|
411
|
+
This callback helps capture internal model and training environment states
|
|
412
|
+
to help investigate model performance and other training regressions.
|
|
413
|
+
It publishes reports that help compare the internal state between different
|
|
414
|
+
steps and runs.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
checksums (set[str] | None): Checksums to collect.
|
|
418
|
+
Use any combination of the following options:
|
|
419
|
+
``"params"`` - capture checksum for every model parameter,
|
|
420
|
+
``"buffers"`` - capture checksum for every model buffer (auxiliary tensor),
|
|
421
|
+
``"grads"`` - capture checksum for every model parameter's gradient,
|
|
422
|
+
``"rngs"`` - capture checksum for every Random Number Generator (RNG) state,
|
|
423
|
+
``"optimizers"`` - capture checksum for optimizers state,
|
|
424
|
+
``"batch"`` - capture checksum for batch data (model input).
|
|
425
|
+
Defaults to no checksum collection.
|
|
426
|
+
tensor_stats (set[str] | None): Tensor details to collect next to checksums,
|
|
427
|
+
Use any combination of the following options:
|
|
428
|
+
``"shape"`` - tensor shape,
|
|
429
|
+
``"dtype"`` - tensor dtype,
|
|
430
|
+
``"infs"`` - count the number of positive and negative infinity elements (∞̟ and ∞̠),
|
|
431
|
+
``"nans"`` - count the number of NaN elements (⚠),
|
|
432
|
+
``"zeros"`` - count the number of 0 elements (⌀),
|
|
433
|
+
``"med"`` - median value (m̃),
|
|
434
|
+
``"mean"`` - mean value (μ),
|
|
435
|
+
``"amean"`` - absolute mean value (μ⁺),
|
|
436
|
+
``"std"`` - standard deviation (σ),
|
|
437
|
+
``"var"`` - variance (biased) (σ²),
|
|
438
|
+
``"uvar"`` - unbiased variance (s²),
|
|
439
|
+
``"skew"`` - skewness (γ₁),
|
|
440
|
+
``"kurt"`` - kurtosis (γ₂),
|
|
441
|
+
``"mode"`` - mode (Mo),
|
|
442
|
+
``"min"`` - min value (↤),
|
|
443
|
+
``"max"`` - max value (↦),
|
|
444
|
+
``"hash"`` - tensor content hash,
|
|
445
|
+
Defaults to ``{"hash"}``
|
|
446
|
+
env_vars (bool): capture environment variables when the training starts,
|
|
447
|
+
defaults to ``False``
|
|
448
|
+
pip_freeze (bool): capture installed pip packages when the training starts,
|
|
449
|
+
defaults to ``False``
|
|
450
|
+
output_path_prefix (str | None): output path prefix for generated reports,
|
|
451
|
+
use to persist these files locally, defaults to temporary location
|
|
452
|
+
that is cleaned as soon as the published by logger.
|
|
453
|
+
schedule (Schedule | None): Controls when logging occurs during training.
|
|
454
|
+
Defaults to :class:`Never`.
|
|
455
|
+
"""
|
|
456
|
+
self.checksums = set(checksums or set())
|
|
457
|
+
self.tensor_stats = set(tensor_stats or {"hash"})
|
|
458
|
+
self.env_vars = env_vars
|
|
459
|
+
self.pip_freeze = pip_freeze
|
|
460
|
+
self.output_path_prefix = output_path_prefix
|
|
461
|
+
self.schedule = schedule or Never()
|
|
462
|
+
self._should_publish = False
|
|
463
|
+
self._cb_logger: LightningLogger | None = None
|
|
464
|
+
self.checksum: Any | None = None
|
|
465
|
+
self.grad_hooks: list[Any] = []
|
|
466
|
+
self.gradients: dict[str, tuple[Any, str]] = {}
|
|
467
|
+
self._hooks_registered = False
|
|
468
|
+
|
|
469
|
+
def _publish(self, file_name: str, path: str, data: dict[str, Any]) -> None:
|
|
470
|
+
with tempfile.TemporaryDirectory() as td:
|
|
471
|
+
output_file = os.path.join(self.output_path_prefix or td, file_name)
|
|
472
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
|
473
|
+
yaml.dump(data, f, sort_keys=False, indent=2, default_flow_style=False, allow_unicode=True, width=10**6)
|
|
474
|
+
self._cb_logger.log_artifact(output_file, path) # type: ignore[union-attr]
|
|
475
|
+
|
|
476
|
+
@override
|
|
477
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
478
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
479
|
+
self.stage = stage
|
|
480
|
+
if self.env_vars:
|
|
481
|
+
self._publish("env_vars.yaml", PATH_PREFIX, dict(os.environ))
|
|
482
|
+
if self.pip_freeze:
|
|
483
|
+
self._publish(
|
|
484
|
+
"pip_freeze.yaml",
|
|
485
|
+
PATH_PREFIX,
|
|
486
|
+
{p.metadata["Name"]: p.version for p in distributions()},
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
@override
|
|
490
|
+
def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
491
|
+
"""Register hooks at the start of training to capture all gradients including the first step"""
|
|
492
|
+
self._register_hooks(pl_module)
|
|
493
|
+
|
|
494
|
+
@override
|
|
495
|
+
def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
496
|
+
"""Remove hooks at the end of training"""
|
|
497
|
+
self._remove_hooks()
|
|
498
|
+
|
|
499
|
+
def _remove_hooks(self) -> None:
|
|
500
|
+
"""Remove all registered hooks"""
|
|
501
|
+
for hook in self.grad_hooks:
|
|
502
|
+
hook.remove()
|
|
503
|
+
self.grad_hooks = []
|
|
504
|
+
self._hooks_registered = False
|
|
505
|
+
|
|
506
|
+
def _register_hooks(self, pl_module: "L.LightningModule") -> None:
|
|
507
|
+
"""Register hooks on parameters to capture gradients"""
|
|
508
|
+
if self._hooks_registered or "grads" not in self.checksums:
|
|
509
|
+
return
|
|
510
|
+
for name, param in pl_module.named_parameters():
|
|
511
|
+
if param.requires_grad:
|
|
512
|
+
|
|
513
|
+
def hook(param_name: str, grad: torch.Tensor) -> torch.Tensor:
|
|
514
|
+
if grad is not None:
|
|
515
|
+
g = grad.detach()
|
|
516
|
+
grad_cks = tensor_checksum(g)
|
|
517
|
+
self.gradients[param_name] = (grad_cks, _format(g, grad_cks.hexdigest(), self.tensor_stats))
|
|
518
|
+
return grad
|
|
519
|
+
|
|
520
|
+
hook = param.register_hook(partial(hook, name))
|
|
521
|
+
self.grad_hooks.append(hook)
|
|
522
|
+
self._hooks_registered = True
|
|
523
|
+
|
|
524
|
+
@override
|
|
525
|
+
def on_train_batch_start(
|
|
526
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
527
|
+
) -> None:
|
|
528
|
+
if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
529
|
+
self.checksum = {}
|
|
530
|
+
|
|
531
|
+
@override
|
|
532
|
+
def on_train_batch_end(
|
|
533
|
+
self,
|
|
534
|
+
trainer: "L.Trainer",
|
|
535
|
+
pl_module: "L.LightningModule",
|
|
536
|
+
outputs: "STEP_OUTPUT",
|
|
537
|
+
batch: Any,
|
|
538
|
+
batch_idx: int,
|
|
539
|
+
) -> None:
|
|
540
|
+
self.gradients.clear()
|
|
541
|
+
if self.checksum is None:
|
|
542
|
+
return
|
|
543
|
+
if "batch" in self.checksums:
|
|
544
|
+
self.checksum["batch"] = _batch_checksums(batch, batch_idx, self.tensor_stats)
|
|
545
|
+
self._publish(
|
|
546
|
+
f"rank{trainer.global_rank}.yaml",
|
|
547
|
+
f"{PATH_PREFIX}/{self.stage}/step={trainer.global_step}",
|
|
548
|
+
self.checksum,
|
|
549
|
+
)
|
|
550
|
+
self.checksum = None
|
|
551
|
+
|
|
552
|
+
@override
|
|
553
|
+
def on_before_optimizer_step(
|
|
554
|
+
self,
|
|
555
|
+
trainer: "L.Trainer",
|
|
556
|
+
pl_module: "L.LightningModule",
|
|
557
|
+
optimizer: torch.optim.Optimizer,
|
|
558
|
+
) -> None:
|
|
559
|
+
if self.checksum is None:
|
|
560
|
+
return
|
|
561
|
+
self.checksum.update(
|
|
562
|
+
get_checksums(
|
|
563
|
+
trainer,
|
|
564
|
+
pl_module,
|
|
565
|
+
self.gradients,
|
|
566
|
+
self.checksums,
|
|
567
|
+
self.tensor_stats,
|
|
568
|
+
)
|
|
569
|
+
)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
import datetime as dt
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import fsspec
|
|
9
|
+
import lightning as L
|
|
10
|
+
|
|
11
|
+
from fkat.pytorch.schedule import Schedule, Never
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OptimizerSnapshot(L.Callback):
|
|
15
|
+
"""
|
|
16
|
+
Callback that saves optimizer state at specified intervals during training.
|
|
17
|
+
|
|
18
|
+
This callback allows you to capture the state of optimizers at specific points
|
|
19
|
+
during training, which can be useful for debugging, analysis, or resuming training
|
|
20
|
+
from specific optimization states.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
output_path_prefix (str): Output path prefix for generated optimizer snapshots.
|
|
24
|
+
schedule (Optional[Schedule]): Schedule at which to take a snapshot of optimizers.
|
|
25
|
+
Defaults to ``Never``
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
output_path_prefix: str,
|
|
31
|
+
schedule: Schedule | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.output_path_prefix = output_path_prefix
|
|
34
|
+
self.schedule = schedule or Never()
|
|
35
|
+
|
|
36
|
+
@override
|
|
37
|
+
def on_train_batch_start(
|
|
38
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
39
|
+
) -> None:
|
|
40
|
+
if self.schedule.check(trainer=trainer, stage="train", batch_idx=batch_idx, step=trainer.global_step):
|
|
41
|
+
timestamp = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")
|
|
42
|
+
for i, opt in enumerate(trainer.optimizers):
|
|
43
|
+
path = f"{self.output_path_prefix}rank{trainer.global_rank}_opt{i}_{timestamp}.pt"
|
|
44
|
+
with fsspec.open(path, "wb", makedirs=True) as f:
|
|
45
|
+
torch.save(opt, f)
|