fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,94 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import json
4
+ import tempfile
5
+ from typing_extensions import override
6
+
7
+ import fsspec
8
+ import lightning as L
9
+ from lightning.pytorch.utilities import rank_zero_only
10
+
11
+ from fkat.utils.logging import rank0_logger
12
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
13
+
14
+ logger = rank0_logger(__name__)
15
+
16
+
17
+ class ValidationMetrics(L.Callback):
18
+ """
19
+ Saves validation metrics after each validation epoch.
20
+
21
+ This callback persists validation metrics in JSON format, creating both
22
+ a versioned file (with epoch and step) and a "latest" file for easy access.
23
+
24
+ Args:
25
+ output_path (str | None): Directory path where metrics will be saved.
26
+ Supports any fsspec-compatible filesystem (local, s3://, gcs://, etc.).
27
+ If None, logs to MLflow artifacts. Defaults to None.
28
+
29
+ Example:
30
+ >>> # MLflow artifacts (default)
31
+ >>> callback = ValidationMetrics()
32
+ >>> # Local storage
33
+ >>> callback = ValidationMetrics(output_path="/tmp/metrics")
34
+ >>> # S3 storage
35
+ >>> callback = ValidationMetrics(output_path="s3://my-bucket/metrics")
36
+ >>> trainer = L.Trainer(callbacks=[callback])
37
+ """
38
+
39
+ def __init__(self, output_path: str | None = None) -> None:
40
+ self.output_path = output_path
41
+ self._cb_logger: CallbackLogger | None = None
42
+
43
+ @override
44
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
45
+ if self.output_path is None:
46
+ self._cb_logger = CallbackLogger(trainer)
47
+
48
+ @override
49
+ @rank_zero_only
50
+ def on_validation_end(self, trainer: "L.Trainer", pl_module: "L.LightningModule") -> None:
51
+ if trainer.sanity_checking:
52
+ logger.info("Skipping validation metrics save during sanity checking")
53
+ return
54
+
55
+ # Extract metrics
56
+ metrics_dict = {k: v.item() for k, v in trainer.logged_metrics.items()}
57
+ metrics_json = json.dumps(metrics_dict)
58
+
59
+ # Filenames
60
+ versioned_name = f"validation-metrics-epoch={pl_module.current_epoch}-step={pl_module.global_step}.json"
61
+ latest_name = "validation-metrics-latest.json"
62
+
63
+ if self.output_path is None:
64
+ # Log to MLflow artifacts
65
+ logger.info("Saving validation metrics to MLflow artifacts")
66
+ with tempfile.TemporaryDirectory() as tmpdir:
67
+ versioned_file = f"{tmpdir}/{versioned_name}"
68
+ latest_file = f"{tmpdir}/{latest_name}"
69
+
70
+ with open(versioned_file, "w") as f:
71
+ f.write(metrics_json)
72
+ with open(latest_file, "w") as f:
73
+ f.write(metrics_json)
74
+
75
+ if self._cb_logger:
76
+ self._cb_logger.log_artifact(versioned_file, "validation_metrics")
77
+ self._cb_logger.log_artifact(latest_file, "validation_metrics")
78
+
79
+ logger.info("Validation metrics saved to MLflow artifacts")
80
+ else:
81
+ # Use fsspec for filesystem
82
+ logger.info(f"Saving validation metrics to {self.output_path}")
83
+ fs, _, paths = fsspec.get_fs_token_paths(self.output_path)
84
+ base_path = paths[0] if paths else self.output_path
85
+
86
+ versioned_file = f"{base_path}/{versioned_name}"
87
+ latest_file = f"{base_path}/{latest_name}"
88
+
89
+ with fs.open(versioned_file, "w") as f:
90
+ f.write(metrics_json)
91
+ with fs.open(latest_file, "w") as f:
92
+ f.write(metrics_json)
93
+
94
+ logger.info(f"Validation metrics saved to {versioned_file} and {latest_file}")
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .crash import CrashDetector
4
+ from .dp import DpSyncMonitor
5
+ from .hardware_stats import HardwareStats
6
+ from .shutdown import GracefulShutdown
7
+
8
+
9
+ __all__ = [
10
+ "CrashDetector",
11
+ "DpSyncMonitor",
12
+ "GracefulShutdown",
13
+ "HardwareStats",
14
+ ]
@@ -0,0 +1,162 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import logging
5
+ import traceback
6
+ import multiprocessing
7
+ import datetime as dt
8
+ import tempfile
9
+ from pathlib import Path
10
+ from typing_extensions import override
11
+
12
+ import lightning as L
13
+ from lightning.pytorch.utilities import rank_zero_only
14
+
15
+ from fkat.pytorch.loggers import LightningLogger
16
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def _monitor_process(queue: multiprocessing.Queue, parent_pid: int, rank: int) -> None:
22
+ """Monitor parent process and report crash info."""
23
+ try:
24
+ _, status = os.waitpid(parent_pid, 0)
25
+ exit_code = os.WEXITSTATUS(status) if os.WIFEXITED(status) else -1
26
+ signal_num = os.WTERMSIG(status) if os.WIFSIGNALED(status) else None
27
+
28
+ crash_info = {
29
+ "pid": parent_pid,
30
+ "rank": rank,
31
+ "exit_code": exit_code,
32
+ "signal": signal_num,
33
+ "timestamp": str(dt.datetime.now(dt.timezone.utc)),
34
+ }
35
+ queue.put(crash_info)
36
+ except Exception as e:
37
+ log.error(f"Error monitoring process {parent_pid}: {e}")
38
+
39
+
40
+ class CrashDetector(L.Callback):
41
+ """
42
+ Detects process crashes and logs detailed error information.
43
+
44
+ Monitors the training process and any spawned subprocesses for crashes.
45
+ Captures PID, rank, error details, and stack traces, logging them to
46
+ the configured Lightning logger.
47
+
48
+ Args:
49
+ error_tag: Tag for error messages (default: "error")
50
+ crash_info_tag: Tag for crash details (default: "crash_info")
51
+
52
+ Example:
53
+ >>> from fkat.pytorch.callbacks.monitoring import CrashDetector
54
+ >>> callback = CrashDetector()
55
+ >>> trainer = L.Trainer(callbacks=[callback])
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ error_tag: str = "error",
61
+ crash_info_tag: str = "crash_info",
62
+ ) -> None:
63
+ self.error_tag = error_tag
64
+ self.crash_info_tag = crash_info_tag
65
+ self._cb_logger: LightningLogger | None = None
66
+ self._processes: list[multiprocessing.Process] = []
67
+ self._queue: multiprocessing.Queue | None = None
68
+
69
+ @override
70
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
71
+ """Initialize crash detection."""
72
+ if trainer.local_rank == 0:
73
+ self._cb_logger = CallbackLogger(trainer)
74
+ self._queue = multiprocessing.Queue()
75
+
76
+ # Monitor main process
77
+ process = multiprocessing.Process(
78
+ target=_monitor_process, args=(self._queue, os.getpid(), trainer.global_rank)
79
+ )
80
+ process.daemon = True
81
+ process.start()
82
+ self._processes.append(process)
83
+
84
+ @override
85
+ @rank_zero_only
86
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
87
+ """Log exception details."""
88
+ if not self._cb_logger:
89
+ return
90
+
91
+ exc_type = type(exception)
92
+ stacktrace = "".join(traceback.format_exception(exc_type, exception, exception.__traceback__))
93
+
94
+ error_msg = f"[{exc_type.__name__}]: {exception}"
95
+ crash_info = {
96
+ "pid": os.getpid(),
97
+ "rank": trainer.global_rank,
98
+ "error": error_msg,
99
+ "stacktrace": stacktrace,
100
+ "timestamp": str(dt.datetime.now(dt.timezone.utc)),
101
+ }
102
+
103
+ log.error(f"Exception: {error_msg}\n{stacktrace}")
104
+ self._cb_logger.log_tag(self.error_tag, error_msg)
105
+ self._cb_logger.log_tag(self.crash_info_tag, str(crash_info))
106
+
107
+ # Log to MLflow artifacts if available
108
+ self._log_to_mlflow_artifact(trainer, crash_info)
109
+
110
+ @override
111
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
112
+ """Check for crashes and cleanup."""
113
+ if trainer.global_rank == 0 and self._queue:
114
+ # Check for any crash reports
115
+ while not self._queue.empty():
116
+ crash_info = self._queue.get_nowait()
117
+ log.warning(f"Detected crash: {crash_info}")
118
+ if self._cb_logger:
119
+ self._cb_logger.log_tag(self.crash_info_tag, str(crash_info))
120
+ self._log_to_mlflow_artifact(trainer, crash_info)
121
+
122
+ self._terminate_monitors()
123
+
124
+ def _log_to_mlflow_artifact(self, trainer: "L.Trainer", crash_info: dict) -> None:
125
+ """Log crash info to MLflow artifacts."""
126
+ try:
127
+ from lightning.pytorch.loggers import MLFlowLogger
128
+
129
+ mlflow_logger = next((logger for logger in trainer.loggers if isinstance(logger, MLFlowLogger)), None)
130
+ if not mlflow_logger:
131
+ return
132
+
133
+ # Create filename: rank0-timestamp.txt
134
+ rank = crash_info.get("rank", 0)
135
+ timestamp = crash_info.get("timestamp", "unknown")
136
+ # Convert timestamp to filename-safe format
137
+ timestamp_safe = timestamp.replace(" ", "_").replace(":", "-")
138
+ filename = f"rank{rank}-{timestamp_safe}.txt"
139
+
140
+ # Create temp file with crash info
141
+ temp_dir = Path(tempfile.gettempdir())
142
+ temp_path = temp_dir / filename
143
+
144
+ with open(temp_path, "w") as f:
145
+ f.write("Crash Information\n")
146
+ f.write("=" * 80 + "\n\n")
147
+ for key, value in crash_info.items():
148
+ f.write(f"{key}: {value}\n")
149
+
150
+ # Log as artifact
151
+ mlflow_logger.experiment.log_artifact(mlflow_logger.run_id, str(temp_path), "crashes")
152
+ temp_path.unlink()
153
+ log.info(f"Logged crash info to MLflow artifacts: {filename}")
154
+ except Exception as e:
155
+ log.warning(f"Failed to log crash info to MLflow: {e}")
156
+
157
+ def _terminate_monitors(self) -> None:
158
+ """Terminate all monitoring processes."""
159
+ for process in self._processes:
160
+ if process.is_alive():
161
+ process.kill()
162
+ self._processes.clear()
@@ -0,0 +1,130 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import time
5
+ from typing import Any, Protocol
6
+ from typing_extensions import override
7
+
8
+ import lightning as L
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ from fkat.pytorch.schedule import Schedule, Never
13
+ from fkat.pytorch.loggers import LightningLogger
14
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
15
+ from fkat.utils.logging import rank0_logger
16
+
17
+ logger = rank0_logger(__name__)
18
+
19
+
20
+ class DpGroupStrategy(Protocol):
21
+ """Protocol for getting DP group info for the current rank."""
22
+
23
+ def dp_group_info(self) -> tuple[int, int]:
24
+ """Return (group_id, rank_in_group) for the current rank."""
25
+ ...
26
+
27
+
28
+ class DistDpGroup(DpGroupStrategy):
29
+ """Calculates DP group info based on dp_size using distributed rank."""
30
+
31
+ def __init__(self, dp_size: int) -> None:
32
+ self.dp_size = dp_size
33
+
34
+ def dp_group_info(self) -> tuple[int, int]:
35
+ return divmod(dist.get_rank(), self.dp_size) # type: ignore[possibly-unbound-attribute]
36
+
37
+
38
+ class EnvDpGroup(DpGroupStrategy):
39
+ """Calculates DP group info based on dp_size using environment variables."""
40
+
41
+ def __init__(self, dp_size: int) -> None:
42
+ self.dp_size = dp_size
43
+
44
+ def dp_group_info(self) -> tuple[int, int]:
45
+ rank = int(os.environ.get("RANK", 0))
46
+ return divmod(rank, self.dp_size)
47
+
48
+
49
+ class MegatronDpGroup(DpGroupStrategy):
50
+ """Gets DP group info from Megatron parallel_state."""
51
+
52
+ def dp_group_info(self) -> tuple[int, int]:
53
+ from megatron.core import parallel_state # type: ignore[import-not-found]
54
+
55
+ group = parallel_state.get_data_parallel_group()
56
+ rank_in_group = dist.get_rank(group) # type: ignore[possibly-unbound-attribute]
57
+ # For Megatron, we need to calculate group_id differently
58
+ # This assumes we can derive it from global rank and group size
59
+ global_rank = dist.get_rank() # Get global rank separately # type: ignore[possibly-unbound-attribute]
60
+ group_size = dist.get_world_size(group) # type: ignore[possibly-unbound-attribute]
61
+ group_id = global_rank // group_size
62
+ return group_id, rank_in_group
63
+
64
+
65
+ class DpSyncMonitor(L.Callback):
66
+ """
67
+ Monitors time for each DP group to reach synchronization point.
68
+ Measures from batch start to before optimizer step to identify slow/fast groups.
69
+ """
70
+
71
+ def __init__(self, dp_group: DpGroupStrategy, schedule: Schedule | None = None) -> None:
72
+ """
73
+ Initialize the DP synchronization monitor.
74
+
75
+ Args:
76
+ dp_group: Strategy for determining DP group info (required).
77
+ schedule: Controls when logging occurs. Defaults to ``Never``.
78
+ """
79
+ self.dp_group = dp_group
80
+ self.schedule = schedule or Never()
81
+ self.batch_start_time: float | None = None
82
+ self._cb_logger: LightningLogger | None = None
83
+
84
+ @override
85
+ def setup(
86
+ self,
87
+ trainer: "L.Trainer",
88
+ pl_module: "L.LightningModule",
89
+ stage: str,
90
+ ) -> None:
91
+ if not self._cb_logger:
92
+ self._cb_logger = CallbackLogger(trainer)
93
+
94
+ @override
95
+ def on_train_batch_start(
96
+ self,
97
+ trainer: "L.Trainer",
98
+ pl_module: "L.LightningModule",
99
+ batch: Any,
100
+ batch_idx: int,
101
+ ) -> None:
102
+ """Start timing when batch processing begins."""
103
+ self.batch_start_time = time.perf_counter()
104
+
105
+ @override
106
+ def on_before_optimizer_step(
107
+ self,
108
+ trainer: "L.Trainer",
109
+ pl_module: "L.LightningModule",
110
+ optimizer: torch.optim.Optimizer,
111
+ ) -> None:
112
+ """End timing when ready for sync (before optimizer step) and log if needed."""
113
+ if self.batch_start_time is not None:
114
+ sync_time_s = time.perf_counter() - self.batch_start_time
115
+ # Log immediately since we're at the sync point, before any DP comms
116
+ self._log_statistics(trainer, "train", 0, sync_time_s)
117
+ self.batch_start_time = None
118
+
119
+ def _log_statistics(self, trainer: "L.Trainer", stage: str, batch_idx: int, sync_time_s: float) -> None:
120
+ """Log current group timing if schedule permits and this is DP group rank 0."""
121
+ group_id, rank_in_group = self.dp_group.dp_group_info()
122
+ if rank_in_group != 0:
123
+ return
124
+
125
+ if not self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
126
+ return
127
+
128
+ if self._cb_logger:
129
+ metrics = {f"dp_sync/group{group_id}/sync_s": sync_time_s}
130
+ self._cb_logger.log_batch(metrics=metrics, timestamp=int(time.time()), step=trainer.global_step)
@@ -0,0 +1,135 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ from time import time_ns
5
+ from typing import Any
6
+ from typing_extensions import override
7
+
8
+ import lightning as L
9
+ from lightning.pytorch.utilities import rank_zero_only
10
+ import psutil
11
+ import torch
12
+
13
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
14
+ from fkat.pytorch.schedule import Schedule, Every
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class HardwareStats(L.Callback):
20
+ """Monitor and log hardware usage (CPU, RAM, and GPU) during training.
21
+
22
+ Args:
23
+ accelerator: Hardware accelerator to monitor ("gpu" or "cpu").
24
+ If None or unsupported, only CPU/RAM are monitored.
25
+ schedule: Controls when hardware stats are logged. Defaults to every batch.
26
+ """
27
+
28
+ def __init__(self, accelerator: str | None = None, schedule: Schedule | None = None) -> None:
29
+ if accelerator not in ("gpu", "cpu", None):
30
+ logger.warning(f"Unsupported accelerator: {accelerator}. Monitoring CPU/RAM only.")
31
+ accelerator = "cpu"
32
+ self.accelerator = accelerator
33
+ self.schedule = schedule or Every(n_batches=1)
34
+ self._cb_logger: CallbackLogger | None = None
35
+ self._total_gpu_memory_gb: float | None = None
36
+
37
+ @override
38
+ def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
39
+ self._cb_logger = CallbackLogger(trainer)
40
+ if self.accelerator == "gpu":
41
+ if torch.cuda.is_available():
42
+ self._total_gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
43
+ torch.cuda.reset_peak_memory_stats()
44
+ else:
45
+ logger.warning("GPU accelerator requested but CUDA not available. Monitoring CPU/RAM only.")
46
+
47
+ @rank_zero_only
48
+ @override
49
+ def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
50
+ if self.accelerator == "gpu" and torch.cuda.is_available():
51
+ torch.cuda.reset_peak_memory_stats()
52
+
53
+ def _get_stats(self) -> dict[str, float]:
54
+ stats = {
55
+ "cpu_usage_percent": psutil.cpu_percent(),
56
+ "ram_used_percent": psutil.virtual_memory().percent,
57
+ "ram_used_gb": psutil.virtual_memory().used / 1e9,
58
+ }
59
+ if self.accelerator == "gpu" and torch.cuda.is_available():
60
+ stats.update(
61
+ {
62
+ "gpu_memory_reserved_gb": torch.cuda.memory_reserved() / 1e9,
63
+ "gpu_memory_allocated_gb": torch.cuda.memory_allocated() / 1e9,
64
+ "gpu_peak_memory_allocated_gb": torch.cuda.max_memory_allocated() / 1e9,
65
+ "gpu_memory_total_per_rank_gb": self._total_gpu_memory_gb or 0.0,
66
+ }
67
+ )
68
+ return stats
69
+
70
+ def _log_stats(self, trainer: L.Trainer) -> None:
71
+ if self._cb_logger:
72
+ self._cb_logger.log_batch(metrics=self._get_stats(), step=trainer.global_step, timestamp=int(time_ns()))
73
+ if self.accelerator == "gpu" and torch.cuda.is_available():
74
+ torch.cuda.reset_peak_memory_stats()
75
+
76
+ @rank_zero_only
77
+ @override
78
+ def on_train_batch_start(
79
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int
80
+ ) -> None:
81
+ if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
82
+ self._log_stats(trainer)
83
+
84
+ @rank_zero_only
85
+ @override
86
+ def on_train_batch_end(
87
+ self, trainer: L.Trainer, pl_module: L.LightningModule, outputs: Any, batch: Any, batch_idx: int
88
+ ) -> None:
89
+ if self.schedule.check(stage="train", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
90
+ self._log_stats(trainer)
91
+
92
+ @rank_zero_only
93
+ @override
94
+ def on_validation_batch_start(
95
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
96
+ ) -> None:
97
+ if self.schedule.check(stage="validation", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
98
+ self._log_stats(trainer)
99
+
100
+ @rank_zero_only
101
+ @override
102
+ def on_validation_batch_end(
103
+ self,
104
+ trainer: L.Trainer,
105
+ pl_module: L.LightningModule,
106
+ outputs: Any,
107
+ batch: Any,
108
+ batch_idx: int,
109
+ dataloader_idx: int = 0,
110
+ ) -> None:
111
+ if self.schedule.check(stage="validation", batch_idx=batch_idx, step=trainer.global_step, trainer=trainer):
112
+ self._log_stats(trainer)
113
+
114
+ @rank_zero_only
115
+ @override
116
+ def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
117
+ self._log_stats(trainer)
118
+
119
+ @rank_zero_only
120
+ @override
121
+ def on_before_zero_grad(self, trainer: L.Trainer, pl_module: L.LightningModule, optimizer: Any) -> None:
122
+ if self.schedule.check(
123
+ stage="train", batch_idx=trainer.fit_loop.batch_idx, step=trainer.global_step, trainer=trainer
124
+ ):
125
+ self._log_stats(trainer)
126
+
127
+ @rank_zero_only
128
+ @override
129
+ def teardown(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
130
+ self._log_stats(trainer)
131
+
132
+ @rank_zero_only
133
+ @override
134
+ def on_exception(self, trainer: L.Trainer, pl_module: L.LightningModule, exception: BaseException) -> None:
135
+ self._log_stats(trainer)