fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import json
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
import fsspec
|
|
8
|
+
import lightning as L
|
|
9
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
10
|
+
|
|
11
|
+
from fkat.utils.logging import rank0_logger
|
|
12
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
13
|
+
|
|
14
|
+
logger = rank0_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ValidationMetrics(L.Callback):
|
|
18
|
+
"""
|
|
19
|
+
Saves validation metrics after each validation epoch.
|
|
20
|
+
|
|
21
|
+
This callback persists validation metrics in JSON format, creating both
|
|
22
|
+
a versioned file (with epoch and step) and a "latest" file for easy access.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
output_path (str | None): Directory path where metrics will be saved.
|
|
26
|
+
Supports any fsspec-compatible filesystem (local, s3://, gcs://, etc.).
|
|
27
|
+
If None, logs to MLflow artifacts. Defaults to None.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> # MLflow artifacts (default)
|
|
31
|
+
>>> callback = ValidationMetrics()
|
|
32
|
+
>>> # Local storage
|
|
33
|
+
>>> callback = ValidationMetrics(output_path="/tmp/metrics")
|
|
34
|
+
>>> # S3 storage
|
|
35
|
+
>>> callback = ValidationMetrics(output_path="s3://my-bucket/metrics")
|
|
36
|
+
>>> trainer = L.Trainer(callbacks=[callback])
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, output_path: str | None = None) -> None:
|
|
40
|
+
self.output_path = output_path
|
|
41
|
+
self._cb_logger: CallbackLogger | None = None
|
|
42
|
+
|
|
43
|
+
@override
|
|
44
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
45
|
+
if self.output_path is None:
|
|
46
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
@rank_zero_only
|
|
50
|
+
def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
|
|
51
|
+
if trainer.sanity_checking:
|
|
52
|
+
logger.info("Skipping validation metrics save during sanity checking")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
# Extract metrics
|
|
56
|
+
metrics_dict = {k: v.item() for k, v in trainer.logged_metrics.items()}
|
|
57
|
+
metrics_json = json.dumps(metrics_dict)
|
|
58
|
+
|
|
59
|
+
# Filenames
|
|
60
|
+
versioned_name = f"validation-metrics-epoch={pl_module.current_epoch}-step={pl_module.global_step}.json"
|
|
61
|
+
latest_name = "validation-metrics-latest.json"
|
|
62
|
+
|
|
63
|
+
if self.output_path is None:
|
|
64
|
+
# Log to MLflow artifacts
|
|
65
|
+
logger.info("Saving validation metrics to MLflow artifacts")
|
|
66
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
67
|
+
versioned_file = f"{tmpdir}/{versioned_name}"
|
|
68
|
+
latest_file = f"{tmpdir}/{latest_name}"
|
|
69
|
+
|
|
70
|
+
with open(versioned_file, "w") as f:
|
|
71
|
+
f.write(metrics_json)
|
|
72
|
+
with open(latest_file, "w") as f:
|
|
73
|
+
f.write(metrics_json)
|
|
74
|
+
|
|
75
|
+
if self._cb_logger:
|
|
76
|
+
self._cb_logger.log_artifact(versioned_file, "validation_metrics")
|
|
77
|
+
self._cb_logger.log_artifact(latest_file, "validation_metrics")
|
|
78
|
+
|
|
79
|
+
logger.info("Validation metrics saved to MLflow artifacts")
|
|
80
|
+
else:
|
|
81
|
+
# Use fsspec for filesystem
|
|
82
|
+
logger.info(f"Saving validation metrics to {self.output_path}")
|
|
83
|
+
fs, _, paths = fsspec.get_fs_token_paths(self.output_path)
|
|
84
|
+
base_path = paths[0] if paths else self.output_path
|
|
85
|
+
|
|
86
|
+
versioned_file = f"{base_path}/{versioned_name}"
|
|
87
|
+
latest_file = f"{base_path}/{latest_name}"
|
|
88
|
+
|
|
89
|
+
with fs.open(versioned_file, "w") as f:
|
|
90
|
+
f.write(metrics_json)
|
|
91
|
+
with fs.open(latest_file, "w") as f:
|
|
92
|
+
f.write(metrics_json)
|
|
93
|
+
|
|
94
|
+
logger.info(f"Validation metrics saved to {versioned_file} and {latest_file}")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .crash import CrashDetector
|
|
4
|
+
from .dp import DpSyncMonitor
|
|
5
|
+
from .hardware_stats import HardwareStats
|
|
6
|
+
from .shutdown import GracefulShutdown
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"CrashDetector",
|
|
11
|
+
"DpSyncMonitor",
|
|
12
|
+
"GracefulShutdown",
|
|
13
|
+
"HardwareStats",
|
|
14
|
+
]
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
import multiprocessing
|
|
7
|
+
import datetime as dt
|
|
8
|
+
import tempfile
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
import lightning as L
|
|
13
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
14
|
+
|
|
15
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
16
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _monitor_process(queue: multiprocessing.Queue, parent_pid: int, rank: int) -> None:
|
|
22
|
+
"""Monitor parent process and report crash info."""
|
|
23
|
+
try:
|
|
24
|
+
_, status = os.waitpid(parent_pid, 0)
|
|
25
|
+
exit_code = os.WEXITSTATUS(status) if os.WIFEXITED(status) else -1
|
|
26
|
+
signal_num = os.WTERMSIG(status) if os.WIFSIGNALED(status) else None
|
|
27
|
+
|
|
28
|
+
crash_info = {
|
|
29
|
+
"pid": parent_pid,
|
|
30
|
+
"rank": rank,
|
|
31
|
+
"exit_code": exit_code,
|
|
32
|
+
"signal": signal_num,
|
|
33
|
+
"timestamp": str(dt.datetime.now(dt.timezone.utc)),
|
|
34
|
+
}
|
|
35
|
+
queue.put(crash_info)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
log.error(f"Error monitoring process {parent_pid}: {e}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CrashDetector(L.Callback):
|
|
41
|
+
"""
|
|
42
|
+
Detects process crashes and logs detailed error information.
|
|
43
|
+
|
|
44
|
+
Monitors the training process and any spawned subprocesses for crashes.
|
|
45
|
+
Captures PID, rank, error details, and stack traces, logging them to
|
|
46
|
+
the configured Lightning logger.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
error_tag: Tag for error messages (default: "error")
|
|
50
|
+
crash_info_tag: Tag for crash details (default: "crash_info")
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
>>> from fkat.pytorch.callbacks.monitoring import CrashDetector
|
|
54
|
+
>>> callback = CrashDetector()
|
|
55
|
+
>>> trainer = L.Trainer(callbacks=[callback])
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
error_tag: str = "error",
|
|
61
|
+
crash_info_tag: str = "crash_info",
|
|
62
|
+
) -> None:
|
|
63
|
+
self.error_tag = error_tag
|
|
64
|
+
self.crash_info_tag = crash_info_tag
|
|
65
|
+
self._cb_logger: LightningLogger | None = None
|
|
66
|
+
self._processes: list[multiprocessing.Process] = []
|
|
67
|
+
self._queue: multiprocessing.Queue | None = None
|
|
68
|
+
|
|
69
|
+
@override
|
|
70
|
+
def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
71
|
+
"""Initialize crash detection."""
|
|
72
|
+
if trainer.local_rank == 0:
|
|
73
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
74
|
+
self._queue = multiprocessing.Queue()
|
|
75
|
+
|
|
76
|
+
# Monitor main process
|
|
77
|
+
process = multiprocessing.Process(
|
|
78
|
+
target=_monitor_process, args=(self._queue, os.getpid(), trainer.global_rank)
|
|
79
|
+
)
|
|
80
|
+
process.daemon = True
|
|
81
|
+
process.start()
|
|
82
|
+
self._processes.append(process)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
@rank_zero_only
|
|
86
|
+
def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
|
|
87
|
+
"""Log exception details."""
|
|
88
|
+
if not self._cb_logger:
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
exc_type = type(exception)
|
|
92
|
+
stacktrace = "".join(traceback.format_exception(exc_type, exception, exception.__traceback__))
|
|
93
|
+
|
|
94
|
+
error_msg = f"[{exc_type.__name__}]: {exception}"
|
|
95
|
+
crash_info = {
|
|
96
|
+
"pid": os.getpid(),
|
|
97
|
+
"rank": trainer.global_rank,
|
|
98
|
+
"error": error_msg,
|
|
99
|
+
"stacktrace": stacktrace,
|
|
100
|
+
"timestamp": str(dt.datetime.now(dt.timezone.utc)),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
log.error(f"Exception: {error_msg}\n{stacktrace}")
|
|
104
|
+
self._cb_logger.log_tag(self.error_tag, error_msg)
|
|
105
|
+
self._cb_logger.log_tag(self.crash_info_tag, str(crash_info))
|
|
106
|
+
|
|
107
|
+
# Log to MLflow artifacts if available
|
|
108
|
+
self._log_to_mlflow_artifact(trainer, crash_info)
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
|
|
112
|
+
"""Check for crashes and cleanup."""
|
|
113
|
+
if trainer.global_rank == 0 and self._queue:
|
|
114
|
+
# Check for any crash reports
|
|
115
|
+
while not self._queue.empty():
|
|
116
|
+
crash_info = self._queue.get_nowait()
|
|
117
|
+
log.warning(f"Detected crash: {crash_info}")
|
|
118
|
+
if self._cb_logger:
|
|
119
|
+
self._cb_logger.log_tag(self.crash_info_tag, str(crash_info))
|
|
120
|
+
self._log_to_mlflow_artifact(trainer, crash_info)
|
|
121
|
+
|
|
122
|
+
self._terminate_monitors()
|
|
123
|
+
|
|
124
|
+
def _log_to_mlflow_artifact(self, trainer: "L.Trainer", crash_info: dict) -> None:
|
|
125
|
+
"""Log crash info to MLflow artifacts."""
|
|
126
|
+
try:
|
|
127
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
128
|
+
|
|
129
|
+
mlflow_logger = next((logger for logger in trainer.loggers if isinstance(logger, MLFlowLogger)), None)
|
|
130
|
+
if not mlflow_logger:
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
# Create filename: rank0-timestamp.txt
|
|
134
|
+
rank = crash_info.get("rank", 0)
|
|
135
|
+
timestamp = crash_info.get("timestamp", "unknown")
|
|
136
|
+
# Convert timestamp to filename-safe format
|
|
137
|
+
timestamp_safe = timestamp.replace(" ", "_").replace(":", "-")
|
|
138
|
+
filename = f"rank{rank}-{timestamp_safe}.txt"
|
|
139
|
+
|
|
140
|
+
# Create temp file with crash info
|
|
141
|
+
temp_dir = Path(tempfile.gettempdir())
|
|
142
|
+
temp_path = temp_dir / filename
|
|
143
|
+
|
|
144
|
+
with open(temp_path, "w") as f:
|
|
145
|
+
f.write("Crash Information\n")
|
|
146
|
+
f.write("=" * 80 + "\n\n")
|
|
147
|
+
for key, value in crash_info.items():
|
|
148
|
+
f.write(f"{key}: {value}\n")
|
|
149
|
+
|
|
150
|
+
# Log as artifact
|
|
151
|
+
mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, str(temp_path), "crashes")
|
|
152
|
+
temp_path.unlink()
|
|
153
|
+
log.info(f"Logged crash info to MLflow artifacts: {filename}")
|
|
154
|
+
except Exception as e:
|
|
155
|
+
log.warning(f"Failed to log crash info to MLflow: {e}")
|
|
156
|
+
|
|
157
|
+
def _terminate_monitors(self) -> None:
|
|
158
|
+
"""Terminate all monitoring processes."""
|
|
159
|
+
for process in self._processes:
|
|
160
|
+
if process.is_alive():
|
|
161
|
+
process.kill()
|
|
162
|
+
self._processes.clear()
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
import torch
|
|
10
|
+
import torch.distributed as dist
|
|
11
|
+
|
|
12
|
+
from fkat.pytorch.schedule import Schedule, Never
|
|
13
|
+
from fkat.pytorch.loggers import LightningLogger
|
|
14
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
15
|
+
from fkat.utils.logging import rank0_logger
|
|
16
|
+
|
|
17
|
+
logger = rank0_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DpGroupStrategy(Protocol):
|
|
21
|
+
"""Protocol for getting DP group info for the current rank."""
|
|
22
|
+
|
|
23
|
+
def dp_group_info(self) -> tuple[int, int]:
|
|
24
|
+
"""Return (group_id, rank_in_group) for the current rank."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DistDpGroup(DpGroupStrategy):
|
|
29
|
+
"""Calculates DP group info based on dp_size using distributed rank."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, dp_size: int) -> None:
|
|
32
|
+
self.dp_size = dp_size
|
|
33
|
+
|
|
34
|
+
def dp_group_info(self) -> tuple[int, int]:
|
|
35
|
+
return divmod(dist.get_rank(), self.dp_size) # type: ignore[possibly-unbound-attribute]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class EnvDpGroup(DpGroupStrategy):
|
|
39
|
+
"""Calculates DP group info based on dp_size using environment variables."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, dp_size: int) -> None:
|
|
42
|
+
self.dp_size = dp_size
|
|
43
|
+
|
|
44
|
+
def dp_group_info(self) -> tuple[int, int]:
|
|
45
|
+
rank = int(os.environ.get("RANK", 0))
|
|
46
|
+
return divmod(rank, self.dp_size)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class MegatronDpGroup(DpGroupStrategy):
|
|
50
|
+
"""Gets DP group info from Megatron parallel_state."""
|
|
51
|
+
|
|
52
|
+
def dp_group_info(self) -> tuple[int, int]:
|
|
53
|
+
from megatron.core import parallel_state # type: ignore[import-not-found]
|
|
54
|
+
|
|
55
|
+
group = parallel_state.get_data_parallel_group()
|
|
56
|
+
rank_in_group = dist.get_rank(group) # type: ignore[possibly-unbound-attribute]
|
|
57
|
+
# For Megatron, we need to calculate group_id differently
|
|
58
|
+
# This assumes we can derive it from global rank and group size
|
|
59
|
+
global_rank = dist.get_rank() # Get global rank separately # type: ignore[possibly-unbound-attribute]
|
|
60
|
+
group_size = dist.get_world_size(group) # type: ignore[possibly-unbound-attribute]
|
|
61
|
+
group_id = global_rank // group_size
|
|
62
|
+
return group_id, rank_in_group
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DpSyncMonitor(L.Callback):
|
|
66
|
+
"""
|
|
67
|
+
Monitors time for each DP group to reach synchronization point.
|
|
68
|
+
Measures from batch start to before optimizer step to identify slow/fast groups.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, dp_group: DpGroupStrategy, schedule: Schedule | None = None) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Initialize the DP synchronization monitor.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dp_group: Strategy for determining DP group info (required).
|
|
77
|
+
schedule: Controls when logging occurs. Defaults to ``Never``.
|
|
78
|
+
"""
|
|
79
|
+
self.dp_group = dp_group
|
|
80
|
+
self.schedule = schedule or Never()
|
|
81
|
+
self.batch_start_time: float | None = None
|
|
82
|
+
self._cb_logger: LightningLogger | None = None
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def setup(
|
|
86
|
+
self,
|
|
87
|
+
trainer: "L.Trainer",
|
|
88
|
+
pl_module: "L.LightningModule",
|
|
89
|
+
stage: str,
|
|
90
|
+
) -> None:
|
|
91
|
+
if not self._cb_logger:
|
|
92
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
93
|
+
|
|
94
|
+
@override
|
|
95
|
+
def on_train_batch_start(
|
|
96
|
+
self,
|
|
97
|
+
trainer: "L.Trainer",
|
|
98
|
+
pl_module: "L.LightningModule",
|
|
99
|
+
batch: Any,
|
|
100
|
+
batch_idx: int,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Start timing when batch processing begins."""
|
|
103
|
+
self.batch_start_time = time.perf_counter()
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def on_before_optimizer_step(
|
|
107
|
+
self,
|
|
108
|
+
trainer: "L.Trainer",
|
|
109
|
+
pl_module: "L.LightningModule",
|
|
110
|
+
optimizer: torch.optim.Optimizer,
|
|
111
|
+
) -> None:
|
|
112
|
+
"""End timing when ready for sync (before optimizer step) and log if needed."""
|
|
113
|
+
if self.batch_start_time is not None:
|
|
114
|
+
sync_time_s = time.perf_counter() - self.batch_start_time
|
|
115
|
+
# Log immediately since we're at the sync point, before any DP comms
|
|
116
|
+
self._log_statistics(trainer, "train", 0, sync_time_s)
|
|
117
|
+
self.batch_start_time = None
|
|
118
|
+
|
|
119
|
+
def _log_statistics(self, trainer: "L.Trainer", stage: str, batch_idx: int, sync_time_s: float) -> None:
|
|
120
|
+
"""Log current group timing if schedule permits and this is DP group rank 0."""
|
|
121
|
+
group_id, rank_in_group = self.dp_group.dp_group_info()
|
|
122
|
+
if rank_in_group != 0:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
if not self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
if self._cb_logger:
|
|
129
|
+
metrics = {f"dp_sync/group{group_id}/sync_s": sync_time_s}
|
|
130
|
+
self._cb_logger.log_batch(metrics=metrics, timestamp=int(time.time()), step=trainer.global_step)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import logging
|
|
4
|
+
from time import time_ns
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
10
|
+
import psutil
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
14
|
+
from fkat.pytorch.schedule import Schedule, Every
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HardwareStats(L.Callback):
|
|
20
|
+
"""Monitor and log hardware usage (CPU, RAM, and GPU) during training.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
accelerator: Hardware accelerator to monitor ("gpu" or "cpu").
|
|
24
|
+
If None or unsupported, only CPU/RAM are monitored.
|
|
25
|
+
schedule: Controls when hardware stats are logged. Defaults to every batch.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, accelerator: str | None = None, schedule: Schedule | None = None) -> None:
|
|
29
|
+
if accelerator not in ("gpu", "cpu", None):
|
|
30
|
+
logger.warning(f"Unsupported accelerator: {accelerator}. Monitoring CPU/RAM only.")
|
|
31
|
+
accelerator = "cpu"
|
|
32
|
+
self.accelerator = accelerator
|
|
33
|
+
self.schedule = schedule or Every(n_batches=1)
|
|
34
|
+
self._cb_logger: CallbackLogger | None = None
|
|
35
|
+
self._total_gpu_memory_gb: float | None = None
|
|
36
|
+
|
|
37
|
+
@override
|
|
38
|
+
def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
|
|
39
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
40
|
+
if self.accelerator == "gpu":
|
|
41
|
+
if torch.cuda.is_available():
|
|
42
|
+
self._total_gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
43
|
+
torch.cuda.reset_peak_memory_stats()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning("GPU accelerator requested but CUDA not available. Monitoring CPU/RAM only.")
|
|
46
|
+
|
|
47
|
+
@rank_zero_only
|
|
48
|
+
@override
|
|
49
|
+
def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
50
|
+
if self.accelerator == "gpu" and torch.cuda.is_available():
|
|
51
|
+
torch.cuda.reset_peak_memory_stats()
|
|
52
|
+
|
|
53
|
+
def _get_stats(self) -> dict[str, float]:
|
|
54
|
+
stats = {
|
|
55
|
+
"cpu_usage_percent": psutil.cpu_percent(),
|
|
56
|
+
"ram_used_percent": psutil.virtual_memory().percent,
|
|
57
|
+
"ram_used_gb": psutil.virtual_memory().used / 1e9,
|
|
58
|
+
}
|
|
59
|
+
if self.accelerator == "gpu" and torch.cuda.is_available():
|
|
60
|
+
stats.update(
|
|
61
|
+
{
|
|
62
|
+
"gpu_memory_reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
|
63
|
+
"gpu_memory_allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
|
64
|
+
"gpu_peak_memory_allocated_gb": torch.cuda.max_memory_allocated() / 1e9,
|
|
65
|
+
"gpu_memory_total_per_rank_gb": self._total_gpu_memory_gb or 0.0,
|
|
66
|
+
}
|
|
67
|
+
)
|
|
68
|
+
return stats
|
|
69
|
+
|
|
70
|
+
def _log_stats(self, trainer: L.Trainer) -> None:
|
|
71
|
+
if self._cb_logger:
|
|
72
|
+
self._cb_logger.log_batch(metrics=self._get_stats(), step=trainer.global_step, timestamp=int(time_ns()))
|
|
73
|
+
if self.accelerator == "gpu" and torch.cuda.is_available():
|
|
74
|
+
torch.cuda.reset_peak_memory_stats()
|
|
75
|
+
|
|
76
|
+
@rank_zero_only
|
|
77
|
+
@override
|
|
78
|
+
def on_train_batch_start(
|
|
79
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int
|
|
80
|
+
) -> None:
|
|
81
|
+
if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
82
|
+
self._log_stats(trainer)
|
|
83
|
+
|
|
84
|
+
@rank_zero_only
|
|
85
|
+
@override
|
|
86
|
+
def on_train_batch_end(
|
|
87
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Any, batch: Any, batch_idx: int
|
|
88
|
+
) -> None:
|
|
89
|
+
if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
90
|
+
self._log_stats(trainer)
|
|
91
|
+
|
|
92
|
+
@rank_zero_only
|
|
93
|
+
@override
|
|
94
|
+
def on_validation_batch_start(
|
|
95
|
+
self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
|
96
|
+
) -> None:
|
|
97
|
+
if self.schedule.check(stage="validation", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
98
|
+
self._log_stats(trainer)
|
|
99
|
+
|
|
100
|
+
@rank_zero_only
|
|
101
|
+
@override
|
|
102
|
+
def on_validation_batch_end(
|
|
103
|
+
self,
|
|
104
|
+
trainer: L.Trainer,
|
|
105
|
+
pl_module: L.LightningModule,
|
|
106
|
+
outputs: Any,
|
|
107
|
+
batch: Any,
|
|
108
|
+
batch_idx: int,
|
|
109
|
+
dataloader_idx: int = 0,
|
|
110
|
+
) -> None:
|
|
111
|
+
if self.schedule.check(stage="validation", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
|
|
112
|
+
self._log_stats(trainer)
|
|
113
|
+
|
|
114
|
+
@rank_zero_only
|
|
115
|
+
@override
|
|
116
|
+
def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
|
|
117
|
+
self._log_stats(trainer)
|
|
118
|
+
|
|
119
|
+
@rank_zero_only
|
|
120
|
+
@override
|
|
121
|
+
def on_before_zero_grad(self, trainer: L.Trainer, pl_module: L.LightningModule, optimizer: Any) -> None:
|
|
122
|
+
if self.schedule.check(
|
|
123
|
+
stage="train", batch_idx=trainer.fit_loop.batch_idx, step=trainer.global_step, trainer=trainer
|
|
124
|
+
):
|
|
125
|
+
self._log_stats(trainer)
|
|
126
|
+
|
|
127
|
+
@rank_zero_only
|
|
128
|
+
@override
|
|
129
|
+
def teardown(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
|
|
130
|
+
self._log_stats(trainer)
|
|
131
|
+
|
|
132
|
+
@rank_zero_only
|
|
133
|
+
@override
|
|
134
|
+
def on_exception(self, trainer: L.Trainer, pl_module: L.LightningModule, exception: BaseException) -> None:
|
|
135
|
+
self._log_stats(trainer)
|