fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,173 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+ import logging
5
+ import multiprocessing
6
+ from typing import Any
7
+ from typing_extensions import override
8
+
9
+ import lightning as L
10
+ from lightning.pytorch.utilities import rank_zero_only
11
+
12
+ from fkat.utils.cuda.xid import detect_xid_errors
13
+ from fkat.pytorch.actions import LightningAction
14
+ from fkat.pytorch.schedule import Schedule
15
+ from fkat.pytorch.utilities import local_rank_zero_only
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class Xid(L.Callback):
21
+ """
22
+ A callback to monitor and log Xid errors in a separate process during training.
23
+
24
+ It utilizes a separate process to monitor these errors, ensuring that the main training process remains unaffected.
25
+ The monitoring process is started at the beginning of training and terminated either
26
+ upon an exception in training or at the end of the training/validation stage.
27
+ """
28
+
29
+ monitor: multiprocessing.Process | None = None
30
+
31
+ def __init__(self, actions: dict[str, LightningAction], schedule: Schedule) -> None:
32
+ """
33
+ Arguments:
34
+ actions: Dictionary mapping Xid ranges to actions
35
+ Format: {
36
+ "0-100": fkat.actions.log,
37
+ "13,43,63-64,48,79,95": fkat.actions.ec2.reboot,
38
+ "81": fkat.actions.ec2.terminate,
39
+ }
40
+ """
41
+ super().__init__()
42
+ self.actions = self._parse_xid_ranges(actions)
43
+ self.schedule = schedule
44
+ self.xid_errors: multiprocessing.Queue[set[int]] = multiprocessing.Queue()
45
+ self.xid_check: multiprocessing.Event = multiprocessing.Event() # type: ignore[attr-defined]
46
+
47
+ def _parse_xid_ranges(self, xid_actions: dict[str, LightningAction]) -> dict[int, LightningAction]:
48
+ actions = {}
49
+ for xid_range, action in xid_actions.items():
50
+ parts = xid_range.split(",")
51
+ for part in parts:
52
+ part = part.strip()
53
+ is_range = "-" in part
54
+ if is_range:
55
+ start, end = map(int, part.split("-"))
56
+ for xid in range(start, end + 1):
57
+ actions[xid] = action
58
+ else:
59
+ try:
60
+ actions[int(part)] = action
61
+ except ValueError:
62
+ print(f"Warning: Invalid XID format: {part}")
63
+ return actions
64
+
65
+ @override
66
+ @local_rank_zero_only
67
+ def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
68
+ """
69
+ Initializes the Xid error monitoring process at the start of the training stage.
70
+
71
+ This method is automatically invoked by the PyTorch Lightning framework. It starts
72
+ a separate background process dedicated to monitoring Xid errors.
73
+
74
+ Args:
75
+ trainer (L.Trainer): The PyTorch Lightning Trainer instance.
76
+ module (L.LightningModule): The PyTorch Lightning module being trained.
77
+ stage (str): The stage of the training process (e.g., 'fit', 'test').
78
+
79
+ Returns:
80
+ None.
81
+ """
82
+ log.info("Checking for Xid errors ...")
83
+ self.monitor = multiprocessing.Process(target=detect_xid_errors, args=(self.xid_check, self.xid_errors))
84
+ self.monitor.start()
85
+
86
+ def _terminate_monitor(self) -> None:
87
+ """
88
+ Terminates the separate process used for monitoring Xid errors if it is alive.
89
+
90
+ This is an internal method that checks if the monitoring process is active and,
91
+ if so, terminates it to clean up resources.
92
+
93
+ Returns:
94
+ None.
95
+ """
96
+ if self.monitor and self.monitor.is_alive():
97
+ log.info("Terminating Xid errors monitor")
98
+ self.monitor.kill()
99
+
100
+ @override
101
+ @local_rank_zero_only
102
+ def on_exception(self, trainer: L.Trainer, pl_module: L.LightningModule, exception: BaseException) -> None:
103
+ """
104
+ Callback method to handle exceptions during training.
105
+
106
+ If an exception occurs during the training process, this method ensures that the
107
+ Xid error monitoring process is terminated to prevent resource leakage.
108
+
109
+ Args:
110
+ trainer (L.Trainer): The PyTorch Lightning Trainer instance.
111
+ pl_module (L.LightningModule): The PyTorch Lightning module being trained.
112
+ exception (BaseException): The exception that occurred during training.
113
+
114
+ Returns:
115
+ None.
116
+ """
117
+ self._terminate_monitor()
118
+
119
+ @override
120
+ @local_rank_zero_only
121
+ def teardown(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
122
+ """
123
+ Ensures the Xid error monitoring process is terminated at the end of training.
124
+
125
+ This method is automatically called by the PyTorch Lightning framework at the
126
+ end of the training or validation stage to clean up the monitoring process.
127
+
128
+ Args:
129
+ trainer (L.Trainer): The PyTorch Lightning Trainer instance.
130
+ module (L.LightningModule): The PyTorch Lightning module being trained.
131
+ stage (str): The stage of the training process (e.g., 'fit', 'test').
132
+
133
+ Returns:
134
+ None.
135
+ """
136
+ self._terminate_monitor()
137
+
138
+ @override
139
+ @rank_zero_only
140
+ def on_train_batch_start(
141
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int
142
+ ) -> None:
143
+ self.check(trainer, "train", batch_idx)
144
+
145
+ @override
146
+ @rank_zero_only
147
+ def on_test_batch_start(
148
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
149
+ ) -> None:
150
+ self.check(trainer, "test", batch_idx)
151
+
152
+ @override
153
+ @rank_zero_only
154
+ def on_validation_batch_start(
155
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
156
+ ) -> None:
157
+ self.check(trainer, "validation", batch_idx)
158
+
159
+ @override
160
+ @rank_zero_only
161
+ def on_predict_batch_start(
162
+ self, trainer: L.Trainer, pl_module: L.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
163
+ ) -> None:
164
+ self.check(trainer, "predict", batch_idx)
165
+
166
+ def check(self, trainer: L.Trainer, stage: str, batch_idx: int) -> None:
167
+ if self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step):
168
+ self.xid_check.set()
169
+ while not self.xid_errors.empty():
170
+ xids = self.xid_errors.get()
171
+ for xid in xids:
172
+ if action := self.actions.get(xid):
173
+ action.perform(trainer=trainer, xid=xid)
@@ -0,0 +1,9 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .introspection import Introspection
4
+ from .optimizer import OptimizerSnapshot
5
+
6
+ __all__ = [
7
+ "Introspection",
8
+ "OptimizerSnapshot",
9
+ ]