fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import pickle
|
|
6
|
+
import tempfile
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
import lightning as L
|
|
12
|
+
import torch
|
|
13
|
+
from torch.cuda import memory
|
|
14
|
+
|
|
15
|
+
from fkat.pytorch.schedule import (
|
|
16
|
+
Schedule,
|
|
17
|
+
Never,
|
|
18
|
+
)
|
|
19
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
20
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
21
|
+
from fkat.utils import safe_timestamp
|
|
22
|
+
|
|
23
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _artifact_path(root_dir: str, rank: int, file_type: str, ext: str) -> tuple[str, str]:
|
|
27
|
+
base_dir = os.path.join(root_dir, "torch.cuda.memory")
|
|
28
|
+
timestamp = safe_timestamp()
|
|
29
|
+
file_path = os.path.join(base_dir, f"rank{rank}/{file_type}/rank{rank}_{timestamp}.{ext}")
|
|
30
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
31
|
+
return base_dir, file_path
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _reset_recording(kwargs: dict[str, Any]) -> None:
|
|
35
|
+
if torch.cuda.is_available():
|
|
36
|
+
memory._record_memory_history(enabled=None)
|
|
37
|
+
# set the limitation of ring buffer ~100 G. Otherwise, the buffer might be too large and trigger CPU OOM.
|
|
38
|
+
kwargs.setdefault("max_entries", 1000000)
|
|
39
|
+
memory._record_memory_history(**kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _detect_tensor_cycles(cb_logger: CallbackLogger, rank: int) -> None:
|
|
43
|
+
from torch.utils.viz import _cycles
|
|
44
|
+
|
|
45
|
+
def is_cuda_tensor(obj: Any) -> bool:
|
|
46
|
+
try:
|
|
47
|
+
return (
|
|
48
|
+
isinstance(obj, torch.Tensor)
|
|
49
|
+
and obj.device.type == "cuda"
|
|
50
|
+
and not isinstance(obj, torch._subclasses.FakeTensor)
|
|
51
|
+
)
|
|
52
|
+
except: # noqa: E722
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
_cycles.is_cuda_tensor = is_cuda_tensor # type: ignore[invalid-assignment]
|
|
56
|
+
|
|
57
|
+
def observer(garbage: Any) -> None:
|
|
58
|
+
if garbage:
|
|
59
|
+
if not any(_cycles.is_cuda_tensor(obj) for obj in garbage):
|
|
60
|
+
logger.debug("No CUDA Tensors found in garbage")
|
|
61
|
+
return
|
|
62
|
+
logger.warning("Reference cycle includes a CUDA Tensor")
|
|
63
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
64
|
+
base_dir, html_path = _artifact_path(temp_dir, rank, "cycles", "html")
|
|
65
|
+
logger.debug(f"Saving tensor cycles to {html_path}")
|
|
66
|
+
with open(html_path, "wb") as f:
|
|
67
|
+
f.write(_cycles.to_html(_cycles.create_graph(garbage)))
|
|
68
|
+
cb_logger.log_artifact(base_dir)
|
|
69
|
+
|
|
70
|
+
_cycles.observe_garbage(observer)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MemoryObserver(L.Callback):
|
|
74
|
+
"""This callback registers an observer to dump and log the CUDA memory snapshot.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
oom: (bool): whether to dump memory snapshot on Out-of-Memory (OOM) event. Defaults to ``True``
|
|
78
|
+
flamegraph (bool): whether to save memory snapshot in flamegraph format. Defaults to ``True``
|
|
79
|
+
reset_memory_history (bool): whether to reset memory history after snapshot. Defaults to ``False``
|
|
80
|
+
snapshot_pickle (bool): whether to dump memory snapshot in pickle format. Defaults to ``False``
|
|
81
|
+
tensor_cycles (bool): whether to detect and dump graphs with cycles containing tensors in the garbage.
|
|
82
|
+
Defaults to ``False``.
|
|
83
|
+
schedule (Optional[Schedule]): Controls when logging occurs besides OOM event. Defaults to :class:`Never`
|
|
84
|
+
**kwargs (Any): Arbitrary keyword arguments passed as is to ``memory._record_memory_history``.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
flamegraph: bool = True,
|
|
90
|
+
reset_memory_history: bool = False,
|
|
91
|
+
snapshot_pickle: bool = False,
|
|
92
|
+
tensor_cycles: bool = False,
|
|
93
|
+
schedule: Schedule | None = None,
|
|
94
|
+
oom: bool = True,
|
|
95
|
+
**kwargs: Any,
|
|
96
|
+
) -> None:
|
|
97
|
+
self.flamegraph = flamegraph
|
|
98
|
+
self.reset_memory_history = reset_memory_history
|
|
99
|
+
self.snapshot_pickle = snapshot_pickle
|
|
100
|
+
self.tensor_cycles = tensor_cycles
|
|
101
|
+
self.schedule = schedule or Never()
|
|
102
|
+
self.oom = oom
|
|
103
|
+
self.kwargs = kwargs
|
|
104
|
+
self._cb_logger: LightningLogger | None = None
|
|
105
|
+
_reset_recording(kwargs)
|
|
106
|
+
|
|
107
|
+
@override
|
|
108
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
109
|
+
if not torch.cuda.is_available():
|
|
110
|
+
logger.warning("No CUDA device is available")
|
|
111
|
+
return
|
|
112
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
113
|
+
if self.tensor_cycles:
|
|
114
|
+
_detect_tensor_cycles(self._cb_logger, trainer.global_rank)
|
|
115
|
+
if self.oom:
|
|
116
|
+
if hasattr(torch._C, "_cuda_attach_out_of_memory_observer"):
|
|
117
|
+
|
|
118
|
+
def oom_observer_func(device: Any, alloc: Any, device_alloc: Any, device_free: Any) -> None:
|
|
119
|
+
logger.warning("OOM observer triggered")
|
|
120
|
+
return self.dump_memory_snapshot(trainer.global_rank)
|
|
121
|
+
|
|
122
|
+
torch._C._cuda_attach_out_of_memory_observer(oom_observer_func)
|
|
123
|
+
logger.info("OOM observer registered successfully")
|
|
124
|
+
else:
|
|
125
|
+
logger.warning(
|
|
126
|
+
f"Failed to register OOM observer because torch._C._cuda_attach_out_of_memory_observer "
|
|
127
|
+
f"is missing in torch=={torch.__version__}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def maybe_dump_memory_snapshot(
|
|
131
|
+
self, trainer: "L.Trainer", stage: str | None = None, batch_idx: int | None = None
|
|
132
|
+
) -> None:
|
|
133
|
+
if not torch.cuda.is_available():
|
|
134
|
+
return
|
|
135
|
+
if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
136
|
+
self.dump_memory_snapshot(trainer.global_rank)
|
|
137
|
+
|
|
138
|
+
def dump_memory_snapshot(self, rank: int) -> None:
|
|
139
|
+
if not hasattr(memory, "_snapshot"):
|
|
140
|
+
logger.warning(
|
|
141
|
+
f"Failed to capture memory snapshot because memory._snapshot is missing in torch=={torch.__version__}"
|
|
142
|
+
)
|
|
143
|
+
return
|
|
144
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
145
|
+
logger.debug(f"Capturing memory snapshot on rank {rank} at {now}")
|
|
146
|
+
snapshot = memory._snapshot()
|
|
147
|
+
if self.reset_memory_history:
|
|
148
|
+
_reset_recording(self.kwargs)
|
|
149
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
150
|
+
base_dir: str | None = None
|
|
151
|
+
if self.snapshot_pickle:
|
|
152
|
+
base_dir, snapshot_path = _artifact_path(temp_dir, rank, "snapshot", "pickle")
|
|
153
|
+
logger.debug(f"Saving memory snapshot to {snapshot_path}")
|
|
154
|
+
with open(snapshot_path, "wb") as f:
|
|
155
|
+
pickle.dump(snapshot, f)
|
|
156
|
+
if self.flamegraph:
|
|
157
|
+
if hasattr(torch.cuda, "_memory_viz"):
|
|
158
|
+
flamegraph = torch.cuda._memory_viz.memory(snapshot)
|
|
159
|
+
base_dir, flamegraph_path = _artifact_path(temp_dir, rank, "flamegraph", "svg")
|
|
160
|
+
logger.debug(f"Saving memory flamegraph to {flamegraph_path}")
|
|
161
|
+
with open(flamegraph_path, "w") as f:
|
|
162
|
+
print(flamegraph, file=f)
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(
|
|
165
|
+
f"Failed to create flamegraph because torch.cuda._memory_viz "
|
|
166
|
+
f"is missing in torch=={torch.__version__}"
|
|
167
|
+
)
|
|
168
|
+
if base_dir is not None:
|
|
169
|
+
logger.debug(f"Logging memory snapshot files with {self._cb_logger}")
|
|
170
|
+
assert self._cb_logger
|
|
171
|
+
self._cb_logger.log_artifact(base_dir)
|
|
172
|
+
logger.debug("Finished capturing memory snapshot")
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
def on_train_batch_start(
|
|
176
|
+
self,
|
|
177
|
+
trainer: "L.Trainer",
|
|
178
|
+
pl_module: "L.LightningModule",
|
|
179
|
+
batch: Any,
|
|
180
|
+
batch_idx: int,
|
|
181
|
+
) -> None:
|
|
182
|
+
self.maybe_dump_memory_snapshot(trainer, stage="train", batch_idx=batch_idx)
|
|
183
|
+
|
|
184
|
+
@override
|
|
185
|
+
def on_test_batch_start(
|
|
186
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
187
|
+
) -> None:
|
|
188
|
+
self.maybe_dump_memory_snapshot(trainer, stage="test", batch_idx=batch_idx)
|
|
189
|
+
|
|
190
|
+
@override
|
|
191
|
+
def on_validation_batch_start(
|
|
192
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
193
|
+
) -> None:
|
|
194
|
+
self.maybe_dump_memory_snapshot(trainer, stage="validation", batch_idx=batch_idx)
|
|
195
|
+
|
|
196
|
+
@override
|
|
197
|
+
def on_predict_batch_start(
|
|
198
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
199
|
+
) -> None:
|
|
200
|
+
self.maybe_dump_memory_snapshot(trainer, stage="predict", batch_idx=batch_idx)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import gzip
|
|
6
|
+
import shutil
|
|
7
|
+
import tempfile
|
|
8
|
+
import atexit
|
|
9
|
+
import signal
|
|
10
|
+
from typing import Any, TYPE_CHECKING
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import lightning as L
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
19
|
+
|
|
20
|
+
from fkat.pytorch.schedule import (
|
|
21
|
+
Schedule,
|
|
22
|
+
Never,
|
|
23
|
+
)
|
|
24
|
+
from fkat.pytorch.utilities import get_rank
|
|
25
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
26
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def exec_with_nsys(kwargs: dict[str, str]) -> None:
|
|
30
|
+
"""Replace current process with nsys profiling of the specified script."""
|
|
31
|
+
# only capture between explicit API calls to start/stop profiling
|
|
32
|
+
kwargs["capture-range"] = "cudaProfilerApi"
|
|
33
|
+
kwargs["capture-range-end"] = "stop"
|
|
34
|
+
|
|
35
|
+
script_path, args = sys.argv[0], sys.argv[1:]
|
|
36
|
+
nsys_cmd = ["nsys", "profile", *[f"--{k}={v}" for k, v in kwargs.items()], "python", script_path] + args
|
|
37
|
+
|
|
38
|
+
# add current working dir for module resolution
|
|
39
|
+
os.environ["PYTHONPATH"] = os.path.join(
|
|
40
|
+
os.getcwd(), *([os.environ["PYTHONPATH"]] if "PYTHONPATH" in os.environ else [])
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# replace current process with nsys
|
|
44
|
+
os.execvp("nsys", nsys_cmd)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Nsys(L.Callback):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
ranks: Sequence[int] | None = None,
|
|
51
|
+
output_path_prefix: str | None = None,
|
|
52
|
+
schedule: Schedule | None = None,
|
|
53
|
+
compress: bool = True,
|
|
54
|
+
record_shapes: bool = False,
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""
|
|
58
|
+
[Nsys](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) PyTorch Lightning callback.
|
|
59
|
+
This :class:`L.Callback` continiously traces the training process and publishes a report
|
|
60
|
+
that helps examining the duration of individual calls through time.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
ranks (Optional[Sequence[int]]): Only trace the provided ranks, defaults to all ranks
|
|
64
|
+
output_path_prefix (Optional[str]): output path prefix for generated reports,
|
|
65
|
+
use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
|
|
66
|
+
schedule (Optional[Schedule]): Controls when tracing occurs during training.
|
|
67
|
+
Defaults to :class:`Never` - no tracing
|
|
68
|
+
compress (bool): Whether to compress the report.
|
|
69
|
+
Defaults to ``True``
|
|
70
|
+
record_shapes (bool): Whether to include tensor shapes in the report.
|
|
71
|
+
Defaults to ``False``
|
|
72
|
+
**kwargs (Any): Arbitrary keyword arguments passed as is to Nsys.
|
|
73
|
+
"""
|
|
74
|
+
self.rank = get_rank()
|
|
75
|
+
self.schedule = schedule or Never()
|
|
76
|
+
self.output_path_prefix = output_path_prefix
|
|
77
|
+
self.compress = compress
|
|
78
|
+
self.record_shapes = record_shapes
|
|
79
|
+
self._enabled = False
|
|
80
|
+
|
|
81
|
+
if ranks is None or self.rank in ranks:
|
|
82
|
+
# break infinite recusion
|
|
83
|
+
self.output_file = os.environ.pop("NSYS_OUTPUT", None)
|
|
84
|
+
if self.output_file is None:
|
|
85
|
+
output_file = os.path.join(self.output_path_prefix or tempfile.mkdtemp(), f"rank{self.rank}.nsys-rep")
|
|
86
|
+
os.environ["NSYS_OUTPUT"] = kwargs["output"] = output_file
|
|
87
|
+
exec_with_nsys(kwargs)
|
|
88
|
+
self._maybe_trace()
|
|
89
|
+
self._cb_logger: LightningLogger | None = None
|
|
90
|
+
self.stage: str | None = None
|
|
91
|
+
|
|
92
|
+
signal.signal(signal.SIGTERM, self._terminate) # terminate signal
|
|
93
|
+
signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
|
|
94
|
+
atexit.register(self._terminate)
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
98
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
99
|
+
self.stage = stage
|
|
100
|
+
self._maybe_trace(stage=stage)
|
|
101
|
+
|
|
102
|
+
def _maybe_trace(
|
|
103
|
+
self, trainer: "L.Trainer | None" = None, stage: str | None = None, batch_idx: int | None = None
|
|
104
|
+
) -> None:
|
|
105
|
+
should_run = self.schedule.check(
|
|
106
|
+
stage=stage, batch_idx=batch_idx, step=trainer.global_step if trainer else None, trainer=trainer
|
|
107
|
+
)
|
|
108
|
+
if should_run:
|
|
109
|
+
self._start()
|
|
110
|
+
else:
|
|
111
|
+
self._stop()
|
|
112
|
+
|
|
113
|
+
def _start(self) -> None:
|
|
114
|
+
if self._enabled:
|
|
115
|
+
return
|
|
116
|
+
self._enabled = True
|
|
117
|
+
torch.cuda.cudart().cudaProfilerStart()
|
|
118
|
+
torch.autograd.profiler.emit_nvtx(record_shapes=self.record_shapes).__enter__()
|
|
119
|
+
|
|
120
|
+
def _stop(self) -> None:
|
|
121
|
+
if not self._enabled:
|
|
122
|
+
return
|
|
123
|
+
torch.cuda.cudart().cudaProfilerStop()
|
|
124
|
+
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
|
|
125
|
+
self._enabled = False
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def on_train_batch_end(
|
|
129
|
+
self,
|
|
130
|
+
trainer: "L.Trainer",
|
|
131
|
+
pl_module: "L.LightningModule",
|
|
132
|
+
outputs: Any,
|
|
133
|
+
batch: Any,
|
|
134
|
+
batch_idx: int,
|
|
135
|
+
) -> None:
|
|
136
|
+
self._maybe_trace(trainer, "train", batch_idx + 1)
|
|
137
|
+
|
|
138
|
+
@override
|
|
139
|
+
def on_validation_batch_end(
|
|
140
|
+
self,
|
|
141
|
+
trainer: "L.Trainer",
|
|
142
|
+
pl_module: "L.LightningModule",
|
|
143
|
+
outputs: "STEP_OUTPUT",
|
|
144
|
+
batch: Any,
|
|
145
|
+
batch_idx: int,
|
|
146
|
+
dataloader_idx: int = 0,
|
|
147
|
+
) -> None:
|
|
148
|
+
self._maybe_trace(trainer, "validation", batch_idx + 1)
|
|
149
|
+
|
|
150
|
+
@override
|
|
151
|
+
def on_predict_batch_end(
|
|
152
|
+
self,
|
|
153
|
+
trainer: "L.Trainer",
|
|
154
|
+
pl_module: "L.LightningModule",
|
|
155
|
+
outputs: Any,
|
|
156
|
+
batch: Any,
|
|
157
|
+
batch_idx: int,
|
|
158
|
+
dataloader_idx: int = 0,
|
|
159
|
+
) -> None:
|
|
160
|
+
self._maybe_trace(trainer, "predict", batch_idx + 1)
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
def on_test_batch_end(
|
|
164
|
+
self,
|
|
165
|
+
trainer: "L.Trainer",
|
|
166
|
+
pl_module: "L.LightningModule",
|
|
167
|
+
outputs: "STEP_OUTPUT",
|
|
168
|
+
batch: Any,
|
|
169
|
+
batch_idx: int,
|
|
170
|
+
dataloader_idx: int = 0,
|
|
171
|
+
) -> None:
|
|
172
|
+
self._maybe_trace(trainer, "test", batch_idx + 1)
|
|
173
|
+
|
|
174
|
+
def _publish(self) -> None:
|
|
175
|
+
self._stop()
|
|
176
|
+
assert self.output_file
|
|
177
|
+
os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
|
|
178
|
+
if self.compress:
|
|
179
|
+
with open(self.output_file, "rb") as f_in:
|
|
180
|
+
output_file = self.output_file + ".gz"
|
|
181
|
+
with gzip.open(output_file, "wb") as f_out:
|
|
182
|
+
shutil.copyfileobj(f_in, f_out)
|
|
183
|
+
shutil.rmtree(self.output_file, ignore_errors=True)
|
|
184
|
+
assert self._cb_logger
|
|
185
|
+
self._cb_logger.log_artifact(output_file, "nsys")
|
|
186
|
+
if not self.output_path_prefix:
|
|
187
|
+
shutil.rmtree(output_file, ignore_errors=True)
|
|
188
|
+
|
|
189
|
+
@override
|
|
190
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
191
|
+
self._terminate()
|
|
192
|
+
|
|
193
|
+
@override
|
|
194
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
195
|
+
self._terminate()
|
|
196
|
+
|
|
197
|
+
def _terminate(self, *_: Any) -> None:
|
|
198
|
+
if self.stage:
|
|
199
|
+
self._publish()
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
import inspect
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import nvtx
|
|
16
|
+
except ImportError:
|
|
17
|
+
from torch.cuda import nvtx
|
|
18
|
+
|
|
19
|
+
_mark = nvtx.mark
|
|
20
|
+
|
|
21
|
+
def _conditional_mark(message: str, *args: Any, **kwargs: Any) -> Any:
|
|
22
|
+
sig = inspect.signature(_mark)
|
|
23
|
+
filtered_kwargs = {}
|
|
24
|
+
|
|
25
|
+
if "domain" in kwargs and "color" not in kwargs:
|
|
26
|
+
kwargs["color"] = DOMAIN_COLORS[kwargs["domain"]]
|
|
27
|
+
|
|
28
|
+
for param in ["color", "domain"]:
|
|
29
|
+
if param in sig.parameters and param in kwargs:
|
|
30
|
+
filtered_kwargs[param] = kwargs[param]
|
|
31
|
+
return _mark(message, **filtered_kwargs)
|
|
32
|
+
|
|
33
|
+
nvtx.mark = _conditional_mark # type: ignore[invalid-assignment]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Domain(str, Enum):
|
|
37
|
+
INIT = "init"
|
|
38
|
+
TRAIN = "train"
|
|
39
|
+
VALIDATION = "validation"
|
|
40
|
+
TEST = "test"
|
|
41
|
+
PREDICT = "predict"
|
|
42
|
+
TUNE = "tune"
|
|
43
|
+
ERROR = "error"
|
|
44
|
+
CHECKPOINT = "checkpoint"
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_stage(s: str) -> "Domain":
|
|
48
|
+
if s == "fit" or s == "train":
|
|
49
|
+
return Domain.TRAIN
|
|
50
|
+
if s == "validation":
|
|
51
|
+
return Domain.VALIDATION
|
|
52
|
+
if s == "test":
|
|
53
|
+
return Domain.TEST
|
|
54
|
+
if s == "predict":
|
|
55
|
+
return Domain.PREDICT
|
|
56
|
+
if s == "tune":
|
|
57
|
+
return Domain.TUNE
|
|
58
|
+
raise NotImplementedError(f"Unsupported stage: {s}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
DOMAIN_COLORS = {
|
|
62
|
+
Domain.INIT: "white",
|
|
63
|
+
Domain.TUNE: "pink",
|
|
64
|
+
Domain.TRAIN: "green",
|
|
65
|
+
Domain.VALIDATION: "blue",
|
|
66
|
+
Domain.TEST: "purple",
|
|
67
|
+
Domain.PREDICT: "yellow",
|
|
68
|
+
Domain.ERROR: "red",
|
|
69
|
+
Domain.CHECKPOINT: "orange",
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Nvtx(L.Callback):
|
|
74
|
+
def __init__(self) -> None:
|
|
75
|
+
nvtx.mark("__init__()", domain=Domain.INIT) # type: ignore[unknown-argument]
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
79
|
+
domain = Domain.from_stage(stage)
|
|
80
|
+
nvtx.mark(f"setup(stage={stage})", domain=domain) # type: ignore[unknown-argument]
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
84
|
+
domain = Domain.from_stage(stage)
|
|
85
|
+
nvtx.mark(f"teardown(stage={stage})", domain=domain) # type: ignore[unknown-argument]
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
89
|
+
nvtx.mark("on_train_start()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
93
|
+
nvtx.mark("on_train_epoch_start()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def on_train_batch_start(
|
|
97
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
98
|
+
) -> None:
|
|
99
|
+
nvtx.mark(
|
|
100
|
+
f"on_train_batch_start(batch_idx={batch_idx})",
|
|
101
|
+
domain=Domain.TRAIN, # type: ignore[unknown-argument]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
@override
|
|
105
|
+
def on_before_zero_grad(
|
|
106
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
|
|
107
|
+
) -> None:
|
|
108
|
+
nvtx.mark("on_before_zero_grad()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def on_before_backward(self, trainer: "L.Trainer", pl_module: "L.LightningModule", loss: "torch.Tensor") -> None:
|
|
112
|
+
nvtx.mark("on_before_backward()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
def on_after_backward(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
116
|
+
nvtx.mark("on_after_backward()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
def on_before_optimizer_step(
|
|
120
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
|
|
121
|
+
) -> None:
|
|
122
|
+
nvtx.mark("on_before_optimizer_step()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
123
|
+
|
|
124
|
+
@override
|
|
125
|
+
def on_train_batch_end(
|
|
126
|
+
self,
|
|
127
|
+
trainer: "L.Trainer",
|
|
128
|
+
pl_module: "L.LightningModule",
|
|
129
|
+
outputs: "STEP_OUTPUT",
|
|
130
|
+
batch: Any,
|
|
131
|
+
batch_idx: int,
|
|
132
|
+
dataloader_idx: int = 0,
|
|
133
|
+
) -> None:
|
|
134
|
+
nvtx.mark(f"on_train_batch_end(batch_idx={batch_idx})", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
135
|
+
|
|
136
|
+
@override
|
|
137
|
+
def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
138
|
+
nvtx.mark("on_train_epoch_end()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
139
|
+
|
|
140
|
+
@override
|
|
141
|
+
def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
142
|
+
nvtx.mark("on_train_end()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
|
|
143
|
+
|
|
144
|
+
@override
|
|
145
|
+
def on_sanity_check_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
146
|
+
nvtx.mark("on_validation_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
147
|
+
|
|
148
|
+
def on_sanity_check_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
149
|
+
nvtx.mark("on_sanity_check_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
def on_validation_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
153
|
+
nvtx.mark("on_sanity_check_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
154
|
+
|
|
155
|
+
@override
|
|
156
|
+
def on_validation_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
157
|
+
nvtx.mark("on_validation_epoch_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
158
|
+
|
|
159
|
+
@override
|
|
160
|
+
def on_validation_batch_start(
|
|
161
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
162
|
+
) -> None:
|
|
163
|
+
nvtx.mark(
|
|
164
|
+
f"on_validation_batch_start(batch_idx={batch_idx})",
|
|
165
|
+
domain=Domain.VALIDATION, # type: ignore[unknown-argument]
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@override
|
|
169
|
+
def on_validation_batch_end(
|
|
170
|
+
self,
|
|
171
|
+
trainer: "L.Trainer",
|
|
172
|
+
pl_module: "L.LightningModule",
|
|
173
|
+
outputs: "STEP_OUTPUT",
|
|
174
|
+
batch: Any,
|
|
175
|
+
batch_idx: int,
|
|
176
|
+
dataloader_idx: int = 0,
|
|
177
|
+
) -> None:
|
|
178
|
+
nvtx.mark(
|
|
179
|
+
f"on_validation_batch_end(batch_idx={batch_idx})",
|
|
180
|
+
domain=Domain.VALIDATION, # type: ignore[unknown-argument]
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@override
|
|
184
|
+
def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
185
|
+
nvtx.mark("on_validation_epoch_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
186
|
+
|
|
187
|
+
@override
|
|
188
|
+
def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
189
|
+
nvtx.mark("on_validation_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
|
|
190
|
+
|
|
191
|
+
@override
|
|
192
|
+
def on_test_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
193
|
+
nvtx.mark("on_test_start()", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
194
|
+
|
|
195
|
+
@override
|
|
196
|
+
def on_test_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
197
|
+
nvtx.mark("on_test_epoch_start()", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
def on_test_batch_start(
|
|
201
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
202
|
+
) -> None:
|
|
203
|
+
nvtx.mark(f"on_test_batch_start(batch_idx={batch_idx})", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
def on_test_batch_end(
|
|
207
|
+
self,
|
|
208
|
+
trainer: "L.Trainer",
|
|
209
|
+
pl_module: "L.LightningModule",
|
|
210
|
+
outputs: "STEP_OUTPUT",
|
|
211
|
+
batch: Any,
|
|
212
|
+
batch_idx: int,
|
|
213
|
+
dataloader_idx: int = 0,
|
|
214
|
+
) -> None:
|
|
215
|
+
nvtx.mark(f"on_test_batch_end(batch_idx={batch_idx})", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
216
|
+
|
|
217
|
+
@override
|
|
218
|
+
def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
219
|
+
nvtx.mark("on_test_epoch_end()", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
220
|
+
|
|
221
|
+
@override
|
|
222
|
+
def on_test_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
223
|
+
nvtx.mark("on_test_end()", domain=Domain.TEST) # type: ignore[unknown-argument]
|
|
224
|
+
|
|
225
|
+
@override
|
|
226
|
+
def on_predict_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
227
|
+
nvtx.mark("on_predict_start()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
|
|
228
|
+
|
|
229
|
+
@override
|
|
230
|
+
def on_predict_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
231
|
+
nvtx.mark("on_predict_epoch_start()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
|
|
232
|
+
|
|
233
|
+
@override
|
|
234
|
+
def on_predict_batch_start(
|
|
235
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
236
|
+
) -> None:
|
|
237
|
+
nvtx.mark(
|
|
238
|
+
f"on_predict_batch_start(batch_idx={batch_idx})",
|
|
239
|
+
domain=Domain.PREDICT, # type: ignore[unknown-argument]
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@override
|
|
243
|
+
def on_predict_batch_end(
|
|
244
|
+
self,
|
|
245
|
+
trainer: "L.Trainer",
|
|
246
|
+
pl_module: "L.LightningModule",
|
|
247
|
+
outputs: Any,
|
|
248
|
+
batch: Any,
|
|
249
|
+
batch_idx: int,
|
|
250
|
+
dataloader_idx: int = 0,
|
|
251
|
+
) -> None:
|
|
252
|
+
nvtx.mark(
|
|
253
|
+
f"on_predict_batch_end(batch_idx={batch_idx})",
|
|
254
|
+
domain=Domain.PREDICT, # type: ignore[unknown-argument]
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@override
|
|
258
|
+
def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
259
|
+
nvtx.mark("on_predict_epoch_end()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
|
|
260
|
+
|
|
261
|
+
@override
|
|
262
|
+
def on_predict_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
263
|
+
nvtx.mark("on_predict_end()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
|
|
264
|
+
|
|
265
|
+
@override
|
|
266
|
+
def state_dict(self) -> dict[str, Any]:
|
|
267
|
+
nvtx.mark("state_dict()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
|
|
268
|
+
return {}
|
|
269
|
+
|
|
270
|
+
@override
|
|
271
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
272
|
+
nvtx.mark("load_state_dict()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
|
|
273
|
+
|
|
274
|
+
@override
|
|
275
|
+
def on_save_checkpoint(
|
|
276
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", checkpoint: dict[str, Any]
|
|
277
|
+
) -> None:
|
|
278
|
+
nvtx.mark("on_save_checkpoint()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
|
|
279
|
+
|
|
280
|
+
@override
|
|
281
|
+
def on_load_checkpoint(
|
|
282
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", checkpoint: dict[str, Any]
|
|
283
|
+
) -> None:
|
|
284
|
+
nvtx.mark("on_load_checkpoint()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
|
|
285
|
+
|
|
286
|
+
@override
|
|
287
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
288
|
+
nvtx.mark(f"on_exception({type(exception)})", domain=Domain.ERROR) # type: ignore[unknown-argument]
|