fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
import atexit
|
|
6
|
+
import signal
|
|
7
|
+
import gzip
|
|
8
|
+
import shutil
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from typing import Any, TYPE_CHECKING
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
import lightning as L
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
19
|
+
|
|
20
|
+
from fkat.pytorch.schedule import (
|
|
21
|
+
Schedule,
|
|
22
|
+
Never,
|
|
23
|
+
)
|
|
24
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
25
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
26
|
+
|
|
27
|
+
memray = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Memray(L.Callback):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
ranks: Sequence[int] | None = None,
|
|
34
|
+
flamegraph: bool = False,
|
|
35
|
+
output_path_prefix: str | None = None,
|
|
36
|
+
schedule: Schedule | None = None,
|
|
37
|
+
compress: bool = False,
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
[Memray](https://bloomberg.github.io/memray/api.html) PyTorch Lightning callback.
|
|
42
|
+
This callbacks traces host RAM (DRAM) allocations and publishes a report to help identify
|
|
43
|
+
potential memory leaks and investigate OOM errors.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
|
|
47
|
+
flamegraph (bool): whether to generate [Flamegraph](https://www.brendangregg.com/flamegraphs.html)
|
|
48
|
+
for the traced allocations, generates HTML report that van be viewed without installing `memray`
|
|
49
|
+
output_path_prefix (Optional[str]): output path prefix for generated reports,
|
|
50
|
+
use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
|
|
51
|
+
schedule (Optional[Schedule]): Controls when logging occurs during training.
|
|
52
|
+
Defaults to Never - no logging
|
|
53
|
+
compress (bool): publish reports as compressed files defaults to publishing raw files
|
|
54
|
+
"""
|
|
55
|
+
self.ranks = ranks
|
|
56
|
+
self.flamegraph = flamegraph
|
|
57
|
+
self.compress = compress
|
|
58
|
+
self.rank: int | None = None
|
|
59
|
+
self.stage: str | None = None
|
|
60
|
+
self.kwargs = kwargs
|
|
61
|
+
|
|
62
|
+
self.output_path_prefix = output_path_prefix
|
|
63
|
+
self.schedule = schedule or Never()
|
|
64
|
+
self._cb_logger: LightningLogger | None = None
|
|
65
|
+
|
|
66
|
+
self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Memray")
|
|
67
|
+
|
|
68
|
+
global memray
|
|
69
|
+
import memray # type: ignore[unresolved-import]
|
|
70
|
+
|
|
71
|
+
self.tracker: memray.Tracker | None = None # type: ignore
|
|
72
|
+
self.dir = self.tmp_dir = "/tmp"
|
|
73
|
+
|
|
74
|
+
signal.signal(signal.SIGTERM, self._terminate) # terminate signal
|
|
75
|
+
signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
|
|
76
|
+
atexit.register(self._terminate)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
80
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
81
|
+
self.rank = trainer.global_rank
|
|
82
|
+
self.stage = stage
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
86
|
+
self._start(trainer)
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def on_train_batch_end(
|
|
90
|
+
self,
|
|
91
|
+
trainer: "L.Trainer",
|
|
92
|
+
pl_module: "L.LightningModule",
|
|
93
|
+
outputs: "STEP_OUTPUT",
|
|
94
|
+
batch: Any,
|
|
95
|
+
batch_idx: int,
|
|
96
|
+
) -> None:
|
|
97
|
+
self._on_batch_end(trainer, "train", batch_idx + 1)
|
|
98
|
+
|
|
99
|
+
@override
|
|
100
|
+
def on_validation_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
101
|
+
self._start(trainer)
|
|
102
|
+
|
|
103
|
+
@override
|
|
104
|
+
def on_validation_batch_end(
|
|
105
|
+
self,
|
|
106
|
+
trainer: "L.Trainer",
|
|
107
|
+
pl_module: "L.LightningModule",
|
|
108
|
+
outputs: "STEP_OUTPUT",
|
|
109
|
+
batch: Any,
|
|
110
|
+
batch_idx: int,
|
|
111
|
+
dataloader_idx: int = 0,
|
|
112
|
+
) -> None:
|
|
113
|
+
self._on_batch_end(trainer, "validation", batch_idx + 1)
|
|
114
|
+
|
|
115
|
+
@override
|
|
116
|
+
def on_predict_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
117
|
+
self._start(trainer)
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
def on_predict_batch_end(
|
|
121
|
+
self,
|
|
122
|
+
trainer: "L.Trainer",
|
|
123
|
+
pl_module: "L.LightningModule",
|
|
124
|
+
outputs: Any,
|
|
125
|
+
batch: Any,
|
|
126
|
+
batch_idx: int,
|
|
127
|
+
dataloader_idx: int = 0,
|
|
128
|
+
) -> None:
|
|
129
|
+
self._on_batch_end(trainer, "predict", batch_idx + 1)
|
|
130
|
+
|
|
131
|
+
@override
|
|
132
|
+
def on_test_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
133
|
+
self._start(trainer)
|
|
134
|
+
|
|
135
|
+
@override
|
|
136
|
+
def on_test_batch_end(
|
|
137
|
+
self,
|
|
138
|
+
trainer: "L.Trainer",
|
|
139
|
+
pl_module: "L.LightningModule",
|
|
140
|
+
outputs: "STEP_OUTPUT",
|
|
141
|
+
batch: Any,
|
|
142
|
+
batch_idx: int,
|
|
143
|
+
dataloader_idx: int = 0,
|
|
144
|
+
) -> None:
|
|
145
|
+
self._on_batch_end(trainer, "test", batch_idx + 1)
|
|
146
|
+
|
|
147
|
+
def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
|
|
148
|
+
if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
149
|
+
self._stop(str(batch_idx))
|
|
150
|
+
self._start(trainer)
|
|
151
|
+
|
|
152
|
+
@override
|
|
153
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
154
|
+
self._terminate()
|
|
155
|
+
|
|
156
|
+
@override
|
|
157
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
158
|
+
self._terminate()
|
|
159
|
+
|
|
160
|
+
def _terminate(self, *args: Any, **kwargs: Any) -> None:
|
|
161
|
+
# calling synchronously since this can be called during interpreter shutdown
|
|
162
|
+
self._stop("last", sync=True)
|
|
163
|
+
self.executor.shutdown()
|
|
164
|
+
|
|
165
|
+
def _start(self, trainer: "L.Trainer") -> None:
|
|
166
|
+
if self.ranks is not None and trainer.global_rank not in self.ranks:
|
|
167
|
+
return
|
|
168
|
+
if not self.tracker:
|
|
169
|
+
self.tmp_dir = tempfile.mkdtemp()
|
|
170
|
+
self.dir = self.output_path_prefix or self.tmp_dir
|
|
171
|
+
path = os.path.join(self.dir, f"rank{trainer.global_rank}.bin")
|
|
172
|
+
assert memray
|
|
173
|
+
self.tracker = memray.Tracker(path, **self.kwargs)
|
|
174
|
+
self.tracker.__enter__()
|
|
175
|
+
assert self.tracker
|
|
176
|
+
|
|
177
|
+
def _stop(self, suffix: str, sync: bool = False) -> None:
|
|
178
|
+
if not self.tracker:
|
|
179
|
+
return
|
|
180
|
+
# create reports synchronously
|
|
181
|
+
self.tracker.__exit__(None, None, None)
|
|
182
|
+
self.tracker = None
|
|
183
|
+
if self.flamegraph:
|
|
184
|
+
for f in os.listdir(self.dir):
|
|
185
|
+
results = os.path.join(self.dir, f)
|
|
186
|
+
from memray.commands.flamegraph import FlamegraphCommand # type: ignore[unresolved-import]
|
|
187
|
+
|
|
188
|
+
# creating this report synchronously because it uses a global memray lock
|
|
189
|
+
FlamegraphCommand().write_report(Path(results), Path(results + ".html"), True, -1, False)
|
|
190
|
+
# process reports asynchronously
|
|
191
|
+
artifacts_path = f"memray/{self.stage}/{suffix}"
|
|
192
|
+
if sync:
|
|
193
|
+
self._process(artifacts_path, self.dir, self.tmp_dir)
|
|
194
|
+
else:
|
|
195
|
+
self.executor.submit(self._process, artifacts_path, self.dir, self.tmp_dir)
|
|
196
|
+
|
|
197
|
+
def _process(
|
|
198
|
+
self,
|
|
199
|
+
artifacts_path: str,
|
|
200
|
+
report_dir: str,
|
|
201
|
+
tmp_dir: str,
|
|
202
|
+
) -> None:
|
|
203
|
+
assert self._cb_logger
|
|
204
|
+
for f in os.listdir(report_dir):
|
|
205
|
+
output_file = os.path.join(report_dir, f)
|
|
206
|
+
if self.compress:
|
|
207
|
+
with open(output_file, "rb") as f_in:
|
|
208
|
+
output_file = output_file + ".gz"
|
|
209
|
+
with gzip.open(output_file, "wb") as f_out:
|
|
210
|
+
shutil.copyfileobj(f_in, f_out)
|
|
211
|
+
self._cb_logger.log_artifact(output_file, artifacts_path)
|
|
212
|
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import gzip
|
|
5
|
+
import shutil
|
|
6
|
+
import tempfile
|
|
7
|
+
import atexit
|
|
8
|
+
import signal
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from typing import Any, TYPE_CHECKING
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
import lightning as L
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
19
|
+
|
|
20
|
+
from fkat.pytorch.schedule import (
|
|
21
|
+
Schedule,
|
|
22
|
+
Never,
|
|
23
|
+
)
|
|
24
|
+
from fkat.pytorch.utilities import get_rank
|
|
25
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
26
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PyTorch(L.Callback):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
ranks: Sequence[int] | None = None,
|
|
33
|
+
output_path_prefix: str | None = None,
|
|
34
|
+
schedule: Schedule | None = None,
|
|
35
|
+
compress: bool = True,
|
|
36
|
+
**kwargs: Any,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
[PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) Lightning callback.
|
|
40
|
+
This :class:`L.Callback` continiously traces the training process and publishes a report
|
|
41
|
+
that helps examining the duration of individual calls through time.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
|
|
45
|
+
output_path_prefix (Optional[str]): output path prefix for generated reports,
|
|
46
|
+
use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
|
|
47
|
+
schedule (Optional[Schedule]): Controls when logging occurs during training.
|
|
48
|
+
Defaults to :class:`Never` - no intermediate logging
|
|
49
|
+
compress (bool): compress the report
|
|
50
|
+
Defaults to ``True``
|
|
51
|
+
**kwargs (Any): Arbitrary keyword arguments passed as is to PyTorch Profiler
|
|
52
|
+
except for ``execution_trace_observer`` and ``on_trace_ready``.
|
|
53
|
+
"""
|
|
54
|
+
self.rank = get_rank()
|
|
55
|
+
self.compress = compress
|
|
56
|
+
self.schedule = schedule or Never()
|
|
57
|
+
self.output_path_prefix = output_path_prefix
|
|
58
|
+
|
|
59
|
+
self.trace_observer: torch.profiler.ExecutionTraceObserver | None = None
|
|
60
|
+
self.trace_file: str | None
|
|
61
|
+
self.profiler: torch.profiler.profile | None = None
|
|
62
|
+
if ranks is None or self.rank in ranks:
|
|
63
|
+
self.trace_file = os.path.join(self.output_path_prefix or tempfile.mkdtemp(), f"rank{self.rank}.json")
|
|
64
|
+
self.trace_observer = torch.profiler.ExecutionTraceObserver()
|
|
65
|
+
kwargs.pop("execution_trace_observer", None)
|
|
66
|
+
kwargs.pop("on_trace_ready", None)
|
|
67
|
+
self.profiler = torch.profiler.profile(
|
|
68
|
+
schedule=lambda step: torch.profiler.ProfilerAction.RECORD_AND_SAVE,
|
|
69
|
+
on_trace_ready=self._publish,
|
|
70
|
+
execution_trace_observer=self.trace_observer,
|
|
71
|
+
**kwargs,
|
|
72
|
+
)
|
|
73
|
+
self._start_profiler()
|
|
74
|
+
self._cb_logger: LightningLogger | None = None
|
|
75
|
+
self.stage: str | None = None
|
|
76
|
+
self.batch_idx = "?"
|
|
77
|
+
self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="PyTorchProfiler")
|
|
78
|
+
|
|
79
|
+
signal.signal(signal.SIGTERM, self._terminate) # terminate signal
|
|
80
|
+
signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
|
|
81
|
+
atexit.register(self._terminate)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
85
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
86
|
+
self.stage = stage
|
|
87
|
+
|
|
88
|
+
def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
|
|
89
|
+
if self.profiler and self.schedule.check(
|
|
90
|
+
stage=stage, batch_idx=batch_idx + 1, step=trainer.global_step if trainer else None, trainer=trainer
|
|
91
|
+
):
|
|
92
|
+
self.batch_idx = str(batch_idx + 1)
|
|
93
|
+
self.profiler.step()
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def on_train_batch_end(
|
|
97
|
+
self,
|
|
98
|
+
trainer: "L.Trainer",
|
|
99
|
+
pl_module: "L.LightningModule",
|
|
100
|
+
outputs: Any,
|
|
101
|
+
batch: Any,
|
|
102
|
+
batch_idx: int,
|
|
103
|
+
) -> None:
|
|
104
|
+
self._on_batch_end(trainer, "train", batch_idx)
|
|
105
|
+
|
|
106
|
+
@override
|
|
107
|
+
def on_validation_batch_end(
|
|
108
|
+
self,
|
|
109
|
+
trainer: "L.Trainer",
|
|
110
|
+
pl_module: "L.LightningModule",
|
|
111
|
+
outputs: "STEP_OUTPUT",
|
|
112
|
+
batch: Any,
|
|
113
|
+
batch_idx: int,
|
|
114
|
+
dataloader_idx: int = 0,
|
|
115
|
+
) -> None:
|
|
116
|
+
self._on_batch_end(trainer, "validation", batch_idx)
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
def on_predict_batch_end(
|
|
120
|
+
self,
|
|
121
|
+
trainer: "L.Trainer",
|
|
122
|
+
pl_module: "L.LightningModule",
|
|
123
|
+
outputs: Any,
|
|
124
|
+
batch: Any,
|
|
125
|
+
batch_idx: int,
|
|
126
|
+
dataloader_idx: int = 0,
|
|
127
|
+
) -> None:
|
|
128
|
+
self._on_batch_end(trainer, "predict", batch_idx)
|
|
129
|
+
|
|
130
|
+
@override
|
|
131
|
+
def on_test_batch_end(
|
|
132
|
+
self,
|
|
133
|
+
trainer: "L.Trainer",
|
|
134
|
+
pl_module: "L.LightningModule",
|
|
135
|
+
outputs: "STEP_OUTPUT",
|
|
136
|
+
batch: Any,
|
|
137
|
+
batch_idx: int,
|
|
138
|
+
dataloader_idx: int = 0,
|
|
139
|
+
) -> None:
|
|
140
|
+
self._on_batch_end(trainer, "test", batch_idx)
|
|
141
|
+
|
|
142
|
+
def _publish(self, prof: torch.profiler.profile) -> None:
|
|
143
|
+
# create report synchronously
|
|
144
|
+
assert self.trace_file
|
|
145
|
+
prof.export_chrome_trace(self.trace_file)
|
|
146
|
+
base_path = self.output_path_prefix or os.path.dirname(self.trace_file)
|
|
147
|
+
output_file = os.path.join(base_path, self.batch_idx, os.path.basename(self.trace_file))
|
|
148
|
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
149
|
+
shutil.move(self.trace_file, output_file)
|
|
150
|
+
self._start_profiler()
|
|
151
|
+
# process report asynchronously
|
|
152
|
+
artifact_path = f"pt_profiler/{self.stage}/{self.batch_idx}"
|
|
153
|
+
sync = self.profiler is None
|
|
154
|
+
if sync:
|
|
155
|
+
# calling synchronously since this can be called during interpreter shutdown
|
|
156
|
+
self._process(output_file, artifact_path)
|
|
157
|
+
else:
|
|
158
|
+
self.executor.submit(self._process, output_file, artifact_path)
|
|
159
|
+
|
|
160
|
+
def _process(self, output_file: str, artifacts_path: str) -> None:
|
|
161
|
+
assert self._cb_logger
|
|
162
|
+
if self.compress:
|
|
163
|
+
with open(output_file, "rb") as f_in:
|
|
164
|
+
output_file = output_file + ".gz"
|
|
165
|
+
with gzip.open(output_file, "wb") as f_out:
|
|
166
|
+
shutil.copyfileobj(f_in, f_out)
|
|
167
|
+
self._cb_logger.log_artifact(output_file, artifacts_path)
|
|
168
|
+
shutil.rmtree(output_file, ignore_errors=True)
|
|
169
|
+
|
|
170
|
+
@override
|
|
171
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
172
|
+
self._terminate()
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
176
|
+
self._terminate()
|
|
177
|
+
|
|
178
|
+
def _terminate(self, *_: Any) -> None:
|
|
179
|
+
if self.profiler and self.stage:
|
|
180
|
+
self.batch_idx = "last"
|
|
181
|
+
self._stop_profiler()
|
|
182
|
+
self.profiler = None
|
|
183
|
+
self.executor.shutdown()
|
|
184
|
+
if self.trace_file:
|
|
185
|
+
shutil.rmtree(self.trace_file, ignore_errors=True)
|
|
186
|
+
|
|
187
|
+
def _start_profiler(self) -> None:
|
|
188
|
+
assert self.trace_file and self.trace_observer and self.profiler
|
|
189
|
+
shutil.rmtree(self.trace_file, ignore_errors=True)
|
|
190
|
+
self.trace_observer.register_callback(self.trace_file)
|
|
191
|
+
self.profiler.start()
|
|
192
|
+
|
|
193
|
+
def _stop_profiler(self) -> None:
|
|
194
|
+
if self.profiler:
|
|
195
|
+
self.profiler.stop()
|
|
196
|
+
if self.trace_observer:
|
|
197
|
+
self.trace_observer.unregister_callback()
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
import gzip
|
|
6
|
+
import shutil
|
|
7
|
+
import tempfile
|
|
8
|
+
import atexit
|
|
9
|
+
import signal
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
+
from typing import Any, TYPE_CHECKING
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
import lightning as L
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
19
|
+
|
|
20
|
+
from fkat.pytorch.schedule import (
|
|
21
|
+
Schedule,
|
|
22
|
+
Never,
|
|
23
|
+
)
|
|
24
|
+
from fkat.pytorch.utilities import get_rank
|
|
25
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
26
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
import viztracer
|
|
30
|
+
else:
|
|
31
|
+
viztracer = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class VizTracer(L.Callback):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
ranks: Sequence[int] | None = None,
|
|
38
|
+
output_path_prefix: str | None = None,
|
|
39
|
+
schedule: Schedule | None = None,
|
|
40
|
+
compress: bool = False,
|
|
41
|
+
patch: bool = False,
|
|
42
|
+
**kwargs: Any,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
[VizTracer](https://viztracer.readthedocs.io/en/latest/) PyTorch Lightning callback.
|
|
46
|
+
This :class:`L.Callback` continiously traces the training process and publishes a report
|
|
47
|
+
that helps examining the duration of individual calls through time.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
|
|
51
|
+
output_path_prefix (Optional[str]): output path prefix for generated reports,
|
|
52
|
+
use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
|
|
53
|
+
schedule (Optional[Schedule]): Controls when logging occurs during training.
|
|
54
|
+
Defaults to :class:`Never` - no logging
|
|
55
|
+
compress (bool): publish reports as compressed binaries
|
|
56
|
+
(need to be decompressed via `viztracer --decompress <REPORT>`),
|
|
57
|
+
if ``True``` saves reports using viztracer's own compression that requires `viztracer` installation,
|
|
58
|
+
defaults to ``False`` and publishes gzipped HTML reports which require no `viztracer` installation
|
|
59
|
+
patch (bool): whether to let VizTracer patch internal Python hooks: subprocess, multiprocessing, etc.
|
|
60
|
+
Defaults to ``False``
|
|
61
|
+
**kwargs (Any): Arbitrary keyword arguments passed as is to VizTracer.
|
|
62
|
+
"""
|
|
63
|
+
self.rank = get_rank()
|
|
64
|
+
self.schedule = schedule or Never()
|
|
65
|
+
self.output_path_prefix = output_path_prefix
|
|
66
|
+
|
|
67
|
+
global viztracer
|
|
68
|
+
import viztracer
|
|
69
|
+
from viztracer.vcompressor import VCompressor
|
|
70
|
+
|
|
71
|
+
self.compressor = VCompressor() if compress else None
|
|
72
|
+
|
|
73
|
+
self.tracer: viztracer.VizTracer | None = None # type: ignore[no-any-unimported]
|
|
74
|
+
if ranks is None or self.rank in ranks:
|
|
75
|
+
kwargs["output_file"] = f"rank{self.rank}.json"
|
|
76
|
+
kwargs["verbose"] = 0
|
|
77
|
+
self.tracer = viztracer.VizTracer(**kwargs)
|
|
78
|
+
assert self.tracer
|
|
79
|
+
if patch:
|
|
80
|
+
args = [v for k, v in kwargs.items() for v in ("--" * min(2, len(k)) + k, v)]
|
|
81
|
+
from viztracer.patch import install_all_hooks
|
|
82
|
+
|
|
83
|
+
install_all_hooks(self.tracer, args)
|
|
84
|
+
self.tracer.start()
|
|
85
|
+
self._cb_logger: LightningLogger | None = None
|
|
86
|
+
self.stage: str | None = None
|
|
87
|
+
self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="VizTracer")
|
|
88
|
+
|
|
89
|
+
signal.signal(signal.SIGTERM, self._terminate) # terminate signal
|
|
90
|
+
signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
|
|
91
|
+
atexit.register(self._terminate)
|
|
92
|
+
|
|
93
|
+
@override
|
|
94
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
95
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
96
|
+
self.stage = stage
|
|
97
|
+
|
|
98
|
+
def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
|
|
99
|
+
if self.tracer and self.schedule.check(
|
|
100
|
+
stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer
|
|
101
|
+
):
|
|
102
|
+
self._publish(str(batch_idx))
|
|
103
|
+
self.tracer.start()
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def on_train_batch_end(
|
|
107
|
+
self,
|
|
108
|
+
trainer: "L.Trainer",
|
|
109
|
+
pl_module: "L.LightningModule",
|
|
110
|
+
outputs: Any,
|
|
111
|
+
batch: Any,
|
|
112
|
+
batch_idx: int,
|
|
113
|
+
) -> None:
|
|
114
|
+
self._on_batch_end(trainer, "train", batch_idx + 1)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def on_validation_batch_end(
|
|
118
|
+
self,
|
|
119
|
+
trainer: "L.Trainer",
|
|
120
|
+
pl_module: "L.LightningModule",
|
|
121
|
+
outputs: "STEP_OUTPUT",
|
|
122
|
+
batch: Any,
|
|
123
|
+
batch_idx: int,
|
|
124
|
+
dataloader_idx: int = 0,
|
|
125
|
+
) -> None:
|
|
126
|
+
self._on_batch_end(trainer, "validation", batch_idx + 1)
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def on_predict_batch_end(
|
|
130
|
+
self,
|
|
131
|
+
trainer: "L.Trainer",
|
|
132
|
+
pl_module: "L.LightningModule",
|
|
133
|
+
outputs: Any,
|
|
134
|
+
batch: Any,
|
|
135
|
+
batch_idx: int,
|
|
136
|
+
dataloader_idx: int = 0,
|
|
137
|
+
) -> None:
|
|
138
|
+
self._on_batch_end(trainer, "predict", batch_idx + 1)
|
|
139
|
+
|
|
140
|
+
@override
|
|
141
|
+
def on_test_batch_end(
|
|
142
|
+
self,
|
|
143
|
+
trainer: "L.Trainer",
|
|
144
|
+
pl_module: "L.LightningModule",
|
|
145
|
+
outputs: "STEP_OUTPUT",
|
|
146
|
+
batch: Any,
|
|
147
|
+
batch_idx: int,
|
|
148
|
+
dataloader_idx: int = 0,
|
|
149
|
+
) -> None:
|
|
150
|
+
self._on_batch_end(trainer, "test", batch_idx + 1)
|
|
151
|
+
|
|
152
|
+
def _publish(self, suffix: str, sync: bool = False) -> None:
|
|
153
|
+
assert self.tracer
|
|
154
|
+
self.tracer.stop()
|
|
155
|
+
# create report synchronously
|
|
156
|
+
tmp_dir = tempfile.mkdtemp()
|
|
157
|
+
output_stem = os.path.join(self.output_path_prefix or tmp_dir, suffix, f"rank{self.rank}")
|
|
158
|
+
output_file = output_stem + (".json" if self.compressor else ".html")
|
|
159
|
+
self.tracer.save(output_file=output_file, verbose=0)
|
|
160
|
+
self.tracer.clear()
|
|
161
|
+
# process report asynchronously
|
|
162
|
+
artifact_path = f"viztracer/{self.stage}/{suffix}"
|
|
163
|
+
if sync:
|
|
164
|
+
self._process(tmp_dir, output_file, artifact_path)
|
|
165
|
+
else:
|
|
166
|
+
self.executor.submit(self._process, tmp_dir, output_file, artifact_path)
|
|
167
|
+
|
|
168
|
+
def _process(self, tmp_dir: str, output_file: str, artifacts_path: str) -> None:
|
|
169
|
+
if not self.compressor:
|
|
170
|
+
with open(output_file, "rb") as f_in:
|
|
171
|
+
output_file = output_file + ".gz"
|
|
172
|
+
with gzip.open(output_file, "wb") as f_out:
|
|
173
|
+
shutil.copyfileobj(f_in, f_out)
|
|
174
|
+
else:
|
|
175
|
+
with open(output_file) as f:
|
|
176
|
+
data = json.load(f)
|
|
177
|
+
output_file = os.path.splitext(output_file)[0] + ".cvf"
|
|
178
|
+
self.compressor.compress(data, output_file)
|
|
179
|
+
assert self._cb_logger
|
|
180
|
+
self._cb_logger.log_artifact(output_file, artifacts_path)
|
|
181
|
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
182
|
+
|
|
183
|
+
@override
|
|
184
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
185
|
+
self._terminate()
|
|
186
|
+
|
|
187
|
+
@override
|
|
188
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
189
|
+
self._terminate()
|
|
190
|
+
|
|
191
|
+
def _terminate(self, *_: Any) -> None:
|
|
192
|
+
if self.tracer and self.stage:
|
|
193
|
+
# calling synchronously since this can be called during interpreter shutdown
|
|
194
|
+
self._publish("last", sync=True)
|
|
195
|
+
self.tracer.terminate()
|
|
196
|
+
self.tracer = None
|
|
197
|
+
self.executor.shutdown()
|