fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import logging
|
|
5
|
+
import multiprocessing
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
11
|
+
|
|
12
|
+
from fkat.utils.cuda.xid import detect_xid_errors
|
|
13
|
+
from fkat.pytorch.actions import LightningAction
|
|
14
|
+
from fkat.pytorch.schedule import Schedule
|
|
15
|
+
from fkat.pytorch.utilities import local_rank_zero_only
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Xid(L.Callback):
|
|
21
|
+
"""
|
|
22
|
+
A callback to monitor and log Xid errors in a separate process during training.
|
|
23
|
+
|
|
24
|
+
It utilizes a separate process to monitor these errors, ensuring that the main training process remains unaffected.
|
|
25
|
+
The monitoring process is started at the beginning of training and terminated either
|
|
26
|
+
upon an exception in training or at the end of the training/validation stage.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
monitor: multiprocessing.Process | None = None
|
|
30
|
+
|
|
31
|
+
def __init__(self, actions: dict[str, LightningAction], schedule: Schedule) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Arguments:
|
|
34
|
+
actions: Dictionary mapping Xid ranges to actions
|
|
35
|
+
Format: {
|
|
36
|
+
"0-100": fkat.actions.log,
|
|
37
|
+
"13,43,63-64,48,79,95": fkat.actions.ec2.reboot,
|
|
38
|
+
"81": fkat.actions.ec2.terminate,
|
|
39
|
+
}
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.actions = self._parse_xid_ranges(actions)
|
|
43
|
+
self.schedule = schedule
|
|
44
|
+
self.xid_errors: multiprocessing.Queue[set[int]] = multiprocessing.Queue()
|
|
45
|
+
self.xid_check: multiprocessing.Event = multiprocessing.Event() # type: ignore[attr-defined]
|
|
46
|
+
|
|
47
|
+
def _parse_xid_ranges(self, xid_actions: dict[str, LightningAction]) -> dict[int, LightningAction]:
|
|
48
|
+
actions = {}
|
|
49
|
+
for xid_range, action in xid_actions.items():
|
|
50
|
+
parts = xid_range.split(",")
|
|
51
|
+
for part in parts:
|
|
52
|
+
part = part.strip()
|
|
53
|
+
is_range = "-" in part
|
|
54
|
+
if is_range:
|
|
55
|
+
start, end = map(int, part.split("-"))
|
|
56
|
+
for xid in range(start, end + 1):
|
|
57
|
+
actions[xid] = action
|
|
58
|
+
else:
|
|
59
|
+
try:
|
|
60
|
+
actions[int(part)] = action
|
|
61
|
+
except ValueError:
|
|
62
|
+
print(f"Warning: Invalid XID format: {part}")
|
|
63
|
+
return actions
|
|
64
|
+
|
|
65
|
+
@override
|
|
66
|
+
@local_rank_zero_only
|
|
67
|
+
def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Initializes the Xid error monitoring process at the start of the training stage.
|
|
70
|
+
|
|
71
|
+
This method is automatically invoked by the PyTorch Lightning framework. It starts
|
|
72
|
+
a separate background process dedicated to monitoring Xid errors.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
trainer (L.Trainer): The PyTorch Lightning Trainer instance.
|
|
76
|
+
module (L.LightningModule): The PyTorch Lightning module being trained.
|
|
77
|
+
stage (str): The stage of the training process (e.g., 'fit', 'test').
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
None.
|
|
81
|
+
"""
|
|
82
|
+
log.info("Checking for Xid errors ...")
|
|
83
|
+
self.monitor = multiprocessing.Process(target=detect_xid_errors, args=(self.xid_check, self.xid_errors))
|
|
84
|
+
self.monitor.start()
|
|
85
|
+
|
|
86
|
+
def _terminate_monitor(self) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Terminates the separate process used for monitoring Xid errors if it is alive.
|
|
89
|
+
|
|
90
|
+
This is an internal method that checks if the monitoring process is active and,
|
|
91
|
+
if so, terminates it to clean up resources.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
None.
|
|
95
|
+
"""
|
|
96
|
+
if self.monitor and self.monitor.is_alive():
|
|
97
|
+
log.info("Terminating Xid errors monitor")
|
|
98
|
+
self.monitor.kill()
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
@local_rank_zero_only
|
|
102
|
+
def on_exception(self, trainer: L.Trainer, pl_module: L.LightningModule, exception: BaseException) -> None:
|
|
103
|
+
"""
|
|
104
|
+
Callback method to handle exceptions during training.
|
|
105
|
+
|
|
106
|
+
If an exception occurs during the training process, this method ensures that the
|
|
107
|
+
Xid error monitoring process is terminated to prevent resource leakage.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
trainer (L.Trainer): The PyTorch Lightning Trainer instance.
|
|
111
|
+
pl_module (L.LightningModule): The PyTorch Lightning module being trained.
|
|
112
|
+
exception (BaseException): The exception that occurred during training.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
None.
|
|
116
|
+
"""
|
|
117
|
+
self._terminate_monitor()
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
@local_rank_zero_only
|
|
121
|
+
def teardown(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Ensures the Xid error monitoring process is terminated at the end of training.
|
|
124
|
+
|
|
125
|
+
This method is automatically called by the PyTorch Lightning framework at the
|
|
126
|
+
end of the training or validation stage to clean up the monitoring process.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
trainer (L.Trainer): The PyTorch Lightning Trainer instance.
|
|
130
|
+
module (L.LightningModule): The PyTorch Lightning module being trained.
|
|
131
|
+
stage (str): The stage of the training process (e.g., 'fit', 'test').
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
None.
|
|
135
|
+
"""
|
|
136
|
+
self._terminate_monitor()
|
|
137
|
+
|
|
138
|
+
@override
|
|
139
|
+
@rank_zero_only
|
|
140
|
+
def on_train_batch_start(
|
|
141
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int
|
|
142
|
+
) -> None:
|
|
143
|
+
self.check(trainer, "train", batch_idx)
|
|
144
|
+
|
|
145
|
+
@override
|
|
146
|
+
@rank_zero_only
|
|
147
|
+
def on_test_batch_start(
|
|
148
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
149
|
+
) -> None:
|
|
150
|
+
self.check(trainer, "test", batch_idx)
|
|
151
|
+
|
|
152
|
+
@override
|
|
153
|
+
@rank_zero_only
|
|
154
|
+
def on_validation_batch_start(
|
|
155
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
156
|
+
) -> None:
|
|
157
|
+
self.check(trainer, "validation", batch_idx)
|
|
158
|
+
|
|
159
|
+
@override
|
|
160
|
+
@rank_zero_only
|
|
161
|
+
def on_predict_batch_start(
|
|
162
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
163
|
+
) -> None:
|
|
164
|
+
self.check(trainer, "predict", batch_idx)
|
|
165
|
+
|
|
166
|
+
def check(self, trainer: L.Trainer, stage: str, batch_idx: int) -> None:
|
|
167
|
+
if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step):
|
|
168
|
+
self.xid_check.set()
|
|
169
|
+
while not self.xid_errors.empty():
|
|
170
|
+
xids = self.xid_errors.get()
|
|
171
|
+
for xid in xids:
|
|
172
|
+
if action := self.actions.get(xid):
|
|
173
|
+
action.perform(trainer=trainer, xid=xid)
|