fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,200 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ import os
5
+ import pickle
6
+ import tempfile
7
+ from datetime import datetime, timezone
8
+ from typing import Any
9
+ from typing_extensions import override
10
+
11
+ import lightning as L
12
+ import torch
13
+ from torch.cuda import memory
14
+
15
+ from fkat.pytorch.schedule import (
16
+ Schedule,
17
+ Never,
18
+ )
19
+ from fkat.pytorch.loggers import LightningLogger
20
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
21
+ from fkat.utils import safe_timestamp
22
+
23
+ logger: logging.Logger = logging.getLogger(__name__)
24
+
25
+
26
+ def _artifact_path(root_dir: str, rank: int, file_type: str, ext: str) -> tuple[str, str]:
27
+ base_dir = os.path.join(root_dir, "torch.cuda.memory")
28
+ timestamp = safe_timestamp()
29
+ file_path = os.path.join(base_dir, f"rank{rank}/{file_type}/rank{rank}_{timestamp}.{ext}")
30
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
31
+ return base_dir, file_path
32
+
33
+
34
+ def _reset_recording(kwargs: dict[str, Any]) -> None:
35
+ if torch.cuda.is_available():
36
+ memory._record_memory_history(enabled=None)
37
+ # set the limitation of ring buffer ~100 G. Otherwise, the buffer might be too large and trigger CPU OOM.
38
+ kwargs.setdefault("max_entries", 1000000)
39
+ memory._record_memory_history(**kwargs)
40
+
41
+
42
+ def _detect_tensor_cycles(cb_logger: CallbackLogger, rank: int) -> None:
43
+ from torch.utils.viz import _cycles
44
+
45
+ def is_cuda_tensor(obj: Any) -> bool:
46
+ try:
47
+ return (
48
+ isinstance(obj, torch.Tensor)
49
+ and obj.device.type == "cuda"
50
+ and not isinstance(obj, torch._subclasses.FakeTensor)
51
+ )
52
+ except: # noqa: E722
53
+ return False
54
+
55
+ _cycles.is_cuda_tensor = is_cuda_tensor # type: ignore[invalid-assignment]
56
+
57
+ def observer(garbage: Any) -> None:
58
+ if garbage:
59
+ if not any(_cycles.is_cuda_tensor(obj) for obj in garbage):
60
+ logger.debug("No CUDA Tensors found in garbage")
61
+ return
62
+ logger.warning("Reference cycle includes a CUDA Tensor")
63
+ with tempfile.TemporaryDirectory() as temp_dir:
64
+ base_dir, html_path = _artifact_path(temp_dir, rank, "cycles", "html")
65
+ logger.debug(f"Saving tensor cycles to {html_path}")
66
+ with open(html_path, "wb") as f:
67
+ f.write(_cycles.to_html(_cycles.create_graph(garbage)))
68
+ cb_logger.log_artifact(base_dir)
69
+
70
+ _cycles.observe_garbage(observer)
71
+
72
+
73
+ class MemoryObserver(L.Callback):
74
+ """This callback registers an observer to dump and log the CUDA memory snapshot.
75
+
76
+ Args:
77
+ oom: (bool): whether to dump memory snapshot on Out-of-Memory (OOM) event. Defaults to ``True``
78
+ flamegraph (bool): whether to save memory snapshot in flamegraph format. Defaults to ``True``
79
+ reset_memory_history (bool): whether to reset memory history after snapshot. Defaults to ``False``
80
+ snapshot_pickle (bool): whether to dump memory snapshot in pickle format. Defaults to ``False``
81
+ tensor_cycles (bool): whether to detect and dump graphs with cycles containing tensors in the garbage.
82
+ Defaults to ``False``.
83
+ schedule (Optional[Schedule]): Controls when logging occurs besides OOM event. Defaults to :class:`Never`
84
+ **kwargs (Any): Arbitrary keyword arguments passed as is to ``memory._record_memory_history``.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ flamegraph: bool = True,
90
+ reset_memory_history: bool = False,
91
+ snapshot_pickle: bool = False,
92
+ tensor_cycles: bool = False,
93
+ schedule: Schedule | None = None,
94
+ oom: bool = True,
95
+ **kwargs: Any,
96
+ ) -> None:
97
+ self.flamegraph = flamegraph
98
+ self.reset_memory_history = reset_memory_history
99
+ self.snapshot_pickle = snapshot_pickle
100
+ self.tensor_cycles = tensor_cycles
101
+ self.schedule = schedule or Never()
102
+ self.oom = oom
103
+ self.kwargs = kwargs
104
+ self._cb_logger: LightningLogger | None = None
105
+ _reset_recording(kwargs)
106
+
107
+ @override
108
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
109
+ if not torch.cuda.is_available():
110
+ logger.warning("No CUDA device is available")
111
+ return
112
+ self._cb_logger = CallbackLogger(trainer)
113
+ if self.tensor_cycles:
114
+ _detect_tensor_cycles(self._cb_logger, trainer.global_rank)
115
+ if self.oom:
116
+ if hasattr(torch._C, "_cuda_attach_out_of_memory_observer"):
117
+
118
+ def oom_observer_func(device: Any, alloc: Any, device_alloc: Any, device_free: Any) -> None:
119
+ logger.warning("OOM observer triggered")
120
+ return self.dump_memory_snapshot(trainer.global_rank)
121
+
122
+ torch._C._cuda_attach_out_of_memory_observer(oom_observer_func)
123
+ logger.info("OOM observer registered successfully")
124
+ else:
125
+ logger.warning(
126
+ f"Failed to register OOM observer because torch._C._cuda_attach_out_of_memory_observer "
127
+ f"is missing in torch=={torch.__version__}"
128
+ )
129
+
130
+ def maybe_dump_memory_snapshot(
131
+ self, trainer: "L.Trainer", stage: str | None = None, batch_idx: int | None = None
132
+ ) -> None:
133
+ if not torch.cuda.is_available():
134
+ return
135
+ if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
136
+ self.dump_memory_snapshot(trainer.global_rank)
137
+
138
+ def dump_memory_snapshot(self, rank: int) -> None:
139
+ if not hasattr(memory, "_snapshot"):
140
+ logger.warning(
141
+ f"Failed to capture memory snapshot because memory._snapshot is missing in torch=={torch.__version__}"
142
+ )
143
+ return
144
+ now = datetime.now(timezone.utc).isoformat()
145
+ logger.debug(f"Capturing memory snapshot on rank {rank} at {now}")
146
+ snapshot = memory._snapshot()
147
+ if self.reset_memory_history:
148
+ _reset_recording(self.kwargs)
149
+ with tempfile.TemporaryDirectory() as temp_dir:
150
+ base_dir: str | None = None
151
+ if self.snapshot_pickle:
152
+ base_dir, snapshot_path = _artifact_path(temp_dir, rank, "snapshot", "pickle")
153
+ logger.debug(f"Saving memory snapshot to {snapshot_path}")
154
+ with open(snapshot_path, "wb") as f:
155
+ pickle.dump(snapshot, f)
156
+ if self.flamegraph:
157
+ if hasattr(torch.cuda, "_memory_viz"):
158
+ flamegraph = torch.cuda._memory_viz.memory(snapshot)
159
+ base_dir, flamegraph_path = _artifact_path(temp_dir, rank, "flamegraph", "svg")
160
+ logger.debug(f"Saving memory flamegraph to {flamegraph_path}")
161
+ with open(flamegraph_path, "w") as f:
162
+ print(flamegraph, file=f)
163
+ else:
164
+ logger.warning(
165
+ f"Failed to create flamegraph because torch.cuda._memory_viz "
166
+ f"is missing in torch=={torch.__version__}"
167
+ )
168
+ if base_dir is not None:
169
+ logger.debug(f"Logging memory snapshot files with {self._cb_logger}")
170
+ assert self._cb_logger
171
+ self._cb_logger.log_artifact(base_dir)
172
+ logger.debug("Finished capturing memory snapshot")
173
+
174
+ @override
175
+ def on_train_batch_start(
176
+ self,
177
+ trainer: "L.Trainer",
178
+ pl_module: "L.LightningModule",
179
+ batch: Any,
180
+ batch_idx: int,
181
+ ) -> None:
182
+ self.maybe_dump_memory_snapshot(trainer, stage="train", batch_idx=batch_idx)
183
+
184
+ @override
185
+ def on_test_batch_start(
186
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
187
+ ) -> None:
188
+ self.maybe_dump_memory_snapshot(trainer, stage="test", batch_idx=batch_idx)
189
+
190
+ @override
191
+ def on_validation_batch_start(
192
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
193
+ ) -> None:
194
+ self.maybe_dump_memory_snapshot(trainer, stage="validation", batch_idx=batch_idx)
195
+
196
+ @override
197
+ def on_predict_batch_start(
198
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
199
+ ) -> None:
200
+ self.maybe_dump_memory_snapshot(trainer, stage="predict", batch_idx=batch_idx)
@@ -0,0 +1,199 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import sys
5
+ import gzip
6
+ import shutil
7
+ import tempfile
8
+ import atexit
9
+ import signal
10
+ from typing import Any, TYPE_CHECKING
11
+ from typing_extensions import override
12
+ from collections.abc import Sequence
13
+
14
+ import torch
15
+ import lightning as L
16
+
17
+ if TYPE_CHECKING:
18
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
19
+
20
+ from fkat.pytorch.schedule import (
21
+ Schedule,
22
+ Never,
23
+ )
24
+ from fkat.pytorch.utilities import get_rank
25
+ from fkat.pytorch.loggers import LightningLogger
26
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
27
+
28
+
29
+ def exec_with_nsys(kwargs: dict[str, str]) -> None:
30
+ """Replace current process with nsys profiling of the specified script."""
31
+ # only capture between explicit API calls to start/stop profiling
32
+ kwargs["capture-range"] = "cudaProfilerApi"
33
+ kwargs["capture-range-end"] = "stop"
34
+
35
+ script_path, args = sys.argv[0], sys.argv[1:]
36
+ nsys_cmd = ["nsys", "profile", *[f"--{k}={v}" for k, v in kwargs.items()], "python", script_path] + args
37
+
38
+ # add current working dir for module resolution
39
+ os.environ["PYTHONPATH"] = os.path.join(
40
+ os.getcwd(), *([os.environ["PYTHONPATH"]] if "PYTHONPATH" in os.environ else [])
41
+ )
42
+
43
+ # replace current process with nsys
44
+ os.execvp("nsys", nsys_cmd)
45
+
46
+
47
+ class Nsys(L.Callback):
48
+ def __init__(
49
+ self,
50
+ ranks: Sequence[int] | None = None,
51
+ output_path_prefix: str | None = None,
52
+ schedule: Schedule | None = None,
53
+ compress: bool = True,
54
+ record_shapes: bool = False,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ """
58
+ [Nsys](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) PyTorch Lightning callback.
59
+ This :class:`L.Callback` continiously traces the training process and publishes a report
60
+ that helps examining the duration of individual calls through time.
61
+
62
+ Args:
63
+ ranks (Optional[Sequence[int]]): Only trace the provided ranks, defaults to all ranks
64
+ output_path_prefix (Optional[str]): output path prefix for generated reports,
65
+ use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
66
+ schedule (Optional[Schedule]): Controls when tracing occurs during training.
67
+ Defaults to :class:`Never` - no tracing
68
+ compress (bool): Whether to compress the report.
69
+ Defaults to ``True``
70
+ record_shapes (bool): Whether to include tensor shapes in the report.
71
+ Defaults to ``False``
72
+ **kwargs (Any): Arbitrary keyword arguments passed as is to Nsys.
73
+ """
74
+ self.rank = get_rank()
75
+ self.schedule = schedule or Never()
76
+ self.output_path_prefix = output_path_prefix
77
+ self.compress = compress
78
+ self.record_shapes = record_shapes
79
+ self._enabled = False
80
+
81
+ if ranks is None or self.rank in ranks:
82
+ # break infinite recusion
83
+ self.output_file = os.environ.pop("NSYS_OUTPUT", None)
84
+ if self.output_file is None:
85
+ output_file = os.path.join(self.output_path_prefix or tempfile.mkdtemp(), f"rank{self.rank}.nsys-rep")
86
+ os.environ["NSYS_OUTPUT"] = kwargs["output"] = output_file
87
+ exec_with_nsys(kwargs)
88
+ self._maybe_trace()
89
+ self._cb_logger: LightningLogger | None = None
90
+ self.stage: str | None = None
91
+
92
+ signal.signal(signal.SIGTERM, self._terminate) # terminate signal
93
+ signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
94
+ atexit.register(self._terminate)
95
+
96
+ @override
97
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
98
+ self._cb_logger = CallbackLogger(trainer)
99
+ self.stage = stage
100
+ self._maybe_trace(stage=stage)
101
+
102
+ def _maybe_trace(
103
+ self, trainer: "L.Trainer | None" = None, stage: str | None = None, batch_idx: int | None = None
104
+ ) -> None:
105
+ should_run = self.schedule.check(
106
+ stage=stage, batch_idx=batch_idx, step=trainer.global_step if trainer else None, trainer=trainer
107
+ )
108
+ if should_run:
109
+ self._start()
110
+ else:
111
+ self._stop()
112
+
113
+ def _start(self) -> None:
114
+ if self._enabled:
115
+ return
116
+ self._enabled = True
117
+ torch.cuda.cudart().cudaProfilerStart()
118
+ torch.autograd.profiler.emit_nvtx(record_shapes=self.record_shapes).__enter__()
119
+
120
+ def _stop(self) -> None:
121
+ if not self._enabled:
122
+ return
123
+ torch.cuda.cudart().cudaProfilerStop()
124
+ torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
125
+ self._enabled = False
126
+
127
+ @override
128
+ def on_train_batch_end(
129
+ self,
130
+ trainer: "L.Trainer",
131
+ pl_module: "L.LightningModule",
132
+ outputs: Any,
133
+ batch: Any,
134
+ batch_idx: int,
135
+ ) -> None:
136
+ self._maybe_trace(trainer, "train", batch_idx + 1)
137
+
138
+ @override
139
+ def on_validation_batch_end(
140
+ self,
141
+ trainer: "L.Trainer",
142
+ pl_module: "L.LightningModule",
143
+ outputs: "STEP_OUTPUT",
144
+ batch: Any,
145
+ batch_idx: int,
146
+ dataloader_idx: int = 0,
147
+ ) -> None:
148
+ self._maybe_trace(trainer, "validation", batch_idx + 1)
149
+
150
+ @override
151
+ def on_predict_batch_end(
152
+ self,
153
+ trainer: "L.Trainer",
154
+ pl_module: "L.LightningModule",
155
+ outputs: Any,
156
+ batch: Any,
157
+ batch_idx: int,
158
+ dataloader_idx: int = 0,
159
+ ) -> None:
160
+ self._maybe_trace(trainer, "predict", batch_idx + 1)
161
+
162
+ @override
163
+ def on_test_batch_end(
164
+ self,
165
+ trainer: "L.Trainer",
166
+ pl_module: "L.LightningModule",
167
+ outputs: "STEP_OUTPUT",
168
+ batch: Any,
169
+ batch_idx: int,
170
+ dataloader_idx: int = 0,
171
+ ) -> None:
172
+ self._maybe_trace(trainer, "test", batch_idx + 1)
173
+
174
+ def _publish(self) -> None:
175
+ self._stop()
176
+ assert self.output_file
177
+ os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
178
+ if self.compress:
179
+ with open(self.output_file, "rb") as f_in:
180
+ output_file = self.output_file + ".gz"
181
+ with gzip.open(output_file, "wb") as f_out:
182
+ shutil.copyfileobj(f_in, f_out)
183
+ shutil.rmtree(self.output_file, ignore_errors=True)
184
+ assert self._cb_logger
185
+ self._cb_logger.log_artifact(output_file, "nsys")
186
+ if not self.output_path_prefix:
187
+ shutil.rmtree(output_file, ignore_errors=True)
188
+
189
+ @override
190
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
191
+ self._terminate()
192
+
193
+ @override
194
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
195
+ self._terminate()
196
+
197
+ def _terminate(self, *_: Any) -> None:
198
+ if self.stage:
199
+ self._publish()
@@ -0,0 +1,288 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from enum import Enum
4
+ from typing import Any, TYPE_CHECKING
5
+ from typing_extensions import override
6
+ import inspect
7
+
8
+ import lightning as L
9
+ import torch
10
+
11
+ if TYPE_CHECKING:
12
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
13
+
14
+ try:
15
+ import nvtx
16
+ except ImportError:
17
+ from torch.cuda import nvtx
18
+
19
+ _mark = nvtx.mark
20
+
21
+ def _conditional_mark(message: str, *args: Any, **kwargs: Any) -> Any:
22
+ sig = inspect.signature(_mark)
23
+ filtered_kwargs = {}
24
+
25
+ if "domain" in kwargs and "color" not in kwargs:
26
+ kwargs["color"] = DOMAIN_COLORS[kwargs["domain"]]
27
+
28
+ for param in ["color", "domain"]:
29
+ if param in sig.parameters and param in kwargs:
30
+ filtered_kwargs[param] = kwargs[param]
31
+ return _mark(message, **filtered_kwargs)
32
+
33
+ nvtx.mark = _conditional_mark # type: ignore[invalid-assignment]
34
+
35
+
36
+ class Domain(str, Enum):
37
+ INIT = "init"
38
+ TRAIN = "train"
39
+ VALIDATION = "validation"
40
+ TEST = "test"
41
+ PREDICT = "predict"
42
+ TUNE = "tune"
43
+ ERROR = "error"
44
+ CHECKPOINT = "checkpoint"
45
+
46
+ @staticmethod
47
+ def from_stage(s: str) -> "Domain":
48
+ if s == "fit" or s == "train":
49
+ return Domain.TRAIN
50
+ if s == "validation":
51
+ return Domain.VALIDATION
52
+ if s == "test":
53
+ return Domain.TEST
54
+ if s == "predict":
55
+ return Domain.PREDICT
56
+ if s == "tune":
57
+ return Domain.TUNE
58
+ raise NotImplementedError(f"Unsupported stage: {s}")
59
+
60
+
61
+ DOMAIN_COLORS = {
62
+ Domain.INIT: "white",
63
+ Domain.TUNE: "pink",
64
+ Domain.TRAIN: "green",
65
+ Domain.VALIDATION: "blue",
66
+ Domain.TEST: "purple",
67
+ Domain.PREDICT: "yellow",
68
+ Domain.ERROR: "red",
69
+ Domain.CHECKPOINT: "orange",
70
+ }
71
+
72
+
73
+ class Nvtx(L.Callback):
74
+ def __init__(self) -> None:
75
+ nvtx.mark("__init__()", domain=Domain.INIT) # type: ignore[unknown-argument]
76
+
77
+ @override
78
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
79
+ domain = Domain.from_stage(stage)
80
+ nvtx.mark(f"setup(stage={stage})", domain=domain) # type: ignore[unknown-argument]
81
+
82
+ @override
83
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
84
+ domain = Domain.from_stage(stage)
85
+ nvtx.mark(f"teardown(stage={stage})", domain=domain) # type: ignore[unknown-argument]
86
+
87
+ @override
88
+ def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
89
+ nvtx.mark("on_train_start()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
90
+
91
+ @override
92
+ def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
93
+ nvtx.mark("on_train_epoch_start()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
94
+
95
+ @override
96
+ def on_train_batch_start(
97
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
98
+ ) -> None:
99
+ nvtx.mark(
100
+ f"on_train_batch_start(batch_idx={batch_idx})",
101
+ domain=Domain.TRAIN, # type: ignore[unknown-argument]
102
+ )
103
+
104
+ @override
105
+ def on_before_zero_grad(
106
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
107
+ ) -> None:
108
+ nvtx.mark("on_before_zero_grad()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
109
+
110
+ @override
111
+ def on_before_backward(self, trainer: "L.Trainer", pl_module: "L.LightningModule", loss: "torch.Tensor") -> None:
112
+ nvtx.mark("on_before_backward()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
113
+
114
+ @override
115
+ def on_after_backward(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
116
+ nvtx.mark("on_after_backward()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
117
+
118
+ @override
119
+ def on_before_optimizer_step(
120
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
121
+ ) -> None:
122
+ nvtx.mark("on_before_optimizer_step()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
123
+
124
+ @override
125
+ def on_train_batch_end(
126
+ self,
127
+ trainer: "L.Trainer",
128
+ pl_module: "L.LightningModule",
129
+ outputs: "STEP_OUTPUT",
130
+ batch: Any,
131
+ batch_idx: int,
132
+ dataloader_idx: int = 0,
133
+ ) -> None:
134
+ nvtx.mark(f"on_train_batch_end(batch_idx={batch_idx})", domain=Domain.TRAIN) # type: ignore[unknown-argument]
135
+
136
+ @override
137
+ def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
138
+ nvtx.mark("on_train_epoch_end()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
139
+
140
+ @override
141
+ def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
142
+ nvtx.mark("on_train_end()", domain=Domain.TRAIN) # type: ignore[unknown-argument]
143
+
144
+ @override
145
+ def on_sanity_check_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
146
+ nvtx.mark("on_validation_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
147
+
148
+ def on_sanity_check_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
149
+ nvtx.mark("on_sanity_check_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
150
+
151
+ @override
152
+ def on_validation_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
153
+ nvtx.mark("on_sanity_check_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
154
+
155
+ @override
156
+ def on_validation_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
157
+ nvtx.mark("on_validation_epoch_start()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
158
+
159
+ @override
160
+ def on_validation_batch_start(
161
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
162
+ ) -> None:
163
+ nvtx.mark(
164
+ f"on_validation_batch_start(batch_idx={batch_idx})",
165
+ domain=Domain.VALIDATION, # type: ignore[unknown-argument]
166
+ )
167
+
168
+ @override
169
+ def on_validation_batch_end(
170
+ self,
171
+ trainer: "L.Trainer",
172
+ pl_module: "L.LightningModule",
173
+ outputs: "STEP_OUTPUT",
174
+ batch: Any,
175
+ batch_idx: int,
176
+ dataloader_idx: int = 0,
177
+ ) -> None:
178
+ nvtx.mark(
179
+ f"on_validation_batch_end(batch_idx={batch_idx})",
180
+ domain=Domain.VALIDATION, # type: ignore[unknown-argument]
181
+ )
182
+
183
+ @override
184
+ def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
185
+ nvtx.mark("on_validation_epoch_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
186
+
187
+ @override
188
+ def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
189
+ nvtx.mark("on_validation_end()", domain=Domain.VALIDATION) # type: ignore[unknown-argument]
190
+
191
+ @override
192
+ def on_test_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
193
+ nvtx.mark("on_test_start()", domain=Domain.TEST) # type: ignore[unknown-argument]
194
+
195
+ @override
196
+ def on_test_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
197
+ nvtx.mark("on_test_epoch_start()", domain=Domain.TEST) # type: ignore[unknown-argument]
198
+
199
+ @override
200
+ def on_test_batch_start(
201
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
202
+ ) -> None:
203
+ nvtx.mark(f"on_test_batch_start(batch_idx={batch_idx})", domain=Domain.TEST) # type: ignore[unknown-argument]
204
+
205
+ @override
206
+ def on_test_batch_end(
207
+ self,
208
+ trainer: "L.Trainer",
209
+ pl_module: "L.LightningModule",
210
+ outputs: "STEP_OUTPUT",
211
+ batch: Any,
212
+ batch_idx: int,
213
+ dataloader_idx: int = 0,
214
+ ) -> None:
215
+ nvtx.mark(f"on_test_batch_end(batch_idx={batch_idx})", domain=Domain.TEST) # type: ignore[unknown-argument]
216
+
217
+ @override
218
+ def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
219
+ nvtx.mark("on_test_epoch_end()", domain=Domain.TEST) # type: ignore[unknown-argument]
220
+
221
+ @override
222
+ def on_test_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
223
+ nvtx.mark("on_test_end()", domain=Domain.TEST) # type: ignore[unknown-argument]
224
+
225
+ @override
226
+ def on_predict_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
227
+ nvtx.mark("on_predict_start()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
228
+
229
+ @override
230
+ def on_predict_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
231
+ nvtx.mark("on_predict_epoch_start()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
232
+
233
+ @override
234
+ def on_predict_batch_start(
235
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
236
+ ) -> None:
237
+ nvtx.mark(
238
+ f"on_predict_batch_start(batch_idx={batch_idx})",
239
+ domain=Domain.PREDICT, # type: ignore[unknown-argument]
240
+ )
241
+
242
+ @override
243
+ def on_predict_batch_end(
244
+ self,
245
+ trainer: "L.Trainer",
246
+ pl_module: "L.LightningModule",
247
+ outputs: Any,
248
+ batch: Any,
249
+ batch_idx: int,
250
+ dataloader_idx: int = 0,
251
+ ) -> None:
252
+ nvtx.mark(
253
+ f"on_predict_batch_end(batch_idx={batch_idx})",
254
+ domain=Domain.PREDICT, # type: ignore[unknown-argument]
255
+ )
256
+
257
+ @override
258
+ def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
259
+ nvtx.mark("on_predict_epoch_end()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
260
+
261
+ @override
262
+ def on_predict_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
263
+ nvtx.mark("on_predict_end()", domain=Domain.PREDICT) # type: ignore[unknown-argument]
264
+
265
+ @override
266
+ def state_dict(self) -> dict[str, Any]:
267
+ nvtx.mark("state_dict()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
268
+ return {}
269
+
270
+ @override
271
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
272
+ nvtx.mark("load_state_dict()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
273
+
274
+ @override
275
+ def on_save_checkpoint(
276
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", checkpoint: dict[str, Any]
277
+ ) -> None:
278
+ nvtx.mark("on_save_checkpoint()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
279
+
280
+ @override
281
+ def on_load_checkpoint(
282
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", checkpoint: dict[str, Any]
283
+ ) -> None:
284
+ nvtx.mark("on_load_checkpoint()", domain=Domain.CHECKPOINT) # type: ignore[unknown-argument]
285
+
286
+ @override
287
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
288
+ nvtx.mark(f"on_exception({type(exception)})", domain=Domain.ERROR) # type: ignore[unknown-argument]