fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import random
|
|
6
|
+
import signal
|
|
7
|
+
import logging
|
|
8
|
+
import ast
|
|
9
|
+
import multiprocessing
|
|
10
|
+
import datetime as dt
|
|
11
|
+
from typing import Any
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
import lightning as L
|
|
15
|
+
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
|
|
16
|
+
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
|
17
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
18
|
+
|
|
19
|
+
from fkat.pytorch.schedule import (
|
|
20
|
+
Schedule,
|
|
21
|
+
Elapsed,
|
|
22
|
+
)
|
|
23
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
24
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
25
|
+
|
|
26
|
+
log = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def start_shutdown_detection_process(
|
|
30
|
+
logger: LightningLogger | None,
|
|
31
|
+
shutdown_tag: str,
|
|
32
|
+
trainer: "L.Trainer",
|
|
33
|
+
) -> multiprocessing.Process | None:
|
|
34
|
+
"""
|
|
35
|
+
Create a process for monitoring trainer errors by periodically detecting
|
|
36
|
+
the ``shutdown`` tag. Terminate the application if the tag is detected.
|
|
37
|
+
The process is spawn on local rank 0 to minimize the overhead.
|
|
38
|
+
"""
|
|
39
|
+
process: multiprocessing.Process | None = None
|
|
40
|
+
if logger is not None and trainer.local_rank == 0:
|
|
41
|
+
log.info("Starting a shutdown detection process...")
|
|
42
|
+
process = multiprocessing.Process(
|
|
43
|
+
target=detect_shutdown_from_logger, args=(logger, shutdown_tag, os.getpid(), 60)
|
|
44
|
+
)
|
|
45
|
+
process.daemon = True
|
|
46
|
+
process.start()
|
|
47
|
+
return process
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def detect_shutdown_from_logger(
|
|
51
|
+
logger: LightningLogger,
|
|
52
|
+
shutdown_tag: str,
|
|
53
|
+
pid: int,
|
|
54
|
+
detection_interval_secs: int,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Detect ``shutdown_tag`` tag periodically. If the tag is found, sends a SIGABRT to the
|
|
58
|
+
process with the provided pid to shutdown the training process.
|
|
59
|
+
"""
|
|
60
|
+
sleep_duration = int(os.getenv("SHUTDOWN_DETECTION_INTERVAL", default=str(detection_interval_secs)))
|
|
61
|
+
log.debug(f"Shutdown detection frequency is {sleep_duration} secs")
|
|
62
|
+
try:
|
|
63
|
+
while True:
|
|
64
|
+
random_delay = random.uniform(0, sleep_duration * 0.5)
|
|
65
|
+
time.sleep(random_delay)
|
|
66
|
+
tags = logger.tags()
|
|
67
|
+
if shutdown_tag in tags:
|
|
68
|
+
log.info(f"Found {shutdown_tag}={tags[shutdown_tag]} tag. Shutting down process {pid}.")
|
|
69
|
+
os.kill(pid, signal.SIGABRT)
|
|
70
|
+
time.sleep(sleep_duration - random_delay)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
log.error(f"Got error when querying mlflow SHUTDOWN tag: {e}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class GracefulShutdown(L.Callback):
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
schedule: Schedule | None = None,
|
|
79
|
+
shutdown_tag: str = "shutdown",
|
|
80
|
+
shutdown_info_tag: str = "shutdown_info",
|
|
81
|
+
) -> None:
|
|
82
|
+
self.shutdown_tag = shutdown_tag
|
|
83
|
+
self.shutdown_info_tag = shutdown_info_tag
|
|
84
|
+
self.schedule = schedule or Elapsed(dt.timedelta(minutes=5))
|
|
85
|
+
self._cb_logger: LightningLogger | None = None
|
|
86
|
+
self._process: multiprocessing.Process | None = None
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
90
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
91
|
+
self._process = start_shutdown_detection_process(self._cb_logger, self.shutdown_tag, trainer)
|
|
92
|
+
|
|
93
|
+
def _maybe_stop(self, stage: str, trainer: "L.Trainer", batch_idx: int) -> None:
|
|
94
|
+
if (
|
|
95
|
+
trainer.should_stop
|
|
96
|
+
or not self._cb_logger
|
|
97
|
+
or not self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer)
|
|
98
|
+
):
|
|
99
|
+
return
|
|
100
|
+
tags = self._cb_logger.tags()
|
|
101
|
+
shutdown_tag = tags.get(self.shutdown_tag)
|
|
102
|
+
trainer.should_stop = shutdown_tag is not None
|
|
103
|
+
if trainer.should_stop:
|
|
104
|
+
info_tag = tags.get(self.shutdown_info_tag)
|
|
105
|
+
if info_tag:
|
|
106
|
+
info = ast.literal_eval(info_tag)[-1]
|
|
107
|
+
strategy = info["Strategy"].upper()
|
|
108
|
+
log.info(f"Shutdown signal received. Using shutdown strategy {strategy}")
|
|
109
|
+
self._cb_logger.log_tag(self.shutdown_tag, "SHUTTING_DOWN")
|
|
110
|
+
|
|
111
|
+
def _update_shutdown_status(self, trainer: "L.Trainer", status: str) -> None:
|
|
112
|
+
if self._cb_logger:
|
|
113
|
+
log.info(f"update shutdown status {status} indicate job finished.")
|
|
114
|
+
self._cb_logger.log_tag(self.shutdown_tag, status)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
@rank_zero_only
|
|
118
|
+
def on_train_batch_start(
|
|
119
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
120
|
+
) -> None:
|
|
121
|
+
self._maybe_stop("train", trainer, batch_idx)
|
|
122
|
+
|
|
123
|
+
@rank_zero_only
|
|
124
|
+
def on_test_batch_start(
|
|
125
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
126
|
+
) -> None:
|
|
127
|
+
self._maybe_stop("test", trainer, batch_idx)
|
|
128
|
+
|
|
129
|
+
@override
|
|
130
|
+
@rank_zero_only
|
|
131
|
+
def on_validation_batch_start(
|
|
132
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
133
|
+
) -> None:
|
|
134
|
+
self._maybe_stop("validation", trainer, batch_idx)
|
|
135
|
+
|
|
136
|
+
@rank_zero_only
|
|
137
|
+
def on_predict_batch_start(
|
|
138
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
139
|
+
) -> None:
|
|
140
|
+
self._maybe_stop("predict", trainer, batch_idx)
|
|
141
|
+
|
|
142
|
+
@override
|
|
143
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
144
|
+
if trainer.global_rank == 0:
|
|
145
|
+
if not self._tuning(trainer):
|
|
146
|
+
log.info("update job status before job finished")
|
|
147
|
+
self._update_shutdown_status(trainer, "JOB_FINISHED")
|
|
148
|
+
self._terminate_monitor()
|
|
149
|
+
|
|
150
|
+
@override
|
|
151
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
152
|
+
self._terminate_monitor()
|
|
153
|
+
|
|
154
|
+
def _terminate_monitor(self) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Terminates the separate process used for monitoring trainer errors if it is alive.
|
|
157
|
+
"""
|
|
158
|
+
if self._process and self._process.is_alive():
|
|
159
|
+
log.info("\nTerminating error monitor...")
|
|
160
|
+
self._process.kill()
|
|
161
|
+
|
|
162
|
+
def _tuning(self, trainer: "L.Trainer") -> bool:
|
|
163
|
+
num_tuning_cbs = sum(
|
|
164
|
+
isinstance(
|
|
165
|
+
cb,
|
|
166
|
+
LearningRateFinder | BatchSizeFinder,
|
|
167
|
+
)
|
|
168
|
+
for cb in trainer.callbacks # type: ignore[attr-defined]
|
|
169
|
+
)
|
|
170
|
+
return num_tuning_cbs > 0
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .viztracer import VizTracer
|
|
4
|
+
from .memray import Memray
|
|
5
|
+
from .flops import Flops
|
|
6
|
+
from .torch import PyTorch
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"PyTorch",
|
|
10
|
+
"VizTracer",
|
|
11
|
+
"Memray",
|
|
12
|
+
"Flops",
|
|
13
|
+
]
|