fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,212 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import tempfile
5
+ import atexit
6
+ import signal
7
+ import gzip
8
+ import shutil
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from pathlib import Path
11
+ from collections.abc import Sequence
12
+ from typing import Any, TYPE_CHECKING
13
+ from typing_extensions import override
14
+
15
+ import lightning as L
16
+
17
+ if TYPE_CHECKING:
18
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
19
+
20
+ from fkat.pytorch.schedule import (
21
+ Schedule,
22
+ Never,
23
+ )
24
+ from fkat.pytorch.loggers import LightningLogger
25
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
26
+
27
+ memray = None
28
+
29
+
30
+ class Memray(L.Callback):
31
+ def __init__(
32
+ self,
33
+ ranks: Sequence[int] | None = None,
34
+ flamegraph: bool = False,
35
+ output_path_prefix: str | None = None,
36
+ schedule: Schedule | None = None,
37
+ compress: bool = False,
38
+ **kwargs: Any,
39
+ ) -> None:
40
+ """
41
+ [Memray](https://bloomberg.github.io/memray/api.html) PyTorch Lightning callback.
42
+ This callbacks traces host RAM (DRAM) allocations and publishes a report to help identify
43
+ potential memory leaks and investigate OOM errors.
44
+
45
+ Args:
46
+ ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
47
+ flamegraph (bool): whether to generate [Flamegraph](https://www.brendangregg.com/flamegraphs.html)
48
+ for the traced allocations, generates HTML report that van be viewed without installing `memray`
49
+ output_path_prefix (Optional[str]): output path prefix for generated reports,
50
+ use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
51
+ schedule (Optional[Schedule]): Controls when logging occurs during training.
52
+ Defaults to Never - no logging
53
+ compress (bool): publish reports as compressed files defaults to publishing raw files
54
+ """
55
+ self.ranks = ranks
56
+ self.flamegraph = flamegraph
57
+ self.compress = compress
58
+ self.rank: int | None = None
59
+ self.stage: str | None = None
60
+ self.kwargs = kwargs
61
+
62
+ self.output_path_prefix = output_path_prefix
63
+ self.schedule = schedule or Never()
64
+ self._cb_logger: LightningLogger | None = None
65
+
66
+ self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Memray")
67
+
68
+ global memray
69
+ import memray # type: ignore[unresolved-import]
70
+
71
+ self.tracker: memray.Tracker | None = None # type: ignore
72
+ self.dir = self.tmp_dir = "/tmp"
73
+
74
+ signal.signal(signal.SIGTERM, self._terminate) # terminate signal
75
+ signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
76
+ atexit.register(self._terminate)
77
+
78
+ @override
79
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
80
+ self._cb_logger = CallbackLogger(trainer)
81
+ self.rank = trainer.global_rank
82
+ self.stage = stage
83
+
84
+ @override
85
+ def on_train_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
86
+ self._start(trainer)
87
+
88
+ @override
89
+ def on_train_batch_end(
90
+ self,
91
+ trainer: "L.Trainer",
92
+ pl_module: "L.LightningModule",
93
+ outputs: "STEP_OUTPUT",
94
+ batch: Any,
95
+ batch_idx: int,
96
+ ) -> None:
97
+ self._on_batch_end(trainer, "train", batch_idx + 1)
98
+
99
+ @override
100
+ def on_validation_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
101
+ self._start(trainer)
102
+
103
+ @override
104
+ def on_validation_batch_end(
105
+ self,
106
+ trainer: "L.Trainer",
107
+ pl_module: "L.LightningModule",
108
+ outputs: "STEP_OUTPUT",
109
+ batch: Any,
110
+ batch_idx: int,
111
+ dataloader_idx: int = 0,
112
+ ) -> None:
113
+ self._on_batch_end(trainer, "validation", batch_idx + 1)
114
+
115
+ @override
116
+ def on_predict_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
117
+ self._start(trainer)
118
+
119
+ @override
120
+ def on_predict_batch_end(
121
+ self,
122
+ trainer: "L.Trainer",
123
+ pl_module: "L.LightningModule",
124
+ outputs: Any,
125
+ batch: Any,
126
+ batch_idx: int,
127
+ dataloader_idx: int = 0,
128
+ ) -> None:
129
+ self._on_batch_end(trainer, "predict", batch_idx + 1)
130
+
131
+ @override
132
+ def on_test_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
133
+ self._start(trainer)
134
+
135
+ @override
136
+ def on_test_batch_end(
137
+ self,
138
+ trainer: "L.Trainer",
139
+ pl_module: "L.LightningModule",
140
+ outputs: "STEP_OUTPUT",
141
+ batch: Any,
142
+ batch_idx: int,
143
+ dataloader_idx: int = 0,
144
+ ) -> None:
145
+ self._on_batch_end(trainer, "test", batch_idx + 1)
146
+
147
+ def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
148
+ if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
149
+ self._stop(str(batch_idx))
150
+ self._start(trainer)
151
+
152
+ @override
153
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
154
+ self._terminate()
155
+
156
+ @override
157
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
158
+ self._terminate()
159
+
160
+ def _terminate(self, *args: Any, **kwargs: Any) -> None:
161
+ # calling synchronously since this can be called during interpreter shutdown
162
+ self._stop("last", sync=True)
163
+ self.executor.shutdown()
164
+
165
+ def _start(self, trainer: "L.Trainer") -> None:
166
+ if self.ranks is not None and trainer.global_rank not in self.ranks:
167
+ return
168
+ if not self.tracker:
169
+ self.tmp_dir = tempfile.mkdtemp()
170
+ self.dir = self.output_path_prefix or self.tmp_dir
171
+ path = os.path.join(self.dir, f"rank{trainer.global_rank}.bin")
172
+ assert memray
173
+ self.tracker = memray.Tracker(path, **self.kwargs)
174
+ self.tracker.__enter__()
175
+ assert self.tracker
176
+
177
+ def _stop(self, suffix: str, sync: bool = False) -> None:
178
+ if not self.tracker:
179
+ return
180
+ # create reports synchronously
181
+ self.tracker.__exit__(None, None, None)
182
+ self.tracker = None
183
+ if self.flamegraph:
184
+ for f in os.listdir(self.dir):
185
+ results = os.path.join(self.dir, f)
186
+ from memray.commands.flamegraph import FlamegraphCommand # type: ignore[unresolved-import]
187
+
188
+ # creating this report synchronously because it uses a global memray lock
189
+ FlamegraphCommand().write_report(Path(results), Path(results + ".html"), True, -1, False)
190
+ # process reports asynchronously
191
+ artifacts_path = f"memray/{self.stage}/{suffix}"
192
+ if sync:
193
+ self._process(artifacts_path, self.dir, self.tmp_dir)
194
+ else:
195
+ self.executor.submit(self._process, artifacts_path, self.dir, self.tmp_dir)
196
+
197
+ def _process(
198
+ self,
199
+ artifacts_path: str,
200
+ report_dir: str,
201
+ tmp_dir: str,
202
+ ) -> None:
203
+ assert self._cb_logger
204
+ for f in os.listdir(report_dir):
205
+ output_file = os.path.join(report_dir, f)
206
+ if self.compress:
207
+ with open(output_file, "rb") as f_in:
208
+ output_file = output_file + ".gz"
209
+ with gzip.open(output_file, "wb") as f_out:
210
+ shutil.copyfileobj(f_in, f_out)
211
+ self._cb_logger.log_artifact(output_file, artifacts_path)
212
+ shutil.rmtree(tmp_dir, ignore_errors=True)
@@ -0,0 +1,197 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import gzip
5
+ import shutil
6
+ import tempfile
7
+ import atexit
8
+ import signal
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from typing import Any, TYPE_CHECKING
11
+ from collections.abc import Sequence
12
+ from typing_extensions import override
13
+
14
+ import lightning as L
15
+ import torch
16
+
17
+ if TYPE_CHECKING:
18
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
19
+
20
+ from fkat.pytorch.schedule import (
21
+ Schedule,
22
+ Never,
23
+ )
24
+ from fkat.pytorch.utilities import get_rank
25
+ from fkat.pytorch.loggers import LightningLogger
26
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
27
+
28
+
29
+ class PyTorch(L.Callback):
30
+ def __init__(
31
+ self,
32
+ ranks: Sequence[int] | None = None,
33
+ output_path_prefix: str | None = None,
34
+ schedule: Schedule | None = None,
35
+ compress: bool = True,
36
+ **kwargs: Any,
37
+ ) -> None:
38
+ """
39
+ [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) Lightning callback.
40
+ This :class:`L.Callback` continiously traces the training process and publishes a report
41
+ that helps examining the duration of individual calls through time.
42
+
43
+ Args:
44
+ ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
45
+ output_path_prefix (Optional[str]): output path prefix for generated reports,
46
+ use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
47
+ schedule (Optional[Schedule]): Controls when logging occurs during training.
48
+ Defaults to :class:`Never` - no intermediate logging
49
+ compress (bool): compress the report
50
+ Defaults to ``True``
51
+ **kwargs (Any): Arbitrary keyword arguments passed as is to PyTorch Profiler
52
+ except for ``execution_trace_observer`` and ``on_trace_ready``.
53
+ """
54
+ self.rank = get_rank()
55
+ self.compress = compress
56
+ self.schedule = schedule or Never()
57
+ self.output_path_prefix = output_path_prefix
58
+
59
+ self.trace_observer: torch.profiler.ExecutionTraceObserver | None = None
60
+ self.trace_file: str | None
61
+ self.profiler: torch.profiler.profile | None = None
62
+ if ranks is None or self.rank in ranks:
63
+ self.trace_file = os.path.join(self.output_path_prefix or tempfile.mkdtemp(), f"rank{self.rank}.json")
64
+ self.trace_observer = torch.profiler.ExecutionTraceObserver()
65
+ kwargs.pop("execution_trace_observer", None)
66
+ kwargs.pop("on_trace_ready", None)
67
+ self.profiler = torch.profiler.profile(
68
+ schedule=lambda step: torch.profiler.ProfilerAction.RECORD_AND_SAVE,
69
+ on_trace_ready=self._publish,
70
+ execution_trace_observer=self.trace_observer,
71
+ **kwargs,
72
+ )
73
+ self._start_profiler()
74
+ self._cb_logger: LightningLogger | None = None
75
+ self.stage: str | None = None
76
+ self.batch_idx = "?"
77
+ self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="PyTorchProfiler")
78
+
79
+ signal.signal(signal.SIGTERM, self._terminate) # terminate signal
80
+ signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
81
+ atexit.register(self._terminate)
82
+
83
+ @override
84
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
85
+ self._cb_logger = CallbackLogger(trainer)
86
+ self.stage = stage
87
+
88
+ def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
89
+ if self.profiler and self.schedule.check(
90
+ stage=stage, batch_idx=batch_idx + 1, step=trainer.global_step if trainer else None, trainer=trainer
91
+ ):
92
+ self.batch_idx = str(batch_idx + 1)
93
+ self.profiler.step()
94
+
95
+ @override
96
+ def on_train_batch_end(
97
+ self,
98
+ trainer: "L.Trainer",
99
+ pl_module: "L.LightningModule",
100
+ outputs: Any,
101
+ batch: Any,
102
+ batch_idx: int,
103
+ ) -> None:
104
+ self._on_batch_end(trainer, "train", batch_idx)
105
+
106
+ @override
107
+ def on_validation_batch_end(
108
+ self,
109
+ trainer: "L.Trainer",
110
+ pl_module: "L.LightningModule",
111
+ outputs: "STEP_OUTPUT",
112
+ batch: Any,
113
+ batch_idx: int,
114
+ dataloader_idx: int = 0,
115
+ ) -> None:
116
+ self._on_batch_end(trainer, "validation", batch_idx)
117
+
118
+ @override
119
+ def on_predict_batch_end(
120
+ self,
121
+ trainer: "L.Trainer",
122
+ pl_module: "L.LightningModule",
123
+ outputs: Any,
124
+ batch: Any,
125
+ batch_idx: int,
126
+ dataloader_idx: int = 0,
127
+ ) -> None:
128
+ self._on_batch_end(trainer, "predict", batch_idx)
129
+
130
+ @override
131
+ def on_test_batch_end(
132
+ self,
133
+ trainer: "L.Trainer",
134
+ pl_module: "L.LightningModule",
135
+ outputs: "STEP_OUTPUT",
136
+ batch: Any,
137
+ batch_idx: int,
138
+ dataloader_idx: int = 0,
139
+ ) -> None:
140
+ self._on_batch_end(trainer, "test", batch_idx)
141
+
142
+ def _publish(self, prof: torch.profiler.profile) -> None:
143
+ # create report synchronously
144
+ assert self.trace_file
145
+ prof.export_chrome_trace(self.trace_file)
146
+ base_path = self.output_path_prefix or os.path.dirname(self.trace_file)
147
+ output_file = os.path.join(base_path, self.batch_idx, os.path.basename(self.trace_file))
148
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
149
+ shutil.move(self.trace_file, output_file)
150
+ self._start_profiler()
151
+ # process report asynchronously
152
+ artifact_path = f"pt_profiler/{self.stage}/{self.batch_idx}"
153
+ sync = self.profiler is None
154
+ if sync:
155
+ # calling synchronously since this can be called during interpreter shutdown
156
+ self._process(output_file, artifact_path)
157
+ else:
158
+ self.executor.submit(self._process, output_file, artifact_path)
159
+
160
+ def _process(self, output_file: str, artifacts_path: str) -> None:
161
+ assert self._cb_logger
162
+ if self.compress:
163
+ with open(output_file, "rb") as f_in:
164
+ output_file = output_file + ".gz"
165
+ with gzip.open(output_file, "wb") as f_out:
166
+ shutil.copyfileobj(f_in, f_out)
167
+ self._cb_logger.log_artifact(output_file, artifacts_path)
168
+ shutil.rmtree(output_file, ignore_errors=True)
169
+
170
+ @override
171
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
172
+ self._terminate()
173
+
174
+ @override
175
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
176
+ self._terminate()
177
+
178
+ def _terminate(self, *_: Any) -> None:
179
+ if self.profiler and self.stage:
180
+ self.batch_idx = "last"
181
+ self._stop_profiler()
182
+ self.profiler = None
183
+ self.executor.shutdown()
184
+ if self.trace_file:
185
+ shutil.rmtree(self.trace_file, ignore_errors=True)
186
+
187
+ def _start_profiler(self) -> None:
188
+ assert self.trace_file and self.trace_observer and self.profiler
189
+ shutil.rmtree(self.trace_file, ignore_errors=True)
190
+ self.trace_observer.register_callback(self.trace_file)
191
+ self.profiler.start()
192
+
193
+ def _stop_profiler(self) -> None:
194
+ if self.profiler:
195
+ self.profiler.stop()
196
+ if self.trace_observer:
197
+ self.trace_observer.unregister_callback()
@@ -0,0 +1,197 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import json
5
+ import gzip
6
+ import shutil
7
+ import tempfile
8
+ import atexit
9
+ import signal
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from typing import Any, TYPE_CHECKING
12
+ from collections.abc import Sequence
13
+ from typing_extensions import override
14
+
15
+ import lightning as L
16
+
17
+ if TYPE_CHECKING:
18
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
19
+
20
+ from fkat.pytorch.schedule import (
21
+ Schedule,
22
+ Never,
23
+ )
24
+ from fkat.pytorch.utilities import get_rank
25
+ from fkat.pytorch.loggers import LightningLogger
26
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
27
+
28
+ if TYPE_CHECKING:
29
+ import viztracer
30
+ else:
31
+ viztracer = None
32
+
33
+
34
+ class VizTracer(L.Callback):
35
+ def __init__(
36
+ self,
37
+ ranks: Sequence[int] | None = None,
38
+ output_path_prefix: str | None = None,
39
+ schedule: Schedule | None = None,
40
+ compress: bool = False,
41
+ patch: bool = False,
42
+ **kwargs: Any,
43
+ ) -> None:
44
+ """
45
+ [VizTracer](https://viztracer.readthedocs.io/en/latest/) PyTorch Lightning callback.
46
+ This :class:`L.Callback` continiously traces the training process and publishes a report
47
+ that helps examining the duration of individual calls through time.
48
+
49
+ Args:
50
+ ranks (Optional[Sequence[int]]): only trace the provided ranks, defaults to all ranks
51
+ output_path_prefix (Optional[str]): output path prefix for generated reports,
52
+ use to persist these files locally, defaults to temporary location that is cleaned as soon as possible
53
+ schedule (Optional[Schedule]): Controls when logging occurs during training.
54
+ Defaults to :class:`Never` - no logging
55
+ compress (bool): publish reports as compressed binaries
56
+ (need to be decompressed via `viztracer --decompress <REPORT>`),
57
+ if ``True``` saves reports using viztracer's own compression that requires `viztracer` installation,
58
+ defaults to ``False`` and publishes gzipped HTML reports which require no `viztracer` installation
59
+ patch (bool): whether to let VizTracer patch internal Python hooks: subprocess, multiprocessing, etc.
60
+ Defaults to ``False``
61
+ **kwargs (Any): Arbitrary keyword arguments passed as is to VizTracer.
62
+ """
63
+ self.rank = get_rank()
64
+ self.schedule = schedule or Never()
65
+ self.output_path_prefix = output_path_prefix
66
+
67
+ global viztracer
68
+ import viztracer
69
+ from viztracer.vcompressor import VCompressor
70
+
71
+ self.compressor = VCompressor() if compress else None
72
+
73
+ self.tracer: viztracer.VizTracer | None = None # type: ignore[no-any-unimported]
74
+ if ranks is None or self.rank in ranks:
75
+ kwargs["output_file"] = f"rank{self.rank}.json"
76
+ kwargs["verbose"] = 0
77
+ self.tracer = viztracer.VizTracer(**kwargs)
78
+ assert self.tracer
79
+ if patch:
80
+ args = [v for k, v in kwargs.items() for v in ("--" * min(2, len(k)) + k, v)]
81
+ from viztracer.patch import install_all_hooks
82
+
83
+ install_all_hooks(self.tracer, args)
84
+ self.tracer.start()
85
+ self._cb_logger: LightningLogger | None = None
86
+ self.stage: str | None = None
87
+ self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="VizTracer")
88
+
89
+ signal.signal(signal.SIGTERM, self._terminate) # terminate signal
90
+ signal.signal(signal.SIGINT, self._terminate) # keyboard interrupt
91
+ atexit.register(self._terminate)
92
+
93
+ @override
94
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
95
+ self._cb_logger = CallbackLogger(trainer)
96
+ self.stage = stage
97
+
98
+ def _on_batch_end(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
99
+ if self.tracer and self.schedule.check(
100
+ stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer
101
+ ):
102
+ self._publish(str(batch_idx))
103
+ self.tracer.start()
104
+
105
+ @override
106
+ def on_train_batch_end(
107
+ self,
108
+ trainer: "L.Trainer",
109
+ pl_module: "L.LightningModule",
110
+ outputs: Any,
111
+ batch: Any,
112
+ batch_idx: int,
113
+ ) -> None:
114
+ self._on_batch_end(trainer, "train", batch_idx + 1)
115
+
116
+ @override
117
+ def on_validation_batch_end(
118
+ self,
119
+ trainer: "L.Trainer",
120
+ pl_module: "L.LightningModule",
121
+ outputs: "STEP_OUTPUT",
122
+ batch: Any,
123
+ batch_idx: int,
124
+ dataloader_idx: int = 0,
125
+ ) -> None:
126
+ self._on_batch_end(trainer, "validation", batch_idx + 1)
127
+
128
+ @override
129
+ def on_predict_batch_end(
130
+ self,
131
+ trainer: "L.Trainer",
132
+ pl_module: "L.LightningModule",
133
+ outputs: Any,
134
+ batch: Any,
135
+ batch_idx: int,
136
+ dataloader_idx: int = 0,
137
+ ) -> None:
138
+ self._on_batch_end(trainer, "predict", batch_idx + 1)
139
+
140
+ @override
141
+ def on_test_batch_end(
142
+ self,
143
+ trainer: "L.Trainer",
144
+ pl_module: "L.LightningModule",
145
+ outputs: "STEP_OUTPUT",
146
+ batch: Any,
147
+ batch_idx: int,
148
+ dataloader_idx: int = 0,
149
+ ) -> None:
150
+ self._on_batch_end(trainer, "test", batch_idx + 1)
151
+
152
+ def _publish(self, suffix: str, sync: bool = False) -> None:
153
+ assert self.tracer
154
+ self.tracer.stop()
155
+ # create report synchronously
156
+ tmp_dir = tempfile.mkdtemp()
157
+ output_stem = os.path.join(self.output_path_prefix or tmp_dir, suffix, f"rank{self.rank}")
158
+ output_file = output_stem + (".json" if self.compressor else ".html")
159
+ self.tracer.save(output_file=output_file, verbose=0)
160
+ self.tracer.clear()
161
+ # process report asynchronously
162
+ artifact_path = f"viztracer/{self.stage}/{suffix}"
163
+ if sync:
164
+ self._process(tmp_dir, output_file, artifact_path)
165
+ else:
166
+ self.executor.submit(self._process, tmp_dir, output_file, artifact_path)
167
+
168
+ def _process(self, tmp_dir: str, output_file: str, artifacts_path: str) -> None:
169
+ if not self.compressor:
170
+ with open(output_file, "rb") as f_in:
171
+ output_file = output_file + ".gz"
172
+ with gzip.open(output_file, "wb") as f_out:
173
+ shutil.copyfileobj(f_in, f_out)
174
+ else:
175
+ with open(output_file) as f:
176
+ data = json.load(f)
177
+ output_file = os.path.splitext(output_file)[0] + ".cvf"
178
+ self.compressor.compress(data, output_file)
179
+ assert self._cb_logger
180
+ self._cb_logger.log_artifact(output_file, artifacts_path)
181
+ shutil.rmtree(tmp_dir, ignore_errors=True)
182
+
183
+ @override
184
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
185
+ self._terminate()
186
+
187
+ @override
188
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
189
+ self._terminate()
190
+
191
+ def _terminate(self, *_: Any) -> None:
192
+ if self.tracer and self.stage:
193
+ # calling synchronously since this can be called during interpreter shutdown
194
+ self._publish("last", sync=True)
195
+ self.tracer.terminate()
196
+ self.tracer = None
197
+ self.executor.shutdown()