fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,146 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import gc
4
+ from time import perf_counter
5
+ from typing import Any
6
+ from typing_extensions import override
7
+
8
+ import lightning as L
9
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
10
+
11
+ from fkat.pytorch.schedule import (
12
+ Schedule,
13
+ Never,
14
+ )
15
+ from fkat.pytorch.loggers import LightningLogger
16
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
17
+
18
+
19
+ class ManualGc(L.Callback):
20
+ def __init__(
21
+ self,
22
+ schedule: Schedule | None = None,
23
+ stats: dict[str, list[str]] | None = None,
24
+ ) -> None:
25
+ """
26
+ PyTorch Lightning callback for manual garbage collection (GC) control.
27
+
28
+ This callback allows fine-grained control over Python's garbage collection during training,
29
+ validation, testing, and prediction. It can disable automatic garbage collection and instead
30
+ perform manual collection at specified batch intervals.
31
+
32
+ Args:
33
+ schedule (Optional[Schedule]): When to invoke manual GC, defaults to class:`Never`
34
+ stats (Optional[dict[str, list[str]]]): The list of stats to log per generation.
35
+ Defaults to all stats for all generations
36
+
37
+ Example:
38
+ >>> trainer = Trainer(callbacks=[ManualGc()])
39
+ """
40
+ self.schedule = schedule or Never()
41
+ self.stats = stats or dict(
42
+ zip("012", (["collections", "collected", "uncollected"] for _ in range(3)), strict=True)
43
+ )
44
+ self._cb_logger: LightningLogger | None = None
45
+
46
+ @override
47
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
48
+ self._cb_logger = self._cb_logger or CallbackLogger(trainer)
49
+ if not isinstance(self.schedule, Never):
50
+ gc.disable()
51
+
52
+ @override
53
+ def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
54
+ """Perform GC after training epoch."""
55
+ self.maybe_collect()
56
+
57
+ @override
58
+ def on_train_batch_end(
59
+ self,
60
+ trainer: "L.Trainer",
61
+ pl_module: "L.LightningModule",
62
+ outputs: "STEP_OUTPUT",
63
+ batch: Any,
64
+ batch_idx: int,
65
+ ) -> None:
66
+ """Perform GC after training batch if needed."""
67
+ self.maybe_gc(trainer, "train", batch_idx)
68
+
69
+ @override
70
+ def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
71
+ """Perform GC after validation epoch."""
72
+ self.maybe_collect()
73
+
74
+ @override
75
+ def on_validation_batch_end(
76
+ self,
77
+ trainer: "L.Trainer",
78
+ pl_module: "L.LightningModule",
79
+ outputs: "STEP_OUTPUT",
80
+ batch: Any,
81
+ batch_idx: int,
82
+ dataloader_idx: int = 0,
83
+ ) -> None:
84
+ """Perform GC after validation batch if needed."""
85
+ self.maybe_gc(trainer, "validation", batch_idx)
86
+
87
+ @override
88
+ def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
89
+ """Perform GC after test epoch."""
90
+ self.maybe_collect()
91
+
92
+ @override
93
+ def on_predict_batch_end(
94
+ self,
95
+ trainer: "L.Trainer",
96
+ pl_module: "L.LightningModule",
97
+ outputs: Any,
98
+ batch: Any,
99
+ batch_idx: int,
100
+ dataloader_idx: int = 0,
101
+ ) -> None:
102
+ """Perform GC after prediction batch if needed."""
103
+ self.maybe_gc(trainer, "predict", batch_idx)
104
+
105
+ @override
106
+ def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
107
+ """Perform GC after predict epoch."""
108
+ self.maybe_collect()
109
+
110
+ @override
111
+ def on_test_batch_end(
112
+ self,
113
+ trainer: "L.Trainer",
114
+ pl_module: "L.LightningModule",
115
+ outputs: "STEP_OUTPUT",
116
+ batch: Any,
117
+ batch_idx: int,
118
+ dataloader_idx: int = 0,
119
+ ) -> None:
120
+ """Perform GC after test batch if needed."""
121
+ self.maybe_gc(trainer, "test", batch_idx)
122
+
123
+ def maybe_collect(self) -> None:
124
+ if not isinstance(self.schedule, Never):
125
+ gc.collect()
126
+
127
+ def maybe_gc(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
128
+ """
129
+ Perform garbage collection if conditions are met.
130
+
131
+ Args:
132
+ batch_idx (int): Current batch index
133
+ """
134
+ if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
135
+ now = perf_counter()
136
+ gc.collect()
137
+ if self._cb_logger is not None:
138
+ metrics = {f"gc/rank{trainer.global_rank}/time": perf_counter() - now}
139
+ for i, stats in enumerate(gc.get_stats()):
140
+ gen = str(i)
141
+ for key, value in stats.items():
142
+ if key in self.stats[gen]:
143
+ metrics[f"gc/rank{trainer.global_rank}/gen{gen}/{key}"] = float(value)
144
+ self._cb_logger.log_batch(
145
+ metrics=metrics, timestamp=int(now or perf_counter()), step=trainer.global_step
146
+ )
@@ -0,0 +1,211 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any, TYPE_CHECKING
4
+ from typing_extensions import override
5
+
6
+ import lightning as L
7
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
8
+ from mlflow.entities import Metric, RunTag, Param
9
+ from mlflow.tracking import MlflowClient # type: ignore[possibly-unbound-import]
10
+
11
+ if TYPE_CHECKING:
12
+ pass
13
+
14
+ from fkat.pytorch.loggers import LightningLogger, _is_logger_type
15
+ from fkat.utils import assert_not_none
16
+ from fkat.utils.logging import rank0_logger
17
+ from fkat.utils.mlflow import broadcast_mlflow_run_id, mlflow_logger
18
+
19
+ log = rank0_logger(__name__)
20
+
21
+
22
+ class MLFlowCallbackLogger(LightningLogger):
23
+ """
24
+ Mlflow logger class that supports distributed logging of tags, metrics and artifacts.
25
+
26
+ Args:
27
+ trainer (L.Trainer): PTL trainer object
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ trainer: "L.Trainer | None" = None,
33
+ client: MlflowClient | None = None,
34
+ synchronous: bool | None = None,
35
+ run_id: str | None = None,
36
+ ) -> None:
37
+ super().__init__()
38
+ if trainer:
39
+ # Initialize logger and broadcast run_ids to all ranks
40
+ logger = assert_not_none(mlflow_logger(trainer))
41
+ broadcast_mlflow_run_id(logger, trainer) # type: ignore[arg-type]
42
+ # Set client and run_id
43
+ self._client: MlflowClient = assert_not_none(getattr(logger, "_mlflow_client", None))
44
+ self._synchronous = getattr(logger, "_log_batch_kwargs", {}).get("synchronous")
45
+ self._run_id: str = assert_not_none(getattr(logger, "_run_id", None))
46
+ else:
47
+ assert client
48
+ self._client = client
49
+ self._synchronous = synchronous
50
+ assert run_id
51
+ self._run_id = run_id
52
+
53
+ def log_tag(self, key: str, value: str) -> None:
54
+ self._client.set_tag(run_id=self._run_id, key=key, value=value, synchronous=self._synchronous)
55
+
56
+ def tags(self) -> dict[str, Any]:
57
+ run = self._client.get_run(self._run_id)
58
+ return run.data.tags
59
+
60
+ def log_batch(
61
+ self,
62
+ metrics: dict[str, float] | None = None,
63
+ params: dict[str, Any] | None = None,
64
+ tags: dict[str, str] | None = None,
65
+ timestamp: int | None = None,
66
+ step: int | None = None,
67
+ ) -> None:
68
+ ms = [Metric(k, v, timestamp, step) for k, v in metrics.items()] if metrics else []
69
+ ps = [Param(k, v) for k, v in params.items()] if params else []
70
+ ts = [RunTag(k, v) for k, v in tags.items()] if tags else []
71
+ self._client.log_batch(
72
+ run_id=self._run_id,
73
+ metrics=ms,
74
+ params=ps,
75
+ tags=ts,
76
+ synchronous=self._synchronous,
77
+ )
78
+
79
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
80
+ # TODO: log directly to s3 uri
81
+ # TODO: support async logging
82
+ self._client.log_artifact(
83
+ run_id=self._run_id,
84
+ local_path=local_path,
85
+ artifact_path=artifact_path,
86
+ )
87
+
88
+
89
+ class TensorBoardCallbackLogger(LightningLogger):
90
+ """TensorBoard logger for distributed logging."""
91
+
92
+ def __init__(self, logger: TensorBoardLogger) -> None:
93
+ self._logger = logger
94
+
95
+ def log_tag(self, key: str, value: str) -> None:
96
+ self._logger.experiment.add_text(key, value)
97
+
98
+ def tags(self) -> dict[str, Any]:
99
+ return {}
100
+
101
+ def log_batch(
102
+ self,
103
+ metrics: dict[str, float] | None = None,
104
+ params: dict[str, Any] | None = None,
105
+ tags: dict[str, str] | None = None,
106
+ timestamp: int | None = None,
107
+ step: int | None = None,
108
+ ) -> None:
109
+ if metrics:
110
+ for k, v in metrics.items():
111
+ self._logger.experiment.add_scalar(k, v, step)
112
+ if tags:
113
+ for k, v in tags.items():
114
+ self._logger.experiment.add_text(k, v, step)
115
+
116
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
117
+ pass
118
+
119
+
120
+ class WandbCallbackLogger(LightningLogger):
121
+ """WandB logger for distributed logging."""
122
+
123
+ def __init__(self, logger: WandbLogger) -> None:
124
+ self._logger = logger
125
+
126
+ def log_tag(self, key: str, value: str) -> None:
127
+ self._logger.experiment.config.update({key: value})
128
+
129
+ def tags(self) -> dict[str, Any]:
130
+ return dict(self._logger.experiment.config)
131
+
132
+ def log_batch(
133
+ self,
134
+ metrics: dict[str, float] | None = None,
135
+ params: dict[str, Any] | None = None,
136
+ tags: dict[str, str] | None = None,
137
+ timestamp: int | None = None,
138
+ step: int | None = None,
139
+ ) -> None:
140
+ log_dict = {}
141
+ if metrics:
142
+ log_dict.update(metrics)
143
+ if tags:
144
+ log_dict.update(tags)
145
+ if log_dict:
146
+ self._logger.experiment.log(log_dict, step=step)
147
+
148
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
149
+ self._logger.experiment.save(local_path)
150
+
151
+
152
+ class CallbackLogger(LightningLogger):
153
+ """
154
+ A wrapper on top of the collection of Logger instances,
155
+ providing methods to log metrics, artifacts, and tags across all registered loggers
156
+ simultaneously.
157
+
158
+ Attributes:
159
+ loggers (list[LightningLogger]): List of loggers
160
+
161
+ Args:
162
+ trainer (L.Trainer): PyTorch Lightning trainer instance used to initialize loggers
163
+ """
164
+
165
+ loggers: list[LightningLogger]
166
+
167
+ def __init__(self, trainer: "L.Trainer | None", loggers: list[LightningLogger] | None = None) -> None:
168
+ if trainer:
169
+ self.loggers = []
170
+ for logger in trainer.loggers:
171
+ if _is_logger_type(logger, "MLFlowLogger"):
172
+ self.loggers.append(MLFlowCallbackLogger(trainer=trainer))
173
+ elif _is_logger_type(logger, "TensorBoardLogger"):
174
+ self.loggers.append(TensorBoardCallbackLogger(logger=logger)) # type: ignore[arg-type]
175
+ elif _is_logger_type(logger, "WandbLogger"):
176
+ self.loggers.append(WandbCallbackLogger(logger=logger)) # type: ignore[arg-type]
177
+ else:
178
+ assert loggers
179
+ self.loggers = loggers
180
+
181
+ def __str__(self) -> str:
182
+ return str([type(obj).__name__ for obj in self.loggers])
183
+
184
+ @override
185
+ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
186
+ for logger in self.loggers:
187
+ logger.log_artifact(local_path=local_path, artifact_path=artifact_path)
188
+
189
+ @override
190
+ def log_batch(
191
+ self,
192
+ metrics: dict[str, float] | None = None,
193
+ params: dict[str, Any] | None = None,
194
+ tags: dict[str, str] | None = None,
195
+ timestamp: int | None = None,
196
+ step: int | None = None,
197
+ ) -> None:
198
+ for logger in self.loggers:
199
+ logger.log_batch(metrics=metrics, tags=tags, timestamp=timestamp, step=step)
200
+
201
+ @override
202
+ def tags(self) -> dict[str, Any]:
203
+ tags = {}
204
+ for logger in self.loggers:
205
+ tags.update(logger.tags())
206
+ return tags
207
+
208
+ @override
209
+ def log_tag(self, key: str, value: str) -> None:
210
+ for logger in self.loggers:
211
+ logger.log_tag(key=key, value=value)
@@ -0,0 +1,12 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .heartbeat import Heartbeat
4
+ from .throughput import Throughput
5
+ from .validation_metrics import ValidationMetrics
6
+
7
+
8
+ __all__ = [
9
+ "Heartbeat",
10
+ "Throughput",
11
+ "ValidationMetrics",
12
+ ]
@@ -0,0 +1,76 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import datetime as dt
4
+ from typing import Any
5
+ from typing_extensions import override
6
+
7
+ import lightning as L
8
+ from lightning.pytorch.utilities import rank_zero_only
9
+
10
+ from fkat.pytorch.schedule import (
11
+ Schedule,
12
+ Elapsed,
13
+ )
14
+ from fkat.pytorch.loggers import LightningLogger
15
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
16
+ from fkat.utils.logging import rank0_logger
17
+
18
+ log = rank0_logger(__name__)
19
+
20
+
21
+ class Heartbeat(L.Callback):
22
+ """Publishes tags indicating the time and step of the last heartbeat with the provided schedule."""
23
+
24
+ def __init__(
25
+ self,
26
+ schedule: Schedule | None = None,
27
+ last_check_in_time_tag: str = "last_check_in_time",
28
+ last_check_in_step_tag: str = "last_check_in_step",
29
+ ) -> None:
30
+ self.last_check_in_time_tag = last_check_in_time_tag
31
+ self.last_check_in_step_tag = last_check_in_step_tag
32
+ self.schedule = schedule or Elapsed(interval=dt.timedelta(minutes=15))
33
+ self._cb_logger: LightningLogger | None = None
34
+
35
+ @override
36
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
37
+ self._cb_logger = CallbackLogger(trainer)
38
+
39
+ def _publish_tags(self, stage: str, batch_idx: int, trainer: "L.Trainer") -> None:
40
+ if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
41
+ assert self._cb_logger
42
+ time = dt.datetime.now(dt.timezone.utc)
43
+ self._cb_logger.log_batch(
44
+ tags={
45
+ self.last_check_in_time_tag: str(time),
46
+ self.last_check_in_step_tag: str(trainer.global_step),
47
+ }
48
+ )
49
+
50
+ @override
51
+ @rank_zero_only
52
+ def on_train_batch_start(
53
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
54
+ ) -> None:
55
+ self._publish_tags("train", batch_idx, trainer)
56
+
57
+ @override
58
+ @rank_zero_only
59
+ def on_test_batch_start(
60
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
61
+ ) -> None:
62
+ self._publish_tags("test", batch_idx, trainer)
63
+
64
+ @override
65
+ @rank_zero_only
66
+ def on_validation_batch_start(
67
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
68
+ ) -> None:
69
+ self._publish_tags("validation", batch_idx, trainer)
70
+
71
+ @override
72
+ @rank_zero_only
73
+ def on_predict_batch_start(
74
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
75
+ ) -> None:
76
+ self._publish_tags("predict", batch_idx, trainer)
@@ -0,0 +1,253 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from time import time
4
+ from typing import Any, TYPE_CHECKING
5
+
6
+ import lightning as L
7
+ from lightning.pytorch.callbacks import LearningRateFinder, BatchSizeFinder
8
+ from lightning.pytorch.utilities.data import extract_batch_size
9
+ from lightning.pytorch.utilities import rank_zero_only
10
+
11
+ if TYPE_CHECKING:
12
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
13
+ from typing_extensions import override
14
+
15
+ import torch
16
+ from fkat.pytorch.schedule import (
17
+ Schedule,
18
+ Never,
19
+ )
20
+ from fkat.pytorch.loggers import LightningLogger
21
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
22
+ from fkat.utils.logging import rank0_logger
23
+
24
+ logger = rank0_logger(__name__)
25
+
26
+
27
+ class Throughput(L.Callback):
28
+ def __init__(self, dp_ranks: int | None = None, schedule: Schedule | None = None) -> None:
29
+ """
30
+ Throughput logging callback that measures the time spent processing the microbatches.
31
+ Args:
32
+ schedule (Optional[Schedule]): Controls when logging occurs. Defaults to ``Never``.
33
+ """
34
+ self.schedule = schedule or Never()
35
+ self.dp_ranks: int | None = dp_ranks
36
+ self.was_last_step_val = False
37
+ self.publish = False
38
+ self.step_start_time: dict[str, float] = {}
39
+ self.step_time: dict[str, float] = {}
40
+ self.total_time: dict[str, float] = {}
41
+ self.step_samples: dict[str, float] = {}
42
+ self.total_samples: dict[str, float] = {}
43
+ self.epoch_start_time: dict[str, float] = {}
44
+ self._cb_logger: LightningLogger | None = None
45
+
46
+ @override
47
+ def setup(
48
+ self,
49
+ trainer: "L.Trainer",
50
+ pl_module: "L.LightningModule",
51
+ stage: str,
52
+ ) -> None:
53
+ if not self._cb_logger:
54
+ self._cb_logger = CallbackLogger(trainer)
55
+ # ignoring special callbacks used for tuning
56
+ callbacks = [
57
+ c
58
+ for c in trainer.callbacks # type: ignore[attr-defined]
59
+ if not isinstance(c, LearningRateFinder | BatchSizeFinder)
60
+ ]
61
+ tput_callbacks = [i for i, c in enumerate(callbacks) if isinstance(c, Throughput)]
62
+ assert len(tput_callbacks) == 1, "There can only be one Throughput logging callback in operation"
63
+ self.dp_ranks: int = self.dp_ranks or trainer.world_size
64
+ self.step_start_time = {}
65
+ self.step_time = {}
66
+ self.total_time = {}
67
+ self.step_samples = {}
68
+ self.total_samples = {}
69
+ self.epoch_start_time = {}
70
+
71
+ def _start_epoch(self, stage: str) -> None:
72
+ self.epoch_start_time[stage] = self.step_start_time[stage] = time()
73
+
74
+ @override
75
+ @rank_zero_only
76
+ def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
77
+ self._start_epoch("train")
78
+
79
+ @override
80
+ @rank_zero_only
81
+ def on_validation_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
82
+ self._start_epoch("validation")
83
+
84
+ @override
85
+ @rank_zero_only
86
+ def on_test_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
87
+ self._start_epoch("test")
88
+
89
+ @override
90
+ @rank_zero_only
91
+ def on_predict_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
92
+ self._start_epoch("predict")
93
+
94
+ def _report_epoch(self, trainer: "L.Trainer", stage: str) -> None:
95
+ self.step_start_time[stage] = (now := time())
96
+ metrics = {
97
+ f"{stage}/epochs/epoch_time": now - self.epoch_start_time[stage],
98
+ }
99
+ if self._cb_logger:
100
+ self._cb_logger.log_batch(metrics=metrics, timestamp=int(now), step=trainer.global_step)
101
+ else:
102
+ logger.info(metrics)
103
+
104
+ @override
105
+ @rank_zero_only
106
+ def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
107
+ self._report_epoch(trainer, "train")
108
+
109
+ @override
110
+ @rank_zero_only
111
+ def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
112
+ self._report_epoch(trainer, "validation")
113
+
114
+ @override
115
+ @rank_zero_only
116
+ def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
117
+ self._report_epoch(trainer, "test")
118
+
119
+ @override
120
+ @rank_zero_only
121
+ def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
122
+ self._report_epoch(trainer, "predict")
123
+
124
+ def _update(self, stage: str, batch: Any, batch_idx: int, step: int | None, trainer: "L.Trainer") -> None:
125
+ # because of other callbacks we want to only measure within batch start/end
126
+ # and make sure this callback is the first in the list
127
+ now = time()
128
+ self.step_time[stage] = self.step_time.get(stage, 0) + (now - self.step_start_time[stage])
129
+ self.step_start_time[stage] = now
130
+ num_samples = extract_batch_size(batch) if batch else 0
131
+ self.step_samples[stage] = self.step_samples.get(stage, 0) + num_samples
132
+ # train data points have to be at step boundaries or we will have multiple datapoints for the same step
133
+ self.publish = stage != "train" and self.schedule.check(
134
+ stage=stage, batch_idx=batch_idx, step=step, trainer=trainer
135
+ )
136
+
137
+ @override
138
+ @rank_zero_only
139
+ def on_before_zero_grad(
140
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
141
+ ) -> None:
142
+ """
143
+ Report metrics for individual steps during training.
144
+ """
145
+ self.publish = True # always log on step
146
+
147
+ @override
148
+ @rank_zero_only
149
+ def on_train_batch_end(
150
+ self,
151
+ trainer: "L.Trainer",
152
+ pl_module: "L.LightningModule",
153
+ outputs: "STEP_OUTPUT",
154
+ batch: Any,
155
+ batch_idx: int,
156
+ dataloader_idx: int = 0,
157
+ ) -> None:
158
+ # throughput is the first callback so it's safe to capture time here
159
+ self._update("train", batch, batch_idx, trainer.global_step, trainer)
160
+ self._report(trainer, "train")
161
+
162
+ @override
163
+ @rank_zero_only
164
+ def on_validation_batch_end(
165
+ self,
166
+ trainer: "L.Trainer",
167
+ pl_module: "L.LightningModule",
168
+ outputs: "STEP_OUTPUT",
169
+ batch: Any,
170
+ batch_idx: int,
171
+ dataloader_idx: int = 0,
172
+ ) -> None:
173
+ self._update("validation", batch, batch_idx, trainer.global_step, trainer)
174
+ self._report(trainer, "validation")
175
+
176
+ @override
177
+ @rank_zero_only
178
+ def on_test_batch_end(
179
+ self,
180
+ trainer: "L.Trainer",
181
+ pl_module: "L.LightningModule",
182
+ outputs: "STEP_OUTPUT",
183
+ batch: Any,
184
+ batch_idx: int,
185
+ dataloader_idx: int = 0,
186
+ ) -> None:
187
+ self._update("test", batch, batch_idx, trainer.global_step, trainer)
188
+ self._report(trainer, "test")
189
+
190
+ @override
191
+ @rank_zero_only
192
+ def on_predict_batch_end(
193
+ self,
194
+ trainer: "L.Trainer",
195
+ pl_module: "L.LightningModule",
196
+ outputs: Any,
197
+ batch: Any,
198
+ batch_idx: int,
199
+ dataloader_idx: int = 0,
200
+ ) -> None:
201
+ self._update("predict", batch, batch_idx, trainer.global_step, trainer)
202
+ self._report(trainer, "predict")
203
+
204
+ def _report(self, trainer: "L.Trainer", stage: str) -> None:
205
+ if not self.publish:
206
+ return
207
+ if not self.step_time.get(stage):
208
+ # can end up here outside of training loop e.g. when initializing precision plugin
209
+ return
210
+ self.total_time[stage] = self.total_time.get(stage, 0) + self.step_time[stage]
211
+ self.total_samples[stage] = self.total_samples.get(stage, 0) + self.step_samples[stage]
212
+ rank0_avg_tput = self.total_samples[stage] / self.total_time[stage]
213
+ assert self.dp_ranks
214
+ metrics = {
215
+ f"{stage}/throughput/running_avg_rank0": rank0_avg_tput,
216
+ f"{stage}/throughput/running_avg": self.dp_ranks * rank0_avg_tput,
217
+ }
218
+ if stage == "train":
219
+ # we only have steps during fit
220
+ metrics[f"{stage}/steps/step_time"] = self.step_time[stage]
221
+ rank0_tput = self.step_samples[stage] / self.step_time[stage]
222
+ metrics[f"{stage}/throughput/current_rank0"] = rank0_tput
223
+ metrics[f"{stage}/throughput/current"] = self.dp_ranks * rank0_tput
224
+ if self._cb_logger:
225
+ self._cb_logger.log_batch(metrics=metrics, timestamp=int(time()), step=trainer.global_step)
226
+ else:
227
+ logger.info(metrics)
228
+ self.step_time[stage] = 0.0
229
+ self.step_samples[stage] = 0
230
+
231
+ @override
232
+ @rank_zero_only
233
+ def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
234
+ self.publish = True
235
+ self._report(trainer, "train")
236
+
237
+ @override
238
+ @rank_zero_only
239
+ def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
240
+ self.publish = True
241
+ self._report(trainer, "validation")
242
+
243
+ @override
244
+ @rank_zero_only
245
+ def on_test_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
246
+ self.publish = True
247
+ self._report(trainer, "test")
248
+
249
+ @override
250
+ @rank_zero_only
251
+ def on_predict_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
252
+ self.publish = True
253
+ self._report(trainer, "predict")