fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import gc
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
10
|
+
|
|
11
|
+
from fkat.pytorch.schedule import (
|
|
12
|
+
Schedule,
|
|
13
|
+
Never,
|
|
14
|
+
)
|
|
15
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
16
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ManualGc(L.Callback):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
schedule: Schedule | None = None,
|
|
23
|
+
stats: dict[str, list[str]] | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""
|
|
26
|
+
PyTorch Lightning callback for manual garbage collection (GC) control.
|
|
27
|
+
|
|
28
|
+
This callback allows fine-grained control over Python's garbage collection during training,
|
|
29
|
+
validation, testing, and prediction. It can disable automatic garbage collection and instead
|
|
30
|
+
perform manual collection at specified batch intervals.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
schedule (Optional[Schedule]): When to invoke manual GC, defaults to class:`Never`
|
|
34
|
+
stats (Optional[dict[str, list[str]]]): The list of stats to log per generation.
|
|
35
|
+
Defaults to all stats for all generations
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> trainer = Trainer(callbacks=[ManualGc()])
|
|
39
|
+
"""
|
|
40
|
+
self.schedule = schedule or Never()
|
|
41
|
+
self.stats = stats or dict(
|
|
42
|
+
zip("012", (["collections", "collected", "uncollected"] for _ in range(3)), strict=True)
|
|
43
|
+
)
|
|
44
|
+
self._cb_logger: LightningLogger | None = None
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
48
|
+
self._cb_logger = self._cb_logger or CallbackLogger(trainer)
|
|
49
|
+
if not isinstance(self.schedule, Never):
|
|
50
|
+
gc.disable()
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
54
|
+
"""Perform GC after training epoch."""
|
|
55
|
+
self.maybe_collect()
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def on_train_batch_end(
|
|
59
|
+
self,
|
|
60
|
+
trainer: "L.Trainer",
|
|
61
|
+
pl_module: "L.LightningModule",
|
|
62
|
+
outputs: "STEP_OUTPUT",
|
|
63
|
+
batch: Any,
|
|
64
|
+
batch_idx: int,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Perform GC after training batch if needed."""
|
|
67
|
+
self.maybe_gc(trainer, "train", batch_idx)
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
71
|
+
"""Perform GC after validation epoch."""
|
|
72
|
+
self.maybe_collect()
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
def on_validation_batch_end(
|
|
76
|
+
self,
|
|
77
|
+
trainer: "L.Trainer",
|
|
78
|
+
pl_module: "L.LightningModule",
|
|
79
|
+
outputs: "STEP_OUTPUT",
|
|
80
|
+
batch: Any,
|
|
81
|
+
batch_idx: int,
|
|
82
|
+
dataloader_idx: int = 0,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Perform GC after validation batch if needed."""
|
|
85
|
+
self.maybe_gc(trainer, "validation", batch_idx)
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
89
|
+
"""Perform GC after test epoch."""
|
|
90
|
+
self.maybe_collect()
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def on_predict_batch_end(
|
|
94
|
+
self,
|
|
95
|
+
trainer: "L.Trainer",
|
|
96
|
+
pl_module: "L.LightningModule",
|
|
97
|
+
outputs: Any,
|
|
98
|
+
batch: Any,
|
|
99
|
+
batch_idx: int,
|
|
100
|
+
dataloader_idx: int = 0,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Perform GC after prediction batch if needed."""
|
|
103
|
+
self.maybe_gc(trainer, "predict", batch_idx)
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
107
|
+
"""Perform GC after predict epoch."""
|
|
108
|
+
self.maybe_collect()
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def on_test_batch_end(
|
|
112
|
+
self,
|
|
113
|
+
trainer: "L.Trainer",
|
|
114
|
+
pl_module: "L.LightningModule",
|
|
115
|
+
outputs: "STEP_OUTPUT",
|
|
116
|
+
batch: Any,
|
|
117
|
+
batch_idx: int,
|
|
118
|
+
dataloader_idx: int = 0,
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Perform GC after test batch if needed."""
|
|
121
|
+
self.maybe_gc(trainer, "test", batch_idx)
|
|
122
|
+
|
|
123
|
+
def maybe_collect(self) -> None:
|
|
124
|
+
if not isinstance(self.schedule, Never):
|
|
125
|
+
gc.collect()
|
|
126
|
+
|
|
127
|
+
def maybe_gc(self, trainer: "L.Trainer", stage: str, batch_idx: int) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Perform garbage collection if conditions are met.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
batch_idx (int): Current batch index
|
|
133
|
+
"""
|
|
134
|
+
if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
135
|
+
now = perf_counter()
|
|
136
|
+
gc.collect()
|
|
137
|
+
if self._cb_logger is not None:
|
|
138
|
+
metrics = {f"gc/rank{trainer.global_rank}/time": perf_counter() - now}
|
|
139
|
+
for i, stats in enumerate(gc.get_stats()):
|
|
140
|
+
gen = str(i)
|
|
141
|
+
for key, value in stats.items():
|
|
142
|
+
if key in self.stats[gen]:
|
|
143
|
+
metrics[f"gc/rank{trainer.global_rank}/gen{gen}/{key}"] = float(value)
|
|
144
|
+
self._cb_logger.log_batch(
|
|
145
|
+
metrics=metrics, timestamp=int(now or perf_counter()), step=trainer.global_step
|
|
146
|
+
)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
import lightning as L
|
|
7
|
+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
8
|
+
from mlflow.entities import Metric, RunTag, Param
|
|
9
|
+
from mlflow.tracking import MlflowClient # type: ignore[possibly-unbound-import]
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
from fkat.pytorch.loggers import LightningLogger, _is_logger_type
|
|
15
|
+
from fkat.utils import assert_not_none
|
|
16
|
+
from fkat.utils.logging import rank0_logger
|
|
17
|
+
from fkat.utils.mlflow import broadcast_mlflow_run_id, mlflow_logger
|
|
18
|
+
|
|
19
|
+
log = rank0_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MLFlowCallbackLogger(LightningLogger):
|
|
23
|
+
"""
|
|
24
|
+
Mlflow logger class that supports distributed logging of tags, metrics and artifacts.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
trainer (L.Trainer): PTL trainer object
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
trainer: "L.Trainer | None" = None,
|
|
33
|
+
client: MlflowClient | None = None,
|
|
34
|
+
synchronous: bool | None = None,
|
|
35
|
+
run_id: str | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
if trainer:
|
|
39
|
+
# Initialize logger and broadcast run_ids to all ranks
|
|
40
|
+
logger = assert_not_none(mlflow_logger(trainer))
|
|
41
|
+
broadcast_mlflow_run_id(logger, trainer) # type: ignore[arg-type]
|
|
42
|
+
# Set client and run_id
|
|
43
|
+
self._client: MlflowClient = assert_not_none(getattr(logger, "_mlflow_client", None))
|
|
44
|
+
self._synchronous = getattr(logger, "_log_batch_kwargs", {}).get("synchronous")
|
|
45
|
+
self._run_id: str = assert_not_none(getattr(logger, "_run_id", None))
|
|
46
|
+
else:
|
|
47
|
+
assert client
|
|
48
|
+
self._client = client
|
|
49
|
+
self._synchronous = synchronous
|
|
50
|
+
assert run_id
|
|
51
|
+
self._run_id = run_id
|
|
52
|
+
|
|
53
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
54
|
+
self._client.set_tag(run_id=self._run_id, key=key, value=value, synchronous=self._synchronous)
|
|
55
|
+
|
|
56
|
+
def tags(self) -> dict[str, Any]:
|
|
57
|
+
run = self._client.get_run(self._run_id)
|
|
58
|
+
return run.data.tags
|
|
59
|
+
|
|
60
|
+
def log_batch(
|
|
61
|
+
self,
|
|
62
|
+
metrics: dict[str, float] | None = None,
|
|
63
|
+
params: dict[str, Any] | None = None,
|
|
64
|
+
tags: dict[str, str] | None = None,
|
|
65
|
+
timestamp: int | None = None,
|
|
66
|
+
step: int | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
ms = [Metric(k, v, timestamp, step) for k, v in metrics.items()] if metrics else []
|
|
69
|
+
ps = [Param(k, v) for k, v in params.items()] if params else []
|
|
70
|
+
ts = [RunTag(k, v) for k, v in tags.items()] if tags else []
|
|
71
|
+
self._client.log_batch(
|
|
72
|
+
run_id=self._run_id,
|
|
73
|
+
metrics=ms,
|
|
74
|
+
params=ps,
|
|
75
|
+
tags=ts,
|
|
76
|
+
synchronous=self._synchronous,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
80
|
+
# TODO: log directly to s3 uri
|
|
81
|
+
# TODO: support async logging
|
|
82
|
+
self._client.log_artifact(
|
|
83
|
+
run_id=self._run_id,
|
|
84
|
+
local_path=local_path,
|
|
85
|
+
artifact_path=artifact_path,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TensorBoardCallbackLogger(LightningLogger):
|
|
90
|
+
"""TensorBoard logger for distributed logging."""
|
|
91
|
+
|
|
92
|
+
def __init__(self, logger: TensorBoardLogger) -> None:
|
|
93
|
+
self._logger = logger
|
|
94
|
+
|
|
95
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
96
|
+
self._logger.experiment.add_text(key, value)
|
|
97
|
+
|
|
98
|
+
def tags(self) -> dict[str, Any]:
|
|
99
|
+
return {}
|
|
100
|
+
|
|
101
|
+
def log_batch(
|
|
102
|
+
self,
|
|
103
|
+
metrics: dict[str, float] | None = None,
|
|
104
|
+
params: dict[str, Any] | None = None,
|
|
105
|
+
tags: dict[str, str] | None = None,
|
|
106
|
+
timestamp: int | None = None,
|
|
107
|
+
step: int | None = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
if metrics:
|
|
110
|
+
for k, v in metrics.items():
|
|
111
|
+
self._logger.experiment.add_scalar(k, v, step)
|
|
112
|
+
if tags:
|
|
113
|
+
for k, v in tags.items():
|
|
114
|
+
self._logger.experiment.add_text(k, v, step)
|
|
115
|
+
|
|
116
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class WandbCallbackLogger(LightningLogger):
|
|
121
|
+
"""WandB logger for distributed logging."""
|
|
122
|
+
|
|
123
|
+
def __init__(self, logger: WandbLogger) -> None:
|
|
124
|
+
self._logger = logger
|
|
125
|
+
|
|
126
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
127
|
+
self._logger.experiment.config.update({key: value})
|
|
128
|
+
|
|
129
|
+
def tags(self) -> dict[str, Any]:
|
|
130
|
+
return dict(self._logger.experiment.config)
|
|
131
|
+
|
|
132
|
+
def log_batch(
|
|
133
|
+
self,
|
|
134
|
+
metrics: dict[str, float] | None = None,
|
|
135
|
+
params: dict[str, Any] | None = None,
|
|
136
|
+
tags: dict[str, str] | None = None,
|
|
137
|
+
timestamp: int | None = None,
|
|
138
|
+
step: int | None = None,
|
|
139
|
+
) -> None:
|
|
140
|
+
log_dict = {}
|
|
141
|
+
if metrics:
|
|
142
|
+
log_dict.update(metrics)
|
|
143
|
+
if tags:
|
|
144
|
+
log_dict.update(tags)
|
|
145
|
+
if log_dict:
|
|
146
|
+
self._logger.experiment.log(log_dict, step=step)
|
|
147
|
+
|
|
148
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
149
|
+
self._logger.experiment.save(local_path)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class CallbackLogger(LightningLogger):
|
|
153
|
+
"""
|
|
154
|
+
A wrapper on top of the collection of Logger instances,
|
|
155
|
+
providing methods to log metrics, artifacts, and tags across all registered loggers
|
|
156
|
+
simultaneously.
|
|
157
|
+
|
|
158
|
+
Attributes:
|
|
159
|
+
loggers (list[LightningLogger]): List of loggers
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
trainer (L.Trainer): PyTorch Lightning trainer instance used to initialize loggers
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
loggers: list[LightningLogger]
|
|
166
|
+
|
|
167
|
+
def __init__(self, trainer: "L.Trainer | None", loggers: list[LightningLogger] | None = None) -> None:
|
|
168
|
+
if trainer:
|
|
169
|
+
self.loggers = []
|
|
170
|
+
for logger in trainer.loggers:
|
|
171
|
+
if _is_logger_type(logger, "MLFlowLogger"):
|
|
172
|
+
self.loggers.append(MLFlowCallbackLogger(trainer=trainer))
|
|
173
|
+
elif _is_logger_type(logger, "TensorBoardLogger"):
|
|
174
|
+
self.loggers.append(TensorBoardCallbackLogger(logger=logger)) # type: ignore[arg-type]
|
|
175
|
+
elif _is_logger_type(logger, "WandbLogger"):
|
|
176
|
+
self.loggers.append(WandbCallbackLogger(logger=logger)) # type: ignore[arg-type]
|
|
177
|
+
else:
|
|
178
|
+
assert loggers
|
|
179
|
+
self.loggers = loggers
|
|
180
|
+
|
|
181
|
+
def __str__(self) -> str:
|
|
182
|
+
return str([type(obj).__name__ for obj in self.loggers])
|
|
183
|
+
|
|
184
|
+
@override
|
|
185
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
|
|
186
|
+
for logger in self.loggers:
|
|
187
|
+
logger.log_artifact(local_path=local_path, artifact_path=artifact_path)
|
|
188
|
+
|
|
189
|
+
@override
|
|
190
|
+
def log_batch(
|
|
191
|
+
self,
|
|
192
|
+
metrics: dict[str, float] | None = None,
|
|
193
|
+
params: dict[str, Any] | None = None,
|
|
194
|
+
tags: dict[str, str] | None = None,
|
|
195
|
+
timestamp: int | None = None,
|
|
196
|
+
step: int | None = None,
|
|
197
|
+
) -> None:
|
|
198
|
+
for logger in self.loggers:
|
|
199
|
+
logger.log_batch(metrics=metrics, tags=tags, timestamp=timestamp, step=step)
|
|
200
|
+
|
|
201
|
+
@override
|
|
202
|
+
def tags(self) -> dict[str, Any]:
|
|
203
|
+
tags = {}
|
|
204
|
+
for logger in self.loggers:
|
|
205
|
+
tags.update(logger.tags())
|
|
206
|
+
return tags
|
|
207
|
+
|
|
208
|
+
@override
|
|
209
|
+
def log_tag(self, key: str, value: str) -> None:
|
|
210
|
+
for logger in self.loggers:
|
|
211
|
+
logger.log_tag(key=key, value=value)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .heartbeat import Heartbeat
|
|
4
|
+
from .throughput import Throughput
|
|
5
|
+
from .validation_metrics import ValidationMetrics
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Heartbeat",
|
|
10
|
+
"Throughput",
|
|
11
|
+
"ValidationMetrics",
|
|
12
|
+
]
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import datetime as dt
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
9
|
+
|
|
10
|
+
from fkat.pytorch.schedule import (
|
|
11
|
+
Schedule,
|
|
12
|
+
Elapsed,
|
|
13
|
+
)
|
|
14
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
15
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
16
|
+
from fkat.utils.logging import rank0_logger
|
|
17
|
+
|
|
18
|
+
log = rank0_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Heartbeat(L.Callback):
|
|
22
|
+
"""Publishes tags indicating the time and step of the last heartbeat with the provided schedule."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
schedule: Schedule | None = None,
|
|
27
|
+
last_check_in_time_tag: str = "last_check_in_time",
|
|
28
|
+
last_check_in_step_tag: str = "last_check_in_step",
|
|
29
|
+
) -> None:
|
|
30
|
+
self.last_check_in_time_tag = last_check_in_time_tag
|
|
31
|
+
self.last_check_in_step_tag = last_check_in_step_tag
|
|
32
|
+
self.schedule = schedule or Elapsed(interval=dt.timedelta(minutes=15))
|
|
33
|
+
self._cb_logger: LightningLogger | None = None
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
37
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
38
|
+
|
|
39
|
+
def _publish_tags(self, stage: str, batch_idx: int, trainer: "L.Trainer") -> None:
|
|
40
|
+
if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
41
|
+
assert self._cb_logger
|
|
42
|
+
time = dt.datetime.now(dt.timezone.utc)
|
|
43
|
+
self._cb_logger.log_batch(
|
|
44
|
+
tags={
|
|
45
|
+
self.last_check_in_time_tag: str(time),
|
|
46
|
+
self.last_check_in_step_tag: str(trainer.global_step),
|
|
47
|
+
}
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@override
|
|
51
|
+
@rank_zero_only
|
|
52
|
+
def on_train_batch_start(
|
|
53
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
|
|
54
|
+
) -> None:
|
|
55
|
+
self._publish_tags("train", batch_idx, trainer)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
@rank_zero_only
|
|
59
|
+
def on_test_batch_start(
|
|
60
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
61
|
+
) -> None:
|
|
62
|
+
self._publish_tags("test", batch_idx, trainer)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
@rank_zero_only
|
|
66
|
+
def on_validation_batch_start(
|
|
67
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
68
|
+
) -> None:
|
|
69
|
+
self._publish_tags("validation", batch_idx, trainer)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
@rank_zero_only
|
|
73
|
+
def on_predict_batch_start(
|
|
74
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
75
|
+
) -> None:
|
|
76
|
+
self._publish_tags("predict", batch_idx, trainer)
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from time import time
|
|
4
|
+
from typing import Any, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import lightning as L
|
|
7
|
+
from lightning.pytorch.callbacks import LearningRateFinder, BatchSizeFinder
|
|
8
|
+
from lightning.pytorch.utilities.data import extract_batch_size
|
|
9
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from fkat.pytorch.schedule import (
|
|
17
|
+
Schedule,
|
|
18
|
+
Never,
|
|
19
|
+
)
|
|
20
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
21
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
22
|
+
from fkat.utils.logging import rank0_logger
|
|
23
|
+
|
|
24
|
+
logger = rank0_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Throughput(L.Callback):
|
|
28
|
+
def __init__(self, dp_ranks: int | None = None, schedule: Schedule | None = None) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Throughput logging callback that measures the time spent processing the microbatches.
|
|
31
|
+
Args:
|
|
32
|
+
schedule (Optional[Schedule]): Controls when logging occurs. Defaults to ``Never``.
|
|
33
|
+
"""
|
|
34
|
+
self.schedule = schedule or Never()
|
|
35
|
+
self.dp_ranks: int | None = dp_ranks
|
|
36
|
+
self.was_last_step_val = False
|
|
37
|
+
self.publish = False
|
|
38
|
+
self.step_start_time: dict[str, float] = {}
|
|
39
|
+
self.step_time: dict[str, float] = {}
|
|
40
|
+
self.total_time: dict[str, float] = {}
|
|
41
|
+
self.step_samples: dict[str, float] = {}
|
|
42
|
+
self.total_samples: dict[str, float] = {}
|
|
43
|
+
self.epoch_start_time: dict[str, float] = {}
|
|
44
|
+
self._cb_logger: LightningLogger | None = None
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def setup(
|
|
48
|
+
self,
|
|
49
|
+
trainer: "L.Trainer",
|
|
50
|
+
pl_module: "L.LightningModule",
|
|
51
|
+
stage: str,
|
|
52
|
+
) -> None:
|
|
53
|
+
if not self._cb_logger:
|
|
54
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
55
|
+
# ignoring special callbacks used for tuning
|
|
56
|
+
callbacks = [
|
|
57
|
+
c
|
|
58
|
+
for c in trainer.callbacks # type: ignore[attr-defined]
|
|
59
|
+
if not isinstance(c, LearningRateFinder | BatchSizeFinder)
|
|
60
|
+
]
|
|
61
|
+
tput_callbacks = [i for i, c in enumerate(callbacks) if isinstance(c, Throughput)]
|
|
62
|
+
assert len(tput_callbacks) == 1, "There can only be one Throughput logging callback in operation"
|
|
63
|
+
self.dp_ranks: int = self.dp_ranks or trainer.world_size
|
|
64
|
+
self.step_start_time = {}
|
|
65
|
+
self.step_time = {}
|
|
66
|
+
self.total_time = {}
|
|
67
|
+
self.step_samples = {}
|
|
68
|
+
self.total_samples = {}
|
|
69
|
+
self.epoch_start_time = {}
|
|
70
|
+
|
|
71
|
+
def _start_epoch(self, stage: str) -> None:
|
|
72
|
+
self.epoch_start_time[stage] = self.step_start_time[stage] = time()
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
@rank_zero_only
|
|
76
|
+
def on_train_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
77
|
+
self._start_epoch("train")
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
@rank_zero_only
|
|
81
|
+
def on_validation_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
82
|
+
self._start_epoch("validation")
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
@rank_zero_only
|
|
86
|
+
def on_test_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
87
|
+
self._start_epoch("test")
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
@rank_zero_only
|
|
91
|
+
def on_predict_epoch_start(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
92
|
+
self._start_epoch("predict")
|
|
93
|
+
|
|
94
|
+
def _report_epoch(self, trainer: "L.Trainer", stage: str) -> None:
|
|
95
|
+
self.step_start_time[stage] = (now := time())
|
|
96
|
+
metrics = {
|
|
97
|
+
f"{stage}/epochs/epoch_time": now - self.epoch_start_time[stage],
|
|
98
|
+
}
|
|
99
|
+
if self._cb_logger:
|
|
100
|
+
self._cb_logger.log_batch(metrics=metrics, timestamp=int(now), step=trainer.global_step)
|
|
101
|
+
else:
|
|
102
|
+
logger.info(metrics)
|
|
103
|
+
|
|
104
|
+
@override
|
|
105
|
+
@rank_zero_only
|
|
106
|
+
def on_train_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
107
|
+
self._report_epoch(trainer, "train")
|
|
108
|
+
|
|
109
|
+
@override
|
|
110
|
+
@rank_zero_only
|
|
111
|
+
def on_validation_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
112
|
+
self._report_epoch(trainer, "validation")
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
@rank_zero_only
|
|
116
|
+
def on_test_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
117
|
+
self._report_epoch(trainer, "test")
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
@rank_zero_only
|
|
121
|
+
def on_predict_epoch_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
122
|
+
self._report_epoch(trainer, "predict")
|
|
123
|
+
|
|
124
|
+
def _update(self, stage: str, batch: Any, batch_idx: int, step: int | None, trainer: "L.Trainer") -> None:
|
|
125
|
+
# because of other callbacks we want to only measure within batch start/end
|
|
126
|
+
# and make sure this callback is the first in the list
|
|
127
|
+
now = time()
|
|
128
|
+
self.step_time[stage] = self.step_time.get(stage, 0) + (now - self.step_start_time[stage])
|
|
129
|
+
self.step_start_time[stage] = now
|
|
130
|
+
num_samples = extract_batch_size(batch) if batch else 0
|
|
131
|
+
self.step_samples[stage] = self.step_samples.get(stage, 0) + num_samples
|
|
132
|
+
# train data points have to be at step boundaries or we will have multiple datapoints for the same step
|
|
133
|
+
self.publish = stage != "train" and self.schedule.check(
|
|
134
|
+
stage=stage, batch_idx=batch_idx, step=step, trainer=trainer
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@override
|
|
138
|
+
@rank_zero_only
|
|
139
|
+
def on_before_zero_grad(
|
|
140
|
+
self, trainer: "L.Trainer", pl_module: "L.LightningModule", optimizer: "torch.optim.Optimizer"
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Report metrics for individual steps during training.
|
|
144
|
+
"""
|
|
145
|
+
self.publish = True # always log on step
|
|
146
|
+
|
|
147
|
+
@override
|
|
148
|
+
@rank_zero_only
|
|
149
|
+
def on_train_batch_end(
|
|
150
|
+
self,
|
|
151
|
+
trainer: "L.Trainer",
|
|
152
|
+
pl_module: "L.LightningModule",
|
|
153
|
+
outputs: "STEP_OUTPUT",
|
|
154
|
+
batch: Any,
|
|
155
|
+
batch_idx: int,
|
|
156
|
+
dataloader_idx: int = 0,
|
|
157
|
+
) -> None:
|
|
158
|
+
# throughput is the first callback so it's safe to capture time here
|
|
159
|
+
self._update("train", batch, batch_idx, trainer.global_step, trainer)
|
|
160
|
+
self._report(trainer, "train")
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
@rank_zero_only
|
|
164
|
+
def on_validation_batch_end(
|
|
165
|
+
self,
|
|
166
|
+
trainer: "L.Trainer",
|
|
167
|
+
pl_module: "L.LightningModule",
|
|
168
|
+
outputs: "STEP_OUTPUT",
|
|
169
|
+
batch: Any,
|
|
170
|
+
batch_idx: int,
|
|
171
|
+
dataloader_idx: int = 0,
|
|
172
|
+
) -> None:
|
|
173
|
+
self._update("validation", batch, batch_idx, trainer.global_step, trainer)
|
|
174
|
+
self._report(trainer, "validation")
|
|
175
|
+
|
|
176
|
+
@override
|
|
177
|
+
@rank_zero_only
|
|
178
|
+
def on_test_batch_end(
|
|
179
|
+
self,
|
|
180
|
+
trainer: "L.Trainer",
|
|
181
|
+
pl_module: "L.LightningModule",
|
|
182
|
+
outputs: "STEP_OUTPUT",
|
|
183
|
+
batch: Any,
|
|
184
|
+
batch_idx: int,
|
|
185
|
+
dataloader_idx: int = 0,
|
|
186
|
+
) -> None:
|
|
187
|
+
self._update("test", batch, batch_idx, trainer.global_step, trainer)
|
|
188
|
+
self._report(trainer, "test")
|
|
189
|
+
|
|
190
|
+
@override
|
|
191
|
+
@rank_zero_only
|
|
192
|
+
def on_predict_batch_end(
|
|
193
|
+
self,
|
|
194
|
+
trainer: "L.Trainer",
|
|
195
|
+
pl_module: "L.LightningModule",
|
|
196
|
+
outputs: Any,
|
|
197
|
+
batch: Any,
|
|
198
|
+
batch_idx: int,
|
|
199
|
+
dataloader_idx: int = 0,
|
|
200
|
+
) -> None:
|
|
201
|
+
self._update("predict", batch, batch_idx, trainer.global_step, trainer)
|
|
202
|
+
self._report(trainer, "predict")
|
|
203
|
+
|
|
204
|
+
def _report(self, trainer: "L.Trainer", stage: str) -> None:
|
|
205
|
+
if not self.publish:
|
|
206
|
+
return
|
|
207
|
+
if not self.step_time.get(stage):
|
|
208
|
+
# can end up here outside of training loop e.g. when initializing precision plugin
|
|
209
|
+
return
|
|
210
|
+
self.total_time[stage] = self.total_time.get(stage, 0) + self.step_time[stage]
|
|
211
|
+
self.total_samples[stage] = self.total_samples.get(stage, 0) + self.step_samples[stage]
|
|
212
|
+
rank0_avg_tput = self.total_samples[stage] / self.total_time[stage]
|
|
213
|
+
assert self.dp_ranks
|
|
214
|
+
metrics = {
|
|
215
|
+
f"{stage}/throughput/running_avg_rank0": rank0_avg_tput,
|
|
216
|
+
f"{stage}/throughput/running_avg": self.dp_ranks * rank0_avg_tput,
|
|
217
|
+
}
|
|
218
|
+
if stage == "train":
|
|
219
|
+
# we only have steps during fit
|
|
220
|
+
metrics[f"{stage}/steps/step_time"] = self.step_time[stage]
|
|
221
|
+
rank0_tput = self.step_samples[stage] / self.step_time[stage]
|
|
222
|
+
metrics[f"{stage}/throughput/current_rank0"] = rank0_tput
|
|
223
|
+
metrics[f"{stage}/throughput/current"] = self.dp_ranks * rank0_tput
|
|
224
|
+
if self._cb_logger:
|
|
225
|
+
self._cb_logger.log_batch(metrics=metrics, timestamp=int(time()), step=trainer.global_step)
|
|
226
|
+
else:
|
|
227
|
+
logger.info(metrics)
|
|
228
|
+
self.step_time[stage] = 0.0
|
|
229
|
+
self.step_samples[stage] = 0
|
|
230
|
+
|
|
231
|
+
@override
|
|
232
|
+
@rank_zero_only
|
|
233
|
+
def on_train_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
234
|
+
self.publish = True
|
|
235
|
+
self._report(trainer, "train")
|
|
236
|
+
|
|
237
|
+
@override
|
|
238
|
+
@rank_zero_only
|
|
239
|
+
def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
240
|
+
self.publish = True
|
|
241
|
+
self._report(trainer, "validation")
|
|
242
|
+
|
|
243
|
+
@override
|
|
244
|
+
@rank_zero_only
|
|
245
|
+
def on_test_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
246
|
+
self.publish = True
|
|
247
|
+
self._report(trainer, "test")
|
|
248
|
+
|
|
249
|
+
@override
|
|
250
|
+
@rank_zero_only
|
|
251
|
+
def on_predict_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
252
|
+
self.publish = True
|
|
253
|
+
self._report(trainer, "predict")
|