fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,170 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import time
5
+ import random
6
+ import signal
7
+ import logging
8
+ import ast
9
+ import multiprocessing
10
+ import datetime as dt
11
+ from typing import Any
12
+ from typing_extensions import override
13
+
14
+ import lightning as L
15
+ from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
16
+ from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
17
+ from lightning.pytorch.utilities import rank_zero_only
18
+
19
+ from fkat.pytorch.schedule import (
20
+ Schedule,
21
+ Elapsed,
22
+ )
23
+ from fkat.pytorch.loggers import LightningLogger
24
+ from fkat.pytorch.callbacks.loggers import CallbackLogger
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+
29
+ def start_shutdown_detection_process(
30
+ logger: LightningLogger | None,
31
+ shutdown_tag: str,
32
+ trainer: "L.Trainer",
33
+ ) -> multiprocessing.Process | None:
34
+ """
35
+ Create a process for monitoring trainer errors by periodically detecting
36
+ the ``shutdown`` tag. Terminate the application if the tag is detected.
37
+ The process is spawn on local rank 0 to minimize the overhead.
38
+ """
39
+ process: multiprocessing.Process | None = None
40
+ if logger is not None and trainer.local_rank == 0:
41
+ log.info("Starting a shutdown detection process...")
42
+ process = multiprocessing.Process(
43
+ target=detect_shutdown_from_logger, args=(logger, shutdown_tag, os.getpid(), 60)
44
+ )
45
+ process.daemon = True
46
+ process.start()
47
+ return process
48
+
49
+
50
+ def detect_shutdown_from_logger(
51
+ logger: LightningLogger,
52
+ shutdown_tag: str,
53
+ pid: int,
54
+ detection_interval_secs: int,
55
+ ) -> None:
56
+ """
57
+ Detect ``shutdown_tag`` tag periodically. If the tag is found, sends a SIGABRT to the
58
+ process with the provided pid to shutdown the training process.
59
+ """
60
+ sleep_duration = int(os.getenv("SHUTDOWN_DETECTION_INTERVAL", default=str(detection_interval_secs)))
61
+ log.debug(f"Shutdown detection frequency is {sleep_duration} secs")
62
+ try:
63
+ while True:
64
+ random_delay = random.uniform(0, sleep_duration * 0.5)
65
+ time.sleep(random_delay)
66
+ tags = logger.tags()
67
+ if shutdown_tag in tags:
68
+ log.info(f"Found {shutdown_tag}={tags[shutdown_tag]} tag. Shutting down process {pid}.")
69
+ os.kill(pid, signal.SIGABRT)
70
+ time.sleep(sleep_duration - random_delay)
71
+ except Exception as e:
72
+ log.error(f"Got error when querying mlflow SHUTDOWN tag: {e}")
73
+
74
+
75
+ class GracefulShutdown(L.Callback):
76
+ def __init__(
77
+ self,
78
+ schedule: Schedule | None = None,
79
+ shutdown_tag: str = "shutdown",
80
+ shutdown_info_tag: str = "shutdown_info",
81
+ ) -> None:
82
+ self.shutdown_tag = shutdown_tag
83
+ self.shutdown_info_tag = shutdown_info_tag
84
+ self.schedule = schedule or Elapsed(dt.timedelta(minutes=5))
85
+ self._cb_logger: LightningLogger | None = None
86
+ self._process: multiprocessing.Process | None = None
87
+
88
+ @override
89
+ def setup(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
90
+ self._cb_logger = CallbackLogger(trainer)
91
+ self._process = start_shutdown_detection_process(self._cb_logger, self.shutdown_tag, trainer)
92
+
93
+ def _maybe_stop(self, stage: str, trainer: "L.Trainer", batch_idx: int) -> None:
94
+ if (
95
+ trainer.should_stop
96
+ or not self._cb_logger
97
+ or not self.schedule.check(stage=stage, batch_idx=batch_idx, step=trainer.global_step, trainer=trainer)
98
+ ):
99
+ return
100
+ tags = self._cb_logger.tags()
101
+ shutdown_tag = tags.get(self.shutdown_tag)
102
+ trainer.should_stop = shutdown_tag is not None
103
+ if trainer.should_stop:
104
+ info_tag = tags.get(self.shutdown_info_tag)
105
+ if info_tag:
106
+ info = ast.literal_eval(info_tag)[-1]
107
+ strategy = info["Strategy"].upper()
108
+ log.info(f"Shutdown signal received. Using shutdown strategy {strategy}")
109
+ self._cb_logger.log_tag(self.shutdown_tag, "SHUTTING_DOWN")
110
+
111
+ def _update_shutdown_status(self, trainer: "L.Trainer", status: str) -> None:
112
+ if self._cb_logger:
113
+ log.info(f"update shutdown status {status} indicate job finished.")
114
+ self._cb_logger.log_tag(self.shutdown_tag, status)
115
+
116
+ @override
117
+ @rank_zero_only
118
+ def on_train_batch_start(
119
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int
120
+ ) -> None:
121
+ self._maybe_stop("train", trainer, batch_idx)
122
+
123
+ @rank_zero_only
124
+ def on_test_batch_start(
125
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
126
+ ) -> None:
127
+ self._maybe_stop("test", trainer, batch_idx)
128
+
129
+ @override
130
+ @rank_zero_only
131
+ def on_validation_batch_start(
132
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
133
+ ) -> None:
134
+ self._maybe_stop("validation", trainer, batch_idx)
135
+
136
+ @rank_zero_only
137
+ def on_predict_batch_start(
138
+ self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0
139
+ ) -> None:
140
+ self._maybe_stop("predict", trainer, batch_idx)
141
+
142
+ @override
143
+ def teardown(self, trainer: "L.Trainer", pl_module: "L.LightningModule", stage: str) -> None:
144
+ if trainer.global_rank == 0:
145
+ if not self._tuning(trainer):
146
+ log.info("update job status before job finished")
147
+ self._update_shutdown_status(trainer, "JOB_FINISHED")
148
+ self._terminate_monitor()
149
+
150
+ @override
151
+ def on_exception(self, trainer: "L.Trainer", pl_module: "L.LightningModule", exception: BaseException) -> None:
152
+ self._terminate_monitor()
153
+
154
+ def _terminate_monitor(self) -> None:
155
+ """
156
+ Terminates the separate process used for monitoring trainer errors if it is alive.
157
+ """
158
+ if self._process and self._process.is_alive():
159
+ log.info("\nTerminating error monitor...")
160
+ self._process.kill()
161
+
162
+ def _tuning(self, trainer: "L.Trainer") -> bool:
163
+ num_tuning_cbs = sum(
164
+ isinstance(
165
+ cb,
166
+ LearningRateFinder | BatchSizeFinder,
167
+ )
168
+ for cb in trainer.callbacks # type: ignore[attr-defined]
169
+ )
170
+ return num_tuning_cbs > 0
@@ -0,0 +1,13 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .viztracer import VizTracer
4
+ from .memray import Memray
5
+ from .flops import Flops
6
+ from .torch import PyTorch
7
+
8
+ __all__ = [
9
+ "PyTorch",
10
+ "VizTracer",
11
+ "Memray",
12
+ "Flops",
13
+ ]